Pārlūkot izejas kodu

mask prompt in loss

Eric Wang 3 gadi atpakaļ
vecāks
revīzija
cfad895aa1
3 mainītis faili ar 65 papildinājumiem un 6 dzēšanām
  1. 4 1
      .gitignore
  2. 2 0
      README.md
  3. 59 5
      finetune.py

+ 4 - 1
.gitignore

@@ -4,4 +4,7 @@ out/
 __pycache__/
 checkpoint**
 minimal-llama**
-upload.py
+upload.py
+lora-**
+*ckpt
+wandb

+ 2 - 0
README.md

@@ -2,6 +2,8 @@
 
 **Try the pretrained model out on Colab [here](https://colab.research.google.com/drive/1eWAmesrW99p7e1nah5bipn0zikMb8XYC)!**
 
+_**Update 2023-03-19:** weights have been updated with cleaned data and prompts masked out in the loss. This should reduce the number of template artifacts in outputs._
+
 This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
 We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
 and the code can be easily extended to the `13b`, `30b`, and `65b` models.

+ 59 - 5
finetune.py

@@ -1,6 +1,5 @@
 import os
 
-# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 import torch
 import torch.nn as nn
 import bitsandbytes as bnb
@@ -37,10 +36,10 @@ TARGET_MODULES = [
 DATA_PATH = "alpaca_data_cleaned.json"
 
 device_map = "auto"
-world_size = int(os.environ.get('WORLD_SIZE', 1))
+world_size = int(os.environ.get("WORLD_SIZE", 1))
 ddp = world_size != 1
 if ddp:
-    device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)}
+    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
     GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
 
 model = LlamaForCausalLM.from_pretrained(
@@ -111,8 +110,60 @@ def tokenize(prompt):
     }
 
 
-train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
-val_data = val_data.shuffle().map(lambda x: tokenize(generate_prompt(x)))
+def generate_and_tokenize_prompt(data_point):
+    # This function masks out the labels for the input,
+    # so that our loss is computed only on the response.
+    user_prompt = (
+        (
+            f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+
+### Instruction:
+{data_point["instruction"]}
+
+### Input:
+{data_point["input"]}
+
+### Response:
+"""
+        )
+        if data_point["input"]
+        else (
+            f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
+
+### Instruction:
+{data_point["instruction"]}
+
+### Response:
+"""
+        )
+    )
+    len_user_prompt_tokens = (
+        len(
+            tokenizer(
+                user_prompt,
+                truncation=True,
+                max_length=CUTOFF_LEN + 1,
+                padding="max_length",
+            )["input_ids"]
+        )
+        - 1
+    )  # no eos token
+    full_tokens = tokenizer(
+        user_prompt + data_point["output"],
+        truncation=True,
+        max_length=CUTOFF_LEN + 1,
+        padding="max_length",
+    )["input_ids"][:-1]
+    return {
+        "input_ids": full_tokens,
+        "labels": [-100] * len_user_prompt_tokens
+        + full_tokens[len_user_prompt_tokens:],
+        "attention_mask": [1] * (len(full_tokens)),
+    }
+
+
+train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
+val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
 
 trainer = transformers.Trainer(
     model=model,
@@ -144,6 +195,9 @@ model.state_dict = (
     lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
 ).__get__(model, type(model))
 
+if torch.__version__ >= "2":
+    model = torch.compile(model)
+
 trainer.train()
 
 model.save_pretrained("lora-alpaca")