export_hf_checkpoint.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import torch
  3. import transformers
  4. from peft import PeftModel
  5. # Unused imports
  6. # import json
  7. # from peft import LoraConfig
  8. assert (
  9. "LlamaTokenizer" in transformers._import_structure["models.llama"]
  10. ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
  11. from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
  12. BASE_MODEL = os.environ.get("BASE_MODEL", None)
  13. assert (
  14. BASE_MODEL
  15. ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
  16. tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
  17. base_model = LlamaForCausalLM.from_pretrained(
  18. BASE_MODEL,
  19. load_in_8bit=False,
  20. torch_dtype=torch.float16,
  21. device_map={"": "cpu"},
  22. )
  23. first_weight = base_model.model.layers[0].self_attn.q_proj.weight
  24. first_weight_old = first_weight.clone()
  25. lora_model = PeftModel.from_pretrained(
  26. base_model,
  27. "tloen/alpaca-lora-7b",
  28. device_map={"": "cpu"},
  29. torch_dtype=torch.float16,
  30. )
  31. lora_weight = lora_model.base_model.model.model.layers[
  32. 0
  33. ].self_attn.q_proj.weight
  34. assert torch.allclose(first_weight_old, first_weight)
  35. # merge weights
  36. for layer in lora_model.base_model.model.model.layers:
  37. layer.self_attn.q_proj.merge_weights = True
  38. layer.self_attn.v_proj.merge_weights = True
  39. lora_model.train(False)
  40. # did we do anything?
  41. assert not torch.allclose(first_weight_old, first_weight)
  42. lora_model_sd = lora_model.state_dict()
  43. deloreanized_sd = {
  44. k.replace("base_model.model.", ""): v
  45. for k, v in lora_model_sd.items()
  46. if "lora" not in k
  47. }
  48. LlamaForCausalLM.save_pretrained(
  49. base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
  50. )