finetune.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import os
  2. import sys
  3. from typing import List
  4. import fire
  5. import torch
  6. import torch.nn as nn
  7. import bitsandbytes as bnb
  8. from datasets import load_dataset
  9. import transformers
  10. assert (
  11. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  12. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  13. from transformers import LlamaForCausalLM, LlamaTokenizer
  14. from peft import (
  15. prepare_model_for_int8_training,
  16. LoraConfig,
  17. get_peft_model,
  18. get_peft_model_state_dict,
  19. set_peft_model_state_dict,
  20. )
  21. def train(
  22. # model/data params
  23. base_model: str = "", # the only required argument
  24. data_path: str = "./alpaca_data_cleaned.json",
  25. output_dir: str = "./lora-alpaca",
  26. # training hyperparams
  27. batch_size: int = 128,
  28. micro_batch_size: int = 4,
  29. num_epochs: int = 3,
  30. learning_rate: float = 3e-4,
  31. cutoff_len: int = 256,
  32. val_set_size: int = 2000,
  33. # lora hyperparams
  34. lora_r: int = 8,
  35. lora_alpha: int = 16,
  36. lora_dropout: float = 0.05,
  37. lora_target_modules: List[str] = [
  38. "q_proj",
  39. "v_proj",
  40. ],
  41. # llm hyperparams
  42. train_on_inputs: bool = True, # if False, masks out inputs in loss
  43. group_by_length: bool = False, # faster, but produces an odd training loss curve,
  44. resume_from_checkpoint: str = None, # either training checkpoint or final adapter
  45. ):
  46. print(
  47. f"Training Alpaca-LoRA model with params:\n"
  48. f"base_model: {base_model}\n"
  49. f"data_path: {data_path}\n"
  50. f"output_dir: {output_dir}\n"
  51. f"batch_size: {batch_size}\n"
  52. f"micro_batch_size: {micro_batch_size}\n"
  53. f"num_epochs: {num_epochs}\n"
  54. f"learning_rate: {learning_rate}\n"
  55. f"cutoff_len: {cutoff_len}\n"
  56. f"val_set_size: {val_set_size}\n"
  57. f"lora_r: {lora_r}\n"
  58. f"lora_alpha: {lora_alpha}\n"
  59. f"lora_dropout: {lora_dropout}\n"
  60. f"lora_target_modules: {lora_target_modules}\n"
  61. f"train_on_inputs: {train_on_inputs}\n"
  62. f"group_by_length: {group_by_length}\n"
  63. f"resume_from_checkpoint: {resume_from_checkpoint}\n"
  64. )
  65. assert (
  66. base_model
  67. ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
  68. gradient_accumulation_steps = batch_size // micro_batch_size
  69. device_map = "auto"
  70. world_size = int(os.environ.get("WORLD_SIZE", 1))
  71. ddp = world_size != 1
  72. if ddp:
  73. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  74. gradient_accumulation_steps = gradient_accumulation_steps // world_size
  75. model = LlamaForCausalLM.from_pretrained(
  76. base_model,
  77. load_in_8bit=True,
  78. device_map=device_map,
  79. )
  80. tokenizer = LlamaTokenizer.from_pretrained(base_model)
  81. tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
  82. tokenizer.padding_side = "left" # Allow batched inference
  83. def tokenize(prompt, add_eos_token=True):
  84. # there's probably a way to do this with the tokenizer settings
  85. # but again, gotta move fast
  86. result = tokenizer(
  87. prompt,
  88. truncation=True,
  89. max_length=cutoff_len,
  90. padding=False,
  91. return_tensors=None,
  92. )
  93. if (
  94. result["input_ids"][-1] != tokenizer.eos_token_id
  95. and len(result["input_ids"]) < cutoff_len
  96. and add_eos_token
  97. ):
  98. result["input_ids"].append(tokenizer.eos_token_id)
  99. result["attention_mask"].append(1)
  100. result["labels"] = result["input_ids"].copy()
  101. return result
  102. def generate_and_tokenize_prompt(data_point):
  103. full_prompt = generate_prompt(data_point)
  104. tokenized_full_prompt = tokenize(full_prompt)
  105. if not train_on_inputs:
  106. user_prompt = generate_prompt({**data_point, "output": ""})
  107. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
  108. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  109. tokenized_full_prompt["labels"] = [
  110. -100
  111. ] * user_prompt_len + tokenized_full_prompt["labels"][
  112. user_prompt_len:
  113. ] # could be sped up, probably
  114. return tokenized_full_prompt
  115. model = prepare_model_for_int8_training(model)
  116. config = LoraConfig(
  117. r=lora_r,
  118. lora_alpha=lora_alpha,
  119. target_modules=lora_target_modules,
  120. lora_dropout=lora_dropout,
  121. bias="none",
  122. task_type="CAUSAL_LM",
  123. )
  124. model = get_peft_model(model, config)
  125. data = load_dataset("json", data_files=data_path)
  126. if resume_from_checkpoint:
  127. # Check the available weights and load them
  128. checkpoint_name = os.path.join(
  129. resume_from_checkpoint, "pytorch_model.bin"
  130. ) # Full checkpoint
  131. if not os.path.exists(checkpoint_name):
  132. checkpoint_name = os.path.join(
  133. resume_from_checkpoint, "adapter_model.bin"
  134. ) # only LoRA model - LoRA config above has to fit
  135. resume_from_checkpoint = False # So the trainer won't try loading its state
  136. # The two files above have a different name depending on how they were saved, but are actually the same.
  137. if os.path.exists(checkpoint_name):
  138. print(f"Restarting from {checkpoint_name}")
  139. adapters_weights = torch.load(checkpoint_name)
  140. model = set_peft_model_state_dict(model, adapters_weights)
  141. model.print_trainable_parameters() # Be more transparent about the % of trainable params.
  142. if val_set_size > 0:
  143. train_val = data["train"].train_test_split(
  144. test_size=val_set_size, shuffle=True, seed=42
  145. )
  146. train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  147. val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  148. else:
  149. train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
  150. val_data = None
  151. trainer = transformers.Trainer(
  152. model=model,
  153. train_dataset=train_data,
  154. eval_dataset=val_data,
  155. args=transformers.TrainingArguments(
  156. per_device_train_batch_size=micro_batch_size,
  157. gradient_accumulation_steps=gradient_accumulation_steps,
  158. warmup_steps=100,
  159. num_train_epochs=num_epochs,
  160. learning_rate=learning_rate,
  161. fp16=True,
  162. logging_steps=10,
  163. evaluation_strategy="steps" if val_set_size > 0 else "no",
  164. save_strategy="steps",
  165. eval_steps=200 if val_set_size > 0 else None,
  166. save_steps=200,
  167. output_dir=output_dir,
  168. save_total_limit=3,
  169. load_best_model_at_end=True if val_set_size > 0 else False,
  170. ddp_find_unused_parameters=False if ddp else None,
  171. group_by_length=group_by_length,
  172. ),
  173. data_collator=transformers.DataCollatorForSeq2Seq(
  174. tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
  175. ),
  176. )
  177. model.config.use_cache = False
  178. old_state_dict = model.state_dict
  179. model.state_dict = (
  180. lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
  181. ).__get__(model, type(model))
  182. if torch.__version__ >= "2" and sys.platform != "win32":
  183. model = torch.compile(model)
  184. trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  185. model.save_pretrained(output_dir)
  186. print("\n If there's a warning about missing keys above, please disregard :)")
  187. def generate_prompt(data_point):
  188. # sorry about the formatting disaster gotta move fast
  189. if data_point["input"]:
  190. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  191. ### Instruction:
  192. {data_point["instruction"]}
  193. ### Input:
  194. {data_point["input"]}
  195. ### Response:
  196. {data_point["output"]}"""
  197. else:
  198. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  199. ### Instruction:
  200. {data_point["instruction"]}
  201. ### Response:
  202. {data_point["output"]}"""
  203. if __name__ == "__main__":
  204. fire.Fire(train)