export_hf_checkpoint.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. import torch
  3. import transformers
  4. from peft import PeftModel
  5. from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
  6. BASE_MODEL = os.environ.get("BASE_MODEL", None)
  7. assert (
  8. BASE_MODEL
  9. ), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
  10. tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
  11. base_model = LlamaForCausalLM.from_pretrained(
  12. BASE_MODEL,
  13. load_in_8bit=False,
  14. torch_dtype=torch.float16,
  15. device_map={"": "cpu"},
  16. )
  17. first_weight = base_model.model.layers[0].self_attn.q_proj.weight
  18. first_weight_old = first_weight.clone()
  19. lora_model = PeftModel.from_pretrained(
  20. base_model,
  21. "tloen/alpaca-lora-7b",
  22. device_map={"": "cpu"},
  23. torch_dtype=torch.float16,
  24. )
  25. lora_weight = lora_model.base_model.model.model.layers[
  26. 0
  27. ].self_attn.q_proj.weight
  28. assert torch.allclose(first_weight_old, first_weight)
  29. # merge weights - new merging method from peft
  30. lora_model = lora_model.merge_and_unload()
  31. lora_model.train(False)
  32. # did we do anything?
  33. assert not torch.allclose(first_weight_old, first_weight)
  34. lora_model_sd = lora_model.state_dict()
  35. deloreanized_sd = {
  36. k.replace("base_model.model.", ""): v
  37. for k, v in lora_model_sd.items()
  38. if "lora" not in k
  39. }
  40. LlamaForCausalLM.save_pretrained(
  41. base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
  42. )