export_hf_checkpoint.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. import json
  3. import torch
  4. from peft import PeftModel, LoraConfig
  5. import transformers
  6. assert (
  7. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  8. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
  9. from transformers import LlamaTokenizer, LlamaForCausalLM
  10. tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
  11. base_model = LlamaForCausalLM.from_pretrained(
  12. "decapoda-research/llama-7b-hf",
  13. load_in_8bit=False,
  14. torch_dtype=torch.float16,
  15. device_map={"": "cpu"},
  16. )
  17. first_weight = base_model.model.layers[0].self_attn.q_proj.weight
  18. first_weight_old = first_weight.clone()
  19. lora_model = PeftModel.from_pretrained(
  20. base_model,
  21. "tloen/alpaca-lora-7b",
  22. device_map={"": "cpu"},
  23. torch_dtype=torch.float16,
  24. )
  25. lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
  26. assert torch.allclose(first_weight_old, first_weight)
  27. # merge weights
  28. for layer in lora_model.base_model.model.model.layers:
  29. layer.self_attn.q_proj.merge_weights = True
  30. layer.self_attn.v_proj.merge_weights = True
  31. lora_model.train(False)
  32. # did we do anything?
  33. assert not torch.allclose(first_weight_old, first_weight)
  34. lora_model_sd = lora_model.state_dict()
  35. deloreanized_sd = {
  36. k.replace("base_model.model.", ""): v
  37. for k, v in lora_model_sd.items()
  38. if "lora" not in k
  39. }
  40. LlamaForCausalLM.save_pretrained(
  41. base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
  42. )