Ver código fonte

Mask out prompt tokens for real

Eric Wang 3 anos atrás
pai
commit
4a712d4d8e
2 arquivos alterados com 29 adições e 9 exclusões
  1. 2 1
      .gitignore
  2. 27 8
      finetune.py

+ 2 - 1
.gitignore

@@ -7,4 +7,5 @@ minimal-llama**
 upload.py
 lora-**
 *ckpt
-wandb
+wandb
+test_data.json

+ 27 - 8
finetune.py

@@ -1,4 +1,5 @@
 import os
+import random
 import sys
 
 import torch
@@ -10,7 +11,7 @@ import transformers
 assert (
     "LlamaTokenizer" in transformers._import_structure["models.llama"]
 ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
-from transformers import LlamaForCausalLM, LlamaTokenizer
+from transformers import LlamaForCausalLM, LlamaTokenizer, TrainerCallback
 from peft import (
     prepare_model_for_int8_training,
     LoraConfig,
@@ -23,7 +24,7 @@ from peft import (
 MICRO_BATCH_SIZE = 4  # this could actually be 5 but i like powers of 2
 BATCH_SIZE = 128
 GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
-EPOCHS = 3  # we don't always need 3 tbh
+EPOCHS = 5  # remember, we're loading the best checkpoint with the val set
 LEARNING_RATE = 3e-4  # the Karpathy constant
 CUTOFF_LEN = 256  # 256 accounts for about 96% of the data
 LORA_R = 8
@@ -64,7 +65,7 @@ config = LoraConfig(
     task_type="CAUSAL_LM",
 )
 model = get_peft_model(model, config)
-tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
+tokenizer.pad_token_id = 1  # unk. we want this to be different from the eos token
 data = load_dataset("json", data_files=DATA_PATH)
 
 
@@ -151,8 +152,11 @@ def generate_and_tokenize_prompt(data_point):
     )["input_ids"][:-1]
     return {
         "input_ids": full_tokens,
-        "labels": [-100] * len_user_prompt_tokens
-        + full_tokens[len_user_prompt_tokens:],
+        "labels": [-100] * len_user_prompt_tokens  # mask out the user prompt
+        + [
+            token if token != tokenizer.pad_token_id else -100
+            for token in full_tokens[len_user_prompt_tokens:]
+        ],  # mask out the padding
         "attention_mask": [1] * (len(full_tokens)),
     }
 
@@ -164,13 +168,29 @@ if VAL_SET_SIZE > 0:
     train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
     val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
 else:
-    train_data = data['train'].shuffle().map(generate_and_tokenize_prompt)
+    train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
     val_data = None
 
+
+class SampleCallback(TrainerCallback):
+    def on_evaluate(self, args, state, control, **kwargs):
+        model = kwargs["model"]
+        input_ids = tokenizer(
+            generate_prompt(random.choice(train_val["test"])).split("### Response:")[0]
+            + "### Response:",
+            truncation=True,
+            max_length=CUTOFF_LEN + 1,
+            return_tensors="pt",
+        )["input_ids"][:, :-1]
+        s = model.generate(input_ids=input_ids, max_new_tokens=100)
+        print(tokenizer.decode(s[0]))
+
+
 trainer = transformers.Trainer(
     model=model,
     train_dataset=train_data,
     eval_dataset=val_data,
+    # callbacks=[SampleCallback()],
     args=transformers.TrainingArguments(
         per_device_train_batch_size=MICRO_BATCH_SIZE,
         gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
@@ -188,7 +208,6 @@ trainer = transformers.Trainer(
         load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
         ddp_find_unused_parameters=False if ddp else None,
     ),
-    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
 )
 model.config.use_cache = False
 
@@ -197,7 +216,7 @@ model.state_dict = (
     lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
 ).__get__(model, type(model))
 
-if torch.__version__ >= "2" and sys.platform != 'win32':
+if torch.__version__ >= "2" and sys.platform != "win32":
     model = torch.compile(model)
 
 trainer.train()