finetune.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import os
  2. import torch
  3. import torch.nn as nn
  4. import bitsandbytes as bnb
  5. from datasets import load_dataset
  6. import transformers
  7. assert (
  8. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  9. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  10. from transformers import LlamaForCausalLM, LlamaTokenizer
  11. from peft import (
  12. prepare_model_for_int8_training,
  13. LoraConfig,
  14. get_peft_model,
  15. get_peft_model_state_dict,
  16. )
  17. # optimized for RTX 4090. for larger GPUs, increase some of these?
  18. MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
  19. BATCH_SIZE = 128
  20. GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
  21. EPOCHS = 3 # we don't always need 3 tbh
  22. LEARNING_RATE = 3e-4 # the Karpathy constant
  23. CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
  24. LORA_R = 8
  25. LORA_ALPHA = 16
  26. LORA_DROPOUT = 0.05
  27. VAL_SET_SIZE = 2000
  28. TARGET_MODULES = [
  29. "q_proj",
  30. "v_proj",
  31. ]
  32. DATA_PATH = "alpaca_data_cleaned.json"
  33. device_map = "auto"
  34. world_size = int(os.environ.get("WORLD_SIZE", 1))
  35. ddp = world_size != 1
  36. if ddp:
  37. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  38. GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
  39. model = LlamaForCausalLM.from_pretrained(
  40. "decapoda-research/llama-7b-hf",
  41. load_in_8bit=True,
  42. device_map=device_map,
  43. )
  44. tokenizer = LlamaTokenizer.from_pretrained(
  45. "decapoda-research/llama-7b-hf", add_eos_token=True
  46. )
  47. model = prepare_model_for_int8_training(model)
  48. config = LoraConfig(
  49. r=LORA_R,
  50. lora_alpha=LORA_ALPHA,
  51. target_modules=TARGET_MODULES,
  52. lora_dropout=LORA_DROPOUT,
  53. bias="none",
  54. task_type="CAUSAL_LM",
  55. )
  56. model = get_peft_model(model, config)
  57. tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
  58. data = load_dataset("json", data_files=DATA_PATH)
  59. train_val = data["train"].train_test_split(
  60. test_size=VAL_SET_SIZE, shuffle=True, seed=42
  61. )
  62. train_data = train_val["train"]
  63. val_data = train_val["test"]
  64. def generate_prompt(data_point):
  65. # sorry about the formatting disaster gotta move fast
  66. if data_point["input"]:
  67. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  68. ### Instruction:
  69. {data_point["instruction"]}
  70. ### Input:
  71. {data_point["input"]}
  72. ### Response:
  73. {data_point["output"]}"""
  74. else:
  75. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  76. ### Instruction:
  77. {data_point["instruction"]}
  78. ### Response:
  79. {data_point["output"]}"""
  80. def tokenize(prompt):
  81. # there's probably a way to do this with the tokenizer settings
  82. # but again, gotta move fast
  83. result = tokenizer(
  84. prompt,
  85. truncation=True,
  86. max_length=CUTOFF_LEN + 1,
  87. padding="max_length",
  88. )
  89. return {
  90. "input_ids": result["input_ids"][:-1],
  91. "attention_mask": result["attention_mask"][:-1],
  92. }
  93. def generate_and_tokenize_prompt(data_point):
  94. # This function masks out the labels for the input,
  95. # so that our loss is computed only on the response.
  96. user_prompt = (
  97. (
  98. f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  99. ### Instruction:
  100. {data_point["instruction"]}
  101. ### Input:
  102. {data_point["input"]}
  103. ### Response:
  104. """
  105. )
  106. if data_point["input"]
  107. else (
  108. f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  109. ### Instruction:
  110. {data_point["instruction"]}
  111. ### Response:
  112. """
  113. )
  114. )
  115. len_user_prompt_tokens = (
  116. len(
  117. tokenizer(
  118. user_prompt,
  119. truncation=True,
  120. max_length=CUTOFF_LEN + 1,
  121. padding="max_length",
  122. )["input_ids"]
  123. )
  124. - 1
  125. ) # no eos token
  126. full_tokens = tokenizer(
  127. user_prompt + data_point["output"],
  128. truncation=True,
  129. max_length=CUTOFF_LEN + 1,
  130. padding="max_length",
  131. )["input_ids"][:-1]
  132. return {
  133. "input_ids": full_tokens,
  134. "labels": [-100] * len_user_prompt_tokens
  135. + full_tokens[len_user_prompt_tokens:],
  136. "attention_mask": [1] * (len(full_tokens)),
  137. }
  138. train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
  139. val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
  140. trainer = transformers.Trainer(
  141. model=model,
  142. train_dataset=train_data,
  143. eval_dataset=val_data,
  144. args=transformers.TrainingArguments(
  145. per_device_train_batch_size=MICRO_BATCH_SIZE,
  146. gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  147. warmup_steps=100,
  148. num_train_epochs=EPOCHS,
  149. learning_rate=LEARNING_RATE,
  150. fp16=True,
  151. logging_steps=20,
  152. evaluation_strategy="steps",
  153. save_strategy="steps",
  154. eval_steps=200,
  155. save_steps=200,
  156. output_dir="lora-alpaca",
  157. save_total_limit=3,
  158. load_best_model_at_end=True,
  159. ddp_find_unused_parameters=False if ddp else None,
  160. ),
  161. data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
  162. )
  163. model.config.use_cache = False
  164. old_state_dict = model.state_dict
  165. model.state_dict = (
  166. lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
  167. ).__get__(model, type(model))
  168. if torch.__version__ >= "2":
  169. model = torch.compile(model)
  170. trainer.train()
  171. model.save_pretrained("lora-alpaca")
  172. print("\n If there's a warning about missing keys above, please disregard :)")