فهرست منبع

Unwind input masking to avoid confusion

Eric Wang 3 سال پیش
والد
کامیت
b12c3b90f8
1فایلهای تغییر یافته به همراه2 افزوده شده و 48 حذف شده
  1. 2 48
      finetune.py

+ 2 - 48
finetune.py

@@ -107,54 +107,8 @@ def tokenize(prompt):
 
 
 def generate_and_tokenize_prompt(data_point):
-    # This function masks out the labels for the input,
-    # so that our loss is computed only on the response.
-    user_prompt = (
-        (
-            f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
-
-### Instruction:
-{data_point["instruction"]}
-
-### Input:
-{data_point["input"]}
-
-### Response:
-"""
-        )
-        if data_point["input"]
-        else (
-            f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
-
-### Instruction:
-{data_point["instruction"]}
-
-### Response:
-"""
-        )
-    )
-    len_user_prompt_tokens = (
-        len(
-            tokenizer(
-                user_prompt,
-                truncation=True,
-                max_length=CUTOFF_LEN + 1,
-            )["input_ids"]
-        )
-        - 1
-    )  # no eos token
-    full_tokens = tokenizer(
-        user_prompt + data_point["output"],
-        truncation=True,
-        max_length=CUTOFF_LEN + 1,
-        padding="max_length",
-    )["input_ids"][:-1]
-    return {
-        "input_ids": full_tokens,
-        "labels": [-100] * len_user_prompt_tokens
-        + full_tokens[len_user_prompt_tokens:],
-        "attention_mask": [1] * (len(full_tokens)),
-    }
+    prompt = generate_prompt(data_point)
+    return tokenize(prompt)
 
 
 if VAL_SET_SIZE > 0: