Explorar o código

Fix label masking length when setting add_eos_token=True and train_on_inputs=False (#306)

Co-authored-by: muximus3 <[email protected]>
Toshiro Mifune %!s(int64=3) %!d(string=hai) anos
pai
achega
179f3974f8
Modificáronse 1 ficheiros con 6 adicións e 1 borrados
  1. 6 1
      finetune.py

+ 6 - 1
finetune.py

@@ -47,6 +47,7 @@ def train(
     ],
     ],
     # llm hyperparams
     # llm hyperparams
     train_on_inputs: bool = True,  # if False, masks out inputs in loss
     train_on_inputs: bool = True,  # if False, masks out inputs in loss
+    add_eos_token: bool = False,
     group_by_length: bool = False,  # faster, but produces an odd training loss curve
     group_by_length: bool = False,  # faster, but produces an odd training loss curve
     # wandb params
     # wandb params
     wandb_project: str = "",
     wandb_project: str = "",
@@ -73,6 +74,7 @@ def train(
             f"lora_dropout: {lora_dropout}\n"
             f"lora_dropout: {lora_dropout}\n"
             f"lora_target_modules: {lora_target_modules}\n"
             f"lora_target_modules: {lora_target_modules}\n"
             f"train_on_inputs: {train_on_inputs}\n"
             f"train_on_inputs: {train_on_inputs}\n"
+            f"add_eos_token: {add_eos_token}\n"
             f"group_by_length: {group_by_length}\n"
             f"group_by_length: {group_by_length}\n"
             f"wandb_project: {wandb_project}\n"
             f"wandb_project: {wandb_project}\n"
             f"wandb_run_name: {wandb_run_name}\n"
             f"wandb_run_name: {wandb_run_name}\n"
@@ -154,9 +156,12 @@ def train(
             user_prompt = prompter.generate_prompt(
             user_prompt = prompter.generate_prompt(
                 data_point["instruction"], data_point["input"]
                 data_point["instruction"], data_point["input"]
             )
             )
-            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
+            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos_token)
             user_prompt_len = len(tokenized_user_prompt["input_ids"])
             user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
 
+            if add_eos_token:
+                user_prompt_len -= 1
+
             tokenized_full_prompt["labels"] = [
             tokenized_full_prompt["labels"] = [
                 -100
                 -100
             ] * user_prompt_len + tokenized_full_prompt["labels"][
             ] * user_prompt_len + tokenized_full_prompt["labels"][