finetune.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import os
  2. import sys
  3. from typing import List
  4. import fire
  5. import torch
  6. import transformers
  7. from datasets import load_dataset
  8. """
  9. Unused imports:
  10. import torch.nn as nn
  11. import bitsandbytes as bnb
  12. """
  13. from peft import (
  14. LoraConfig,
  15. get_peft_model,
  16. get_peft_model_state_dict,
  17. prepare_model_for_int8_training,
  18. set_peft_model_state_dict,
  19. )
  20. from transformers import LlamaForCausalLM, LlamaTokenizer
  21. from utils.prompter import Prompter
  22. def train(
  23. # model/data params
  24. base_model: str = "", # the only required argument
  25. data_path: str = "yahma/alpaca-cleaned",
  26. output_dir: str = "./lora-alpaca",
  27. # training hyperparams
  28. batch_size: int = 128,
  29. micro_batch_size: int = 4,
  30. num_epochs: int = 3,
  31. learning_rate: float = 3e-4,
  32. cutoff_len: int = 256,
  33. val_set_size: int = 2000,
  34. # lora hyperparams
  35. lora_r: int = 8,
  36. lora_alpha: int = 16,
  37. lora_dropout: float = 0.05,
  38. lora_target_modules: List[str] = [
  39. "q_proj",
  40. "v_proj",
  41. ],
  42. # llm hyperparams
  43. train_on_inputs: bool = True, # if False, masks out inputs in loss
  44. add_eos_token: bool = False,
  45. group_by_length: bool = False, # faster, but produces an odd training loss curve
  46. # wandb params
  47. wandb_project: str = "",
  48. wandb_run_name: str = "",
  49. wandb_watch: str = "", # options: false | gradients | all
  50. wandb_log_model: str = "", # options: false | true
  51. resume_from_checkpoint: str = None, # either training checkpoint or final adapter
  52. prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
  53. ):
  54. if int(os.environ.get("LOCAL_RANK", 0)) == 0:
  55. print(
  56. f"Training Alpaca-LoRA model with params:\n"
  57. f"base_model: {base_model}\n"
  58. f"data_path: {data_path}\n"
  59. f"output_dir: {output_dir}\n"
  60. f"batch_size: {batch_size}\n"
  61. f"micro_batch_size: {micro_batch_size}\n"
  62. f"num_epochs: {num_epochs}\n"
  63. f"learning_rate: {learning_rate}\n"
  64. f"cutoff_len: {cutoff_len}\n"
  65. f"val_set_size: {val_set_size}\n"
  66. f"lora_r: {lora_r}\n"
  67. f"lora_alpha: {lora_alpha}\n"
  68. f"lora_dropout: {lora_dropout}\n"
  69. f"lora_target_modules: {lora_target_modules}\n"
  70. f"train_on_inputs: {train_on_inputs}\n"
  71. f"add_eos_token: {add_eos_token}\n"
  72. f"group_by_length: {group_by_length}\n"
  73. f"wandb_project: {wandb_project}\n"
  74. f"wandb_run_name: {wandb_run_name}\n"
  75. f"wandb_watch: {wandb_watch}\n"
  76. f"wandb_log_model: {wandb_log_model}\n"
  77. f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
  78. f"prompt template: {prompt_template_name}\n"
  79. )
  80. assert (
  81. base_model
  82. ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
  83. gradient_accumulation_steps = batch_size // micro_batch_size
  84. prompter = Prompter(prompt_template_name)
  85. device_map = "auto"
  86. world_size = int(os.environ.get("WORLD_SIZE", 1))
  87. ddp = world_size != 1
  88. if ddp:
  89. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  90. gradient_accumulation_steps = gradient_accumulation_steps // world_size
  91. # Check if parameter passed or if set within environ
  92. use_wandb = len(wandb_project) > 0 or (
  93. "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
  94. )
  95. # Only overwrite environ if wandb param passed
  96. if len(wandb_project) > 0:
  97. os.environ["WANDB_PROJECT"] = wandb_project
  98. if len(wandb_watch) > 0:
  99. os.environ["WANDB_WATCH"] = wandb_watch
  100. if len(wandb_log_model) > 0:
  101. os.environ["WANDB_LOG_MODEL"] = wandb_log_model
  102. model = LlamaForCausalLM.from_pretrained(
  103. base_model,
  104. load_in_8bit=True,
  105. torch_dtype=torch.float16,
  106. device_map=device_map,
  107. )
  108. tokenizer = LlamaTokenizer.from_pretrained(base_model)
  109. tokenizer.pad_token_id = (
  110. 0 # unk. we want this to be different from the eos token
  111. )
  112. tokenizer.padding_side = "left" # Allow batched inference
  113. def tokenize(prompt, add_eos_token=True):
  114. # there's probably a way to do this with the tokenizer settings
  115. # but again, gotta move fast
  116. result = tokenizer(
  117. prompt,
  118. truncation=True,
  119. max_length=cutoff_len,
  120. padding=False,
  121. return_tensors=None,
  122. )
  123. if (
  124. result["input_ids"][-1] != tokenizer.eos_token_id
  125. and len(result["input_ids"]) < cutoff_len
  126. and add_eos_token
  127. ):
  128. result["input_ids"].append(tokenizer.eos_token_id)
  129. result["attention_mask"].append(1)
  130. result["labels"] = result["input_ids"].copy()
  131. return result
  132. def generate_and_tokenize_prompt(data_point):
  133. full_prompt = prompter.generate_prompt(
  134. data_point["instruction"],
  135. data_point["input"],
  136. data_point["output"],
  137. )
  138. tokenized_full_prompt = tokenize(full_prompt)
  139. if not train_on_inputs:
  140. user_prompt = prompter.generate_prompt(
  141. data_point["instruction"], data_point["input"]
  142. )
  143. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos_token)
  144. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  145. if add_eos_token:
  146. user_prompt_len -= 1
  147. tokenized_full_prompt["labels"] = [
  148. -100
  149. ] * user_prompt_len + tokenized_full_prompt["labels"][
  150. user_prompt_len:
  151. ] # could be sped up, probably
  152. return tokenized_full_prompt
  153. model = prepare_model_for_int8_training(model)
  154. config = LoraConfig(
  155. r=lora_r,
  156. lora_alpha=lora_alpha,
  157. target_modules=lora_target_modules,
  158. lora_dropout=lora_dropout,
  159. bias="none",
  160. task_type="CAUSAL_LM",
  161. )
  162. model = get_peft_model(model, config)
  163. if data_path.endswith(".json") or data_path.endswith(".jsonl"):
  164. data = load_dataset("json", data_files=data_path)
  165. else:
  166. data = load_dataset(data_path)
  167. if resume_from_checkpoint:
  168. # Check the available weights and load them
  169. checkpoint_name = os.path.join(
  170. resume_from_checkpoint, "pytorch_model.bin"
  171. ) # Full checkpoint
  172. if not os.path.exists(checkpoint_name):
  173. checkpoint_name = os.path.join(
  174. resume_from_checkpoint, "adapter_model.bin"
  175. ) # only LoRA model - LoRA config above has to fit
  176. resume_from_checkpoint = (
  177. False # So the trainer won't try loading its state
  178. )
  179. # The two files above have a different name depending on how they were saved, but are actually the same.
  180. if os.path.exists(checkpoint_name):
  181. print(f"Restarting from {checkpoint_name}")
  182. adapters_weights = torch.load(checkpoint_name)
  183. set_peft_model_state_dict(model, adapters_weights)
  184. else:
  185. print(f"Checkpoint {checkpoint_name} not found")
  186. model.print_trainable_parameters() # Be more transparent about the % of trainable params.
  187. if val_set_size > 0:
  188. train_val = data["train"].train_test_split(
  189. test_size=val_set_size, shuffle=True, seed=42
  190. )
  191. train_data = (
  192. train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  193. )
  194. val_data = (
  195. train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  196. )
  197. else:
  198. train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
  199. val_data = None
  200. if not ddp and torch.cuda.device_count() > 1:
  201. # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
  202. model.is_parallelizable = True
  203. model.model_parallel = True
  204. trainer = transformers.Trainer(
  205. model=model,
  206. train_dataset=train_data,
  207. eval_dataset=val_data,
  208. args=transformers.TrainingArguments(
  209. per_device_train_batch_size=micro_batch_size,
  210. gradient_accumulation_steps=gradient_accumulation_steps,
  211. warmup_steps=100,
  212. num_train_epochs=num_epochs,
  213. learning_rate=learning_rate,
  214. fp16=True,
  215. logging_steps=10,
  216. optim="adamw_torch",
  217. evaluation_strategy="steps" if val_set_size > 0 else "no",
  218. save_strategy="steps",
  219. eval_steps=200 if val_set_size > 0 else None,
  220. save_steps=200,
  221. output_dir=output_dir,
  222. save_total_limit=3,
  223. load_best_model_at_end=True if val_set_size > 0 else False,
  224. ddp_find_unused_parameters=False if ddp else None,
  225. group_by_length=group_by_length,
  226. report_to="wandb" if use_wandb else None,
  227. run_name=wandb_run_name if use_wandb else None,
  228. ),
  229. data_collator=transformers.DataCollatorForSeq2Seq(
  230. tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
  231. ),
  232. )
  233. model.config.use_cache = False
  234. old_state_dict = model.state_dict
  235. model.state_dict = (
  236. lambda self, *_, **__: get_peft_model_state_dict(
  237. self, old_state_dict()
  238. )
  239. ).__get__(model, type(model))
  240. if torch.__version__ >= "2" and sys.platform != "win32":
  241. model = torch.compile(model)
  242. trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  243. model.save_pretrained(output_dir)
  244. print(
  245. "\n If there's a warning about missing keys above, please disregard :)"
  246. )
  247. if __name__ == "__main__":
  248. fire.Fire(train)