finetune.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import os
  2. import sys
  3. from typing import List
  4. import fire
  5. import torch
  6. import transformers
  7. from datasets import load_dataset
  8. """
  9. Unused imports:
  10. import torch.nn as nn
  11. import bitsandbytes as bnb
  12. """
  13. # Catch when user should re-install transformers library
  14. assert (
  15. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  16. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
  17. from peft import ( # noqa: E402
  18. LoraConfig,
  19. get_peft_model,
  20. get_peft_model_state_dict,
  21. prepare_model_for_int8_training,
  22. set_peft_model_state_dict,
  23. )
  24. from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
  25. def train(
  26. # model/data params
  27. base_model: str = "", # the only required argument
  28. data_path: str = "./alpaca_data_cleaned.json",
  29. output_dir: str = "./lora-alpaca",
  30. # training hyperparams
  31. batch_size: int = 128,
  32. micro_batch_size: int = 4,
  33. num_epochs: int = 3,
  34. learning_rate: float = 3e-4,
  35. cutoff_len: int = 256,
  36. val_set_size: int = 2000,
  37. # lora hyperparams
  38. lora_r: int = 8,
  39. lora_alpha: int = 16,
  40. lora_dropout: float = 0.05,
  41. lora_target_modules: List[str] = [
  42. "q_proj",
  43. "v_proj",
  44. ],
  45. # llm hyperparams
  46. train_on_inputs: bool = True, # if False, masks out inputs in loss
  47. group_by_length: bool = False, # faster, but produces an odd training loss curve
  48. # wandb params
  49. wandb_project: str = "",
  50. wandb_run_name: str = "",
  51. wandb_watch: str = "", # options: false | gradients | all
  52. wandb_log_model: str = "", # options: false | true
  53. resume_from_checkpoint: str = None, # either training checkpoint or final adapter
  54. ):
  55. print(
  56. f"Training Alpaca-LoRA model with params:\n"
  57. f"base_model: {base_model}\n"
  58. f"data_path: {data_path}\n"
  59. f"output_dir: {output_dir}\n"
  60. f"batch_size: {batch_size}\n"
  61. f"micro_batch_size: {micro_batch_size}\n"
  62. f"num_epochs: {num_epochs}\n"
  63. f"learning_rate: {learning_rate}\n"
  64. f"cutoff_len: {cutoff_len}\n"
  65. f"val_set_size: {val_set_size}\n"
  66. f"lora_r: {lora_r}\n"
  67. f"lora_alpha: {lora_alpha}\n"
  68. f"lora_dropout: {lora_dropout}\n"
  69. f"lora_target_modules: {lora_target_modules}\n"
  70. f"train_on_inputs: {train_on_inputs}\n"
  71. f"group_by_length: {group_by_length}\n"
  72. f"wandb_project: {wandb_project}\n"
  73. f"wandb_run_name: {wandb_run_name}\n"
  74. f"wandb_watch: {wandb_watch}\n"
  75. f"wandb_log_model: {wandb_log_model}\n"
  76. f"resume_from_checkpoint: {resume_from_checkpoint}\n"
  77. )
  78. assert (
  79. base_model
  80. ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
  81. gradient_accumulation_steps = batch_size // micro_batch_size
  82. device_map = "auto"
  83. world_size = int(os.environ.get("WORLD_SIZE", 1))
  84. ddp = world_size != 1
  85. if ddp:
  86. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  87. gradient_accumulation_steps = gradient_accumulation_steps // world_size
  88. # Check if parameter passed or if set within environ
  89. use_wandb = len(wandb_project) > 0 or (
  90. "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
  91. )
  92. # Only overwrite environ if wandb param passed
  93. if len(wandb_project) > 0:
  94. os.environ["WANDB_PROJECT"] = wandb_project
  95. if len(wandb_watch) > 0:
  96. os.environ["WANDB_WATCH"] = wandb_watch
  97. if len(wandb_log_model) > 0:
  98. os.environ["WANDB_LOG_MODEL"] = wandb_log_model
  99. model = LlamaForCausalLM.from_pretrained(
  100. base_model,
  101. load_in_8bit=True,
  102. device_map=device_map,
  103. )
  104. tokenizer = LlamaTokenizer.from_pretrained(base_model)
  105. tokenizer.pad_token_id = (
  106. 0 # unk. we want this to be different from the eos token
  107. )
  108. tokenizer.padding_side = "left" # Allow batched inference
  109. def tokenize(prompt, add_eos_token=True):
  110. # there's probably a way to do this with the tokenizer settings
  111. # but again, gotta move fast
  112. result = tokenizer(
  113. prompt,
  114. truncation=True,
  115. max_length=cutoff_len,
  116. padding=False,
  117. return_tensors=None,
  118. )
  119. if (
  120. result["input_ids"][-1] != tokenizer.eos_token_id
  121. and len(result["input_ids"]) < cutoff_len
  122. and add_eos_token
  123. ):
  124. result["input_ids"].append(tokenizer.eos_token_id)
  125. result["attention_mask"].append(1)
  126. result["labels"] = result["input_ids"].copy()
  127. return result
  128. def generate_and_tokenize_prompt(data_point):
  129. full_prompt = generate_prompt(data_point)
  130. tokenized_full_prompt = tokenize(full_prompt)
  131. if not train_on_inputs:
  132. user_prompt = generate_prompt({**data_point, "output": ""})
  133. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
  134. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  135. tokenized_full_prompt["labels"] = [
  136. -100
  137. ] * user_prompt_len + tokenized_full_prompt["labels"][
  138. user_prompt_len:
  139. ] # could be sped up, probably
  140. return tokenized_full_prompt
  141. model = prepare_model_for_int8_training(model)
  142. config = LoraConfig(
  143. r=lora_r,
  144. lora_alpha=lora_alpha,
  145. target_modules=lora_target_modules,
  146. lora_dropout=lora_dropout,
  147. bias="none",
  148. task_type="CAUSAL_LM",
  149. )
  150. model = get_peft_model(model, config)
  151. if data_path.endswith(".json"): # todo: support jsonl
  152. data = load_dataset("json", data_files=data_path)
  153. else:
  154. data = load_dataset(data_path)
  155. if resume_from_checkpoint:
  156. # Check the available weights and load them
  157. checkpoint_name = os.path.join(
  158. resume_from_checkpoint, "pytorch_model.bin"
  159. ) # Full checkpoint
  160. if not os.path.exists(checkpoint_name):
  161. checkpoint_name = os.path.join(
  162. resume_from_checkpoint, "adapter_model.bin"
  163. ) # only LoRA model - LoRA config above has to fit
  164. resume_from_checkpoint = (
  165. False # So the trainer won't try loading its state
  166. )
  167. # The two files above have a different name depending on how they were saved, but are actually the same.
  168. if os.path.exists(checkpoint_name):
  169. print(f"Restarting from {checkpoint_name}")
  170. adapters_weights = torch.load(checkpoint_name)
  171. model = set_peft_model_state_dict(model, adapters_weights)
  172. else:
  173. print(f"Checkpoint {checkpoint_name} not found")
  174. model.print_trainable_parameters() # Be more transparent about the % of trainable params.
  175. if val_set_size > 0:
  176. train_val = data["train"].train_test_split(
  177. test_size=val_set_size, shuffle=True, seed=42
  178. )
  179. train_data = (
  180. train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  181. )
  182. val_data = (
  183. train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  184. )
  185. else:
  186. train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
  187. val_data = None
  188. trainer = transformers.Trainer(
  189. model=model,
  190. train_dataset=train_data,
  191. eval_dataset=val_data,
  192. args=transformers.TrainingArguments(
  193. per_device_train_batch_size=micro_batch_size,
  194. gradient_accumulation_steps=gradient_accumulation_steps,
  195. warmup_steps=100,
  196. num_train_epochs=num_epochs,
  197. learning_rate=learning_rate,
  198. fp16=True,
  199. logging_steps=10,
  200. optim="adamw_torch",
  201. evaluation_strategy="steps" if val_set_size > 0 else "no",
  202. save_strategy="steps",
  203. eval_steps=200 if val_set_size > 0 else None,
  204. save_steps=200,
  205. output_dir=output_dir,
  206. save_total_limit=3,
  207. load_best_model_at_end=True if val_set_size > 0 else False,
  208. ddp_find_unused_parameters=False if ddp else None,
  209. group_by_length=group_by_length,
  210. report_to="wandb" if use_wandb else None,
  211. run_name=wandb_run_name if use_wandb else None,
  212. ),
  213. data_collator=transformers.DataCollatorForSeq2Seq(
  214. tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
  215. ),
  216. )
  217. model.config.use_cache = False
  218. old_state_dict = model.state_dict
  219. model.state_dict = (
  220. lambda self, *_, **__: get_peft_model_state_dict(
  221. self, old_state_dict()
  222. )
  223. ).__get__(model, type(model))
  224. if torch.__version__ >= "2" and sys.platform != "win32":
  225. model = torch.compile(model)
  226. trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  227. model.save_pretrained(output_dir)
  228. print(
  229. "\n If there's a warning about missing keys above, please disregard :)"
  230. )
  231. def generate_prompt(data_point):
  232. # sorry about the formatting disaster gotta move fast
  233. if data_point["input"]:
  234. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
  235. ### Instruction:
  236. {data_point["instruction"]}
  237. ### Input:
  238. {data_point["input"]}
  239. ### Response:
  240. {data_point["output"]}"""
  241. else:
  242. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
  243. ### Instruction:
  244. {data_point["instruction"]}
  245. ### Response:
  246. {data_point["output"]}"""
  247. if __name__ == "__main__":
  248. fire.Fire(train)