export_hf_checkpoint.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import os
  2. import json
  3. import torch
  4. from peft import PeftModel, LoraConfig
  5. import transformers
  6. assert (
  7. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  8. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  9. from transformers import LlamaTokenizer, LlamaForCausalLM
  10. BASE_MODEL = None
  11. assert (
  12. BASE_MODEL
  13. ), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
  14. tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
  15. base_model = LlamaForCausalLM.from_pretrained(
  16. BASE_MODEL,
  17. load_in_8bit=False,
  18. torch_dtype=torch.float16,
  19. device_map={"": "cpu"},
  20. )
  21. first_weight = base_model.model.layers[0].self_attn.q_proj.weight
  22. first_weight_old = first_weight.clone()
  23. lora_model = PeftModel.from_pretrained(
  24. base_model,
  25. "tloen/alpaca-lora-7b",
  26. device_map={"": "cpu"},
  27. torch_dtype=torch.float16,
  28. )
  29. lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
  30. assert torch.allclose(first_weight_old, first_weight)
  31. # merge weights
  32. for layer in lora_model.base_model.model.model.layers:
  33. layer.self_attn.q_proj.merge_weights = True
  34. layer.self_attn.v_proj.merge_weights = True
  35. lora_model.train(False)
  36. # did we do anything?
  37. assert not torch.allclose(first_weight_old, first_weight)
  38. lora_model_sd = lora_model.state_dict()
  39. deloreanized_sd = {
  40. k.replace("base_model.model.", ""): v
  41. for k, v in lora_model_sd.items()
  42. if "lora" not in k
  43. }
  44. LlamaForCausalLM.save_pretrained(
  45. base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
  46. )