finetune.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import os
  2. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  3. import torch
  4. import torch.nn as nn
  5. import bitsandbytes as bnb
  6. from datasets import load_dataset
  7. import transformers
  8. from transformers import AutoTokenizer, AutoConfig, LLaMAForCausalLM, LLaMATokenizer
  9. from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
  10. # optimized for RTX 4090. for larger GPUs, increase some of these?
  11. MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
  12. BATCH_SIZE = 128
  13. GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
  14. EPOCHS = 3 # we don't need 3 tbh
  15. LEARNING_RATE = 3e-4 # the Karpathy constant
  16. CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
  17. LORA_R = 8
  18. LORA_ALPHA = 16
  19. LORA_DROPOUT = 0.05
  20. model = LLaMAForCausalLM.from_pretrained(
  21. "decapoda-research/llama-7b-hf",
  22. load_in_8bit=True,
  23. device_map="auto",
  24. )
  25. tokenizer = LLaMATokenizer.from_pretrained(
  26. "decapoda-research/llama-7b-hf", add_eos_token=True
  27. )
  28. model = prepare_model_for_int8_training(model)
  29. config = LoraConfig(
  30. r=LORA_R,
  31. lora_alpha=LORA_ALPHA,
  32. target_modules=["q_proj", "v_proj"],
  33. lora_dropout=LORA_DROPOUT,
  34. bias="none",
  35. task_type="CAUSAL_LM",
  36. )
  37. model = get_peft_model(model, config)
  38. tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
  39. data = load_dataset("json", data_files="alpaca_data.json")
  40. def generate_prompt(data_point):
  41. # sorry about the formatting disaster gotta move fast
  42. if data_point["input"]:
  43. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  44. ### Instruction:
  45. {data_point["instruction"]}
  46. ### Input:
  47. {data_point["input"]}
  48. ### Response:
  49. {data_point["output"]}"""
  50. else:
  51. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  52. ### Instruction:
  53. {data_point["instruction"]}
  54. ### Response:
  55. {data_point["output"]}"""
  56. data = data.shuffle().map(
  57. lambda data_point: tokenizer(
  58. generate_prompt(data_point),
  59. truncation=True,
  60. max_length=CUTOFF_LEN,
  61. padding="max_length",
  62. )
  63. )
  64. trainer = transformers.Trainer(
  65. model=model,
  66. train_dataset=data["train"],
  67. args=transformers.TrainingArguments(
  68. per_device_train_batch_size=MICRO_BATCH_SIZE,
  69. gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  70. warmup_steps=100,
  71. num_train_epochs=EPOCHS,
  72. learning_rate=LEARNING_RATE,
  73. fp16=True,
  74. logging_steps=1,
  75. output_dir="lora-alpaca",
  76. save_total_limit=3,
  77. ),
  78. data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
  79. )
  80. model.config.use_cache = False
  81. trainer.train(resume_from_checkpoint=False)
  82. model.save_pretrained("lora-alpaca")