finetune.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import os
  2. import random
  3. import sys
  4. import torch
  5. import torch.nn as nn
  6. import bitsandbytes as bnb
  7. from datasets import load_dataset
  8. import transformers
  9. assert (
  10. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  11. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  12. from transformers import LlamaForCausalLM, LlamaTokenizer, TrainerCallback
  13. from peft import (
  14. prepare_model_for_int8_training,
  15. LoraConfig,
  16. get_peft_model,
  17. get_peft_model_state_dict,
  18. )
  19. # optimized for RTX 4090. for larger GPUs, increase some of these?
  20. MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
  21. BATCH_SIZE = 128
  22. GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
  23. EPOCHS = 3 # remember, we're loading the best checkpoint with the val set
  24. LEARNING_RATE = 3e-4 # the Karpathy constant
  25. CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
  26. LORA_R = 8
  27. LORA_ALPHA = 16
  28. LORA_DROPOUT = 0.05
  29. VAL_SET_SIZE = 2000
  30. TARGET_MODULES = [
  31. "q_proj",
  32. "v_proj",
  33. ]
  34. DATA_PATH = "alpaca_data_cleaned.json"
  35. OUTPUT_DIR = "lora-alpaca"
  36. device_map = "auto"
  37. world_size = int(os.environ.get("WORLD_SIZE", 1))
  38. ddp = world_size != 1
  39. if ddp:
  40. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  41. GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
  42. model = LlamaForCausalLM.from_pretrained(
  43. "decapoda-research/llama-7b-hf",
  44. load_in_8bit=True,
  45. device_map=device_map,
  46. )
  47. tokenizer = LlamaTokenizer.from_pretrained(
  48. "decapoda-research/llama-7b-hf", add_eos_token=True
  49. )
  50. model = prepare_model_for_int8_training(model)
  51. config = LoraConfig(
  52. r=LORA_R,
  53. lora_alpha=LORA_ALPHA,
  54. target_modules=TARGET_MODULES,
  55. lora_dropout=LORA_DROPOUT,
  56. bias="none",
  57. task_type="CAUSAL_LM",
  58. )
  59. model = get_peft_model(model, config)
  60. tokenizer.pad_token_id = 1 # unk. we want this to be different from the eos token
  61. data = load_dataset("json", data_files=DATA_PATH)
  62. def generate_prompt(data_point):
  63. # sorry about the formatting disaster gotta move fast
  64. if data_point["input"]:
  65. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  66. ### Instruction:
  67. {data_point["instruction"]}
  68. ### Input:
  69. {data_point["input"]}
  70. ### Response:
  71. {data_point["output"]}"""
  72. else:
  73. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  74. ### Instruction:
  75. {data_point["instruction"]}
  76. ### Response:
  77. {data_point["output"]}"""
  78. def tokenize(prompt):
  79. # there's probably a way to do this with the tokenizer settings
  80. # but again, gotta move fast
  81. result = tokenizer(
  82. prompt,
  83. truncation=True,
  84. max_length=CUTOFF_LEN + 1,
  85. padding="max_length",
  86. )
  87. return {
  88. "input_ids": result["input_ids"][:-1],
  89. "attention_mask": result["attention_mask"][:-1],
  90. }
  91. def generate_and_tokenize_prompt(data_point):
  92. # This function masks out the labels for the input,
  93. # so that our loss is computed only on the response.
  94. user_prompt = (
  95. (
  96. f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  97. ### Instruction:
  98. {data_point["instruction"]}
  99. ### Input:
  100. {data_point["input"]}
  101. ### Response:
  102. """
  103. )
  104. if data_point["input"]
  105. else (
  106. f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  107. ### Instruction:
  108. {data_point["instruction"]}
  109. ### Response:
  110. """
  111. )
  112. )
  113. len_user_prompt_tokens = (
  114. len(
  115. tokenizer(
  116. user_prompt,
  117. truncation=True,
  118. max_length=CUTOFF_LEN + 1,
  119. )["input_ids"]
  120. )
  121. - 1
  122. ) # no eos token
  123. full_tokens = tokenizer(
  124. user_prompt + data_point["output"],
  125. truncation=True,
  126. max_length=CUTOFF_LEN + 1,
  127. padding="max_length",
  128. )["input_ids"][:-1]
  129. return {
  130. "input_ids": full_tokens,
  131. "labels": [-100] * len_user_prompt_tokens # mask out the user prompt
  132. + [
  133. token if token != tokenizer.pad_token_id else -100
  134. for token in full_tokens[len_user_prompt_tokens:]
  135. ], # mask out the padding
  136. "attention_mask": [1] * (len(full_tokens)),
  137. }
  138. if VAL_SET_SIZE > 0:
  139. train_val = data["train"].train_test_split(
  140. test_size=VAL_SET_SIZE, shuffle=True, seed=42
  141. )
  142. train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  143. val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  144. else:
  145. train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
  146. val_data = None
  147. class SampleCallback(TrainerCallback):
  148. def on_evaluate(self, args, state, control, **kwargs):
  149. model = kwargs["model"]
  150. input_ids = tokenizer(
  151. generate_prompt(random.choice(train_val["test"])).split("### Response:")[0]
  152. + "### Response:",
  153. truncation=True,
  154. max_length=CUTOFF_LEN + 1,
  155. return_tensors="pt",
  156. )["input_ids"][:, :-1]
  157. s = model.generate(input_ids=input_ids, max_new_tokens=100)
  158. print(tokenizer.decode(s[0]))
  159. trainer = transformers.Trainer(
  160. model=model,
  161. train_dataset=train_data,
  162. eval_dataset=val_data,
  163. # callbacks=[SampleCallback()],
  164. args=transformers.TrainingArguments(
  165. per_device_train_batch_size=MICRO_BATCH_SIZE,
  166. gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  167. warmup_steps=100,
  168. num_train_epochs=EPOCHS,
  169. learning_rate=LEARNING_RATE,
  170. fp16=True,
  171. logging_steps=20,
  172. evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
  173. save_strategy="steps",
  174. eval_steps=200 if VAL_SET_SIZE > 0 else None,
  175. save_steps=200,
  176. output_dir=OUTPUT_DIR,
  177. save_total_limit=3,
  178. load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
  179. ddp_find_unused_parameters=False if ddp else None,
  180. ),
  181. )
  182. model.config.use_cache = False
  183. old_state_dict = model.state_dict
  184. model.state_dict = (
  185. lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
  186. ).__get__(model, type(model))
  187. if torch.__version__ >= "2" and sys.platform != "win32":
  188. model = torch.compile(model)
  189. trainer.train()
  190. model.save_pretrained(OUTPUT_DIR)
  191. print("\n If there's a warning about missing keys above, please disregard :)")