finetune.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import os
  2. import sys
  3. from typing import List
  4. import fire
  5. import torch
  6. import torch.nn as nn
  7. import bitsandbytes as bnb
  8. from datasets import load_dataset
  9. import transformers
  10. assert (
  11. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  12. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  13. from transformers import LlamaForCausalLM, LlamaTokenizer
  14. from peft import (
  15. prepare_model_for_int8_training,
  16. LoraConfig,
  17. get_peft_model,
  18. get_peft_model_state_dict,
  19. )
  20. def train(
  21. # model/data params
  22. base_model: str = "", # the only required argument
  23. data_path: str = "./alpaca_data_cleaned.json",
  24. output_dir: str = "./lora-alpaca",
  25. # training hyperparams
  26. batch_size: int = 128,
  27. micro_batch_size: int = 4,
  28. num_epochs: int = 3,
  29. learning_rate: float = 3e-4,
  30. cutoff_len: int = 256,
  31. val_set_size: int = 2000,
  32. # lora hyperparams
  33. lora_r: int = 8,
  34. lora_alpha: int = 16,
  35. lora_dropout: float = 0.05,
  36. lora_target_modules: List[str] = [
  37. "q_proj",
  38. "v_proj",
  39. ],
  40. # llm hyperparams
  41. train_on_inputs: bool = True, # if False, masks out inputs in loss
  42. group_by_length: bool = False, # faster, but produces an odd training loss curve
  43. ):
  44. print(
  45. f"Training Alpaca-LoRA model with params:\n"
  46. f"base_model: {base_model}\n"
  47. f"data_path: {data_path}\n"
  48. f"output_dir: {output_dir}\n"
  49. f"batch_size: {batch_size}\n"
  50. f"micro_batch_size: {micro_batch_size}\n"
  51. f"num_epochs: {num_epochs}\n"
  52. f"learning_rate: {learning_rate}\n"
  53. f"cutoff_len: {cutoff_len}\n"
  54. f"val_set_size: {val_set_size}\n"
  55. f"lora_r: {lora_r}\n"
  56. f"lora_alpha: {lora_alpha}\n"
  57. f"lora_dropout: {lora_dropout}\n"
  58. f"lora_target_modules: {lora_target_modules}\n"
  59. f"train_on_inputs: {train_on_inputs}\n"
  60. f"group_by_length: {group_by_length}\n"
  61. )
  62. assert (
  63. base_model
  64. ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
  65. gradient_accumulation_steps = batch_size // micro_batch_size
  66. device_map = "auto"
  67. world_size = int(os.environ.get("WORLD_SIZE", 1))
  68. ddp = world_size != 1
  69. if ddp:
  70. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  71. gradient_accumulation_steps = gradient_accumulation_steps // world_size
  72. model = LlamaForCausalLM.from_pretrained(
  73. base_model,
  74. load_in_8bit=True,
  75. device_map=device_map,
  76. )
  77. tokenizer = LlamaTokenizer.from_pretrained(base_model)
  78. tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
  79. tokenizer.padding_side = "left" # Allow batched inference
  80. def tokenize(prompt, add_eos_token=True):
  81. # there's probably a way to do this with the tokenizer settings
  82. # but again, gotta move fast
  83. result = tokenizer(
  84. prompt,
  85. truncation=True,
  86. max_length=cutoff_len,
  87. padding=False,
  88. return_tensors=None,
  89. )
  90. if (
  91. result["input_ids"][-1] != tokenizer.eos_token_id
  92. and len(result["input_ids"]) < cutoff_len
  93. and add_eos_token
  94. ):
  95. result["input_ids"].append(tokenizer.eos_token_id)
  96. result["attention_mask"].append(1)
  97. result["labels"] = result["input_ids"].copy()
  98. return result
  99. def generate_and_tokenize_prompt(data_point):
  100. full_prompt = generate_prompt(data_point)
  101. tokenized_full_prompt = tokenize(full_prompt)
  102. if not train_on_inputs:
  103. user_prompt = generate_prompt({**data_point, "output": ""})
  104. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
  105. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  106. tokenized_full_prompt["labels"] = [
  107. -100
  108. ] * user_prompt_len + tokenized_full_prompt["labels"][
  109. user_prompt_len:
  110. ] # could be sped up, probably
  111. return tokenized_full_prompt
  112. model = prepare_model_for_int8_training(model)
  113. config = LoraConfig(
  114. r=lora_r,
  115. lora_alpha=lora_alpha,
  116. target_modules=lora_target_modules,
  117. lora_dropout=lora_dropout,
  118. bias="none",
  119. task_type="CAUSAL_LM",
  120. )
  121. model = get_peft_model(model, config)
  122. data = load_dataset("json", data_files=data_path)
  123. if val_set_size > 0:
  124. train_val = data["train"].train_test_split(
  125. test_size=val_set_size, shuffle=True, seed=42
  126. )
  127. train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  128. val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  129. else:
  130. train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
  131. val_data = None
  132. trainer = transformers.Trainer(
  133. model=model,
  134. train_dataset=train_data,
  135. eval_dataset=val_data,
  136. args=transformers.TrainingArguments(
  137. per_device_train_batch_size=micro_batch_size,
  138. gradient_accumulation_steps=gradient_accumulation_steps,
  139. warmup_steps=100,
  140. num_train_epochs=num_epochs,
  141. learning_rate=learning_rate,
  142. fp16=True,
  143. logging_steps=10,
  144. evaluation_strategy="steps" if val_set_size > 0 else "no",
  145. save_strategy="steps",
  146. eval_steps=200 if val_set_size > 0 else None,
  147. save_steps=200,
  148. output_dir=output_dir,
  149. save_total_limit=3,
  150. load_best_model_at_end=True if val_set_size > 0 else False,
  151. ddp_find_unused_parameters=False if ddp else None,
  152. group_by_length=group_by_length,
  153. ),
  154. data_collator=transformers.DataCollatorForSeq2Seq(
  155. tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
  156. ),
  157. )
  158. model.config.use_cache = False
  159. old_state_dict = model.state_dict
  160. model.state_dict = (
  161. lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
  162. ).__get__(model, type(model))
  163. if torch.__version__ >= "2" and sys.platform != "win32":
  164. model = torch.compile(model)
  165. trainer.train()
  166. model.save_pretrained(output_dir)
  167. print("\n If there's a warning about missing keys above, please disregard :)")
  168. def generate_prompt(data_point):
  169. # sorry about the formatting disaster gotta move fast
  170. if data_point["input"]:
  171. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  172. ### Instruction:
  173. {data_point["instruction"]}
  174. ### Input:
  175. {data_point["input"]}
  176. ### Response:
  177. {data_point["output"]}"""
  178. else:
  179. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  180. ### Instruction:
  181. {data_point["instruction"]}
  182. ### Response:
  183. {data_point["output"]}"""
  184. if __name__ == "__main__":
  185. fire.Fire(train)