finetune.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import os
  2. import sys
  3. import torch
  4. import torch.nn as nn
  5. import bitsandbytes as bnb
  6. from datasets import load_dataset
  7. import transformers
  8. assert (
  9. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  10. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  11. from transformers import LlamaForCausalLM, LlamaTokenizer
  12. from peft import (
  13. prepare_model_for_int8_training,
  14. LoraConfig,
  15. get_peft_model,
  16. get_peft_model_state_dict,
  17. )
  18. # optimized for RTX 4090. for larger GPUs, increase some of these?
  19. MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
  20. BATCH_SIZE = 128
  21. GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
  22. EPOCHS = 3
  23. LEARNING_RATE = 3e-4
  24. CUTOFF_LEN = 512
  25. LORA_R = 8
  26. LORA_ALPHA = 16
  27. LORA_DROPOUT = 0.05
  28. VAL_SET_SIZE = 2000
  29. TARGET_MODULES = [
  30. "q_proj",
  31. "v_proj",
  32. ]
  33. DATA_PATH = "alpaca_data_cleaned.json"
  34. OUTPUT_DIR = "lora-alpaca"
  35. BASE_MODEL = None
  36. assert (
  37. BASE_MODEL
  38. ), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
  39. TRAIN_ON_INPUTS = True
  40. GROUP_BY_LENGTH = True # faster, but produces an odd training loss curve
  41. device_map = "auto"
  42. world_size = int(os.environ.get("WORLD_SIZE", 1))
  43. ddp = world_size != 1
  44. if ddp:
  45. device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
  46. GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
  47. model = LlamaForCausalLM.from_pretrained(
  48. BASE_MODEL,
  49. load_in_8bit=True,
  50. device_map=device_map,
  51. )
  52. tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
  53. tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
  54. tokenizer.padding_side = "left" # Allow batched inference
  55. model = prepare_model_for_int8_training(model)
  56. config = LoraConfig(
  57. r=LORA_R,
  58. lora_alpha=LORA_ALPHA,
  59. target_modules=TARGET_MODULES,
  60. lora_dropout=LORA_DROPOUT,
  61. bias="none",
  62. task_type="CAUSAL_LM",
  63. )
  64. model = get_peft_model(model, config)
  65. data = load_dataset("json", data_files=DATA_PATH)
  66. def generate_prompt(data_point):
  67. # sorry about the formatting disaster gotta move fast
  68. if data_point["input"]:
  69. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  70. ### Instruction:
  71. {data_point["instruction"]}
  72. ### Input:
  73. {data_point["input"]}
  74. ### Response:
  75. {data_point["output"]}"""
  76. else:
  77. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  78. ### Instruction:
  79. {data_point["instruction"]}
  80. ### Response:
  81. {data_point["output"]}"""
  82. def tokenize(prompt, add_eos_token=True):
  83. # there's probably a way to do this with the tokenizer settings
  84. # but again, gotta move fast
  85. result = tokenizer(
  86. prompt,
  87. truncation=True,
  88. max_length=CUTOFF_LEN,
  89. padding=False,
  90. return_tensors=None,
  91. )
  92. if (
  93. result["input_ids"][-1] != tokenizer.eos_token_id
  94. and len(result["input_ids"]) < CUTOFF_LEN
  95. and add_eos_token
  96. ):
  97. result["input_ids"].append(tokenizer.eos_token_id)
  98. result["attention_mask"].append(1)
  99. result["labels"] = result["input_ids"].copy()
  100. return result
  101. def generate_and_tokenize_prompt(data_point):
  102. full_prompt = generate_prompt(data_point)
  103. tokenized_full_prompt = tokenize(full_prompt)
  104. if not TRAIN_ON_INPUTS:
  105. user_prompt = generate_prompt({**data_point, "output": ""})
  106. tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
  107. user_prompt_len = len(tokenized_user_prompt["input_ids"])
  108. tokenized_full_prompt["labels"] = [
  109. -100
  110. ] * user_prompt_len + tokenized_full_prompt["labels"][
  111. user_prompt_len:
  112. ] # could be sped up, probably
  113. return tokenized_full_prompt
  114. if VAL_SET_SIZE > 0:
  115. train_val = data["train"].train_test_split(
  116. test_size=VAL_SET_SIZE, shuffle=True, seed=42
  117. )
  118. train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
  119. val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  120. else:
  121. train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
  122. val_data = None
  123. trainer = transformers.Trainer(
  124. model=model,
  125. train_dataset=train_data,
  126. eval_dataset=val_data,
  127. args=transformers.TrainingArguments(
  128. per_device_train_batch_size=MICRO_BATCH_SIZE,
  129. gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  130. warmup_steps=100,
  131. num_train_epochs=EPOCHS,
  132. learning_rate=LEARNING_RATE,
  133. fp16=True,
  134. logging_steps=10,
  135. evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
  136. save_strategy="steps",
  137. eval_steps=200 if VAL_SET_SIZE > 0 else None,
  138. save_steps=200,
  139. output_dir=OUTPUT_DIR,
  140. save_total_limit=3,
  141. load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
  142. ddp_find_unused_parameters=False if ddp else None,
  143. group_by_length=GROUP_BY_LENGTH,
  144. ),
  145. data_collator=transformers.DataCollatorForSeq2Seq(
  146. tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
  147. ),
  148. )
  149. model.config.use_cache = False
  150. old_state_dict = model.state_dict
  151. model.state_dict = (
  152. lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
  153. ).__get__(model, type(model))
  154. if torch.__version__ >= "2" and sys.platform != "win32":
  155. model = torch.compile(model)
  156. trainer.train()
  157. model.save_pretrained(OUTPUT_DIR)
  158. print("\n If there's a warning about missing keys above, please disregard :)")