finetune.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import os
  2. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  3. import torch
  4. import torch.nn as nn
  5. import bitsandbytes as bnb
  6. from datasets import load_dataset
  7. import transformers
  8. from transformers import AutoTokenizer, AutoConfig, LLaMAForCausalLM, LLaMATokenizer
  9. from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
  10. model = LLaMAForCausalLM.from_pretrained(
  11. "decapoda-research/llama-7b-hf",
  12. load_in_8bit=True,
  13. device_map="auto",
  14. )
  15. tokenizer = LLaMATokenizer.from_pretrained(
  16. "decapoda-research/llama-7b-hf", add_eos_token=True
  17. )
  18. model = prepare_model_for_int8_training(model)
  19. config = LoraConfig(
  20. r=4,
  21. lora_alpha=16,
  22. target_modules=["q_proj", "v_proj"],
  23. lora_dropout=0.05,
  24. bias="none",
  25. task_type="CAUSAL_LM",
  26. )
  27. model = get_peft_model(model, config)
  28. tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
  29. data = load_dataset("json", data_files="alpaca_data.json")
  30. def generate_prompt(data_point):
  31. # sorry about the formatting disaster gotta move fast
  32. if data_point["instruction"]:
  33. return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
  34. ### Instruction:
  35. {data_point["instruction"]}
  36. ### Input:
  37. {data_point["input"]}
  38. ### Response:
  39. {data_point["output"]}"""
  40. else:
  41. return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
  42. ### Instruction:
  43. {data_point["instruction"]}
  44. ### Response:
  45. {data_point["output"]}"""
  46. # optimized for RTX 4090. for larger GPUs, increase some of these?
  47. MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
  48. BATCH_SIZE = 128
  49. GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
  50. EPOCHS = 3 # from the result
  51. LEARNING_RATE = 3e-4 # the karpathy constant
  52. CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
  53. data = data.shuffle().map(
  54. lambda data_point: tokenizer(
  55. generate_prompt(data_point),
  56. truncation=True,
  57. max_length=CUTOFF_LEN,
  58. padding="max_length",
  59. )
  60. )
  61. trainer = transformers.Trainer(
  62. model=model,
  63. train_dataset=data["train"],
  64. args=transformers.TrainingArguments(
  65. per_device_train_batch_size=MICRO_BATCH_SIZE,
  66. gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  67. warmup_steps=100,
  68. num_train_epochs=EPOCHS,
  69. learning_rate=LEARNING_RATE,
  70. fp16=True,
  71. logging_steps=1,
  72. output_dir="lora-alpaca",
  73. save_total_limit=3,
  74. ),
  75. data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
  76. )
  77. model.config.use_cache = False
  78. trainer.train(resume_from_checkpoint=False)
  79. model.save_pretrained("lora-alpaca")