|
@@ -0,0 +1,56 @@
|
|
|
|
|
+import os
|
|
|
|
|
+import json
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+from peft import PeftModel, LoraConfig
|
|
|
|
|
+
|
|
|
|
|
+import transformers
|
|
|
|
|
+
|
|
|
|
|
+assert (
|
|
|
|
|
+ "LlamaTokenizer" in transformers._import_structure["models.llama"]
|
|
|
|
|
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
|
|
|
|
+from transformers import LlamaTokenizer, LlamaForCausalLM
|
|
|
|
|
+
|
|
|
|
|
+tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
|
|
|
|
+
|
|
|
|
|
+base_model = LlamaForCausalLM.from_pretrained(
|
|
|
|
|
+ "decapoda-research/llama-7b-hf",
|
|
|
|
|
+ load_in_8bit=False,
|
|
|
|
|
+ torch_dtype=torch.float16,
|
|
|
|
|
+ device_map={"": "cpu"},
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+first_weight = base_model.model.layers[0].self_attn.q_proj.weight
|
|
|
|
|
+first_weight_old = first_weight.clone()
|
|
|
|
|
+
|
|
|
|
|
+lora_model = PeftModel.from_pretrained(
|
|
|
|
|
+ base_model,
|
|
|
|
|
+ "tloen/alpaca-lora-7b",
|
|
|
|
|
+ device_map={"": "cpu"},
|
|
|
|
|
+ torch_dtype=torch.float16,
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
|
|
|
|
|
+
|
|
|
|
|
+assert torch.allclose(first_weight_old, first_weight)
|
|
|
|
|
+
|
|
|
|
|
+# merge weights
|
|
|
|
|
+for layer in lora_model.base_model.model.model.layers:
|
|
|
|
|
+ layer.self_attn.q_proj.merge_weights = True
|
|
|
|
|
+ layer.self_attn.v_proj.merge_weights = True
|
|
|
|
|
+
|
|
|
|
|
+lora_model.train(False)
|
|
|
|
|
+
|
|
|
|
|
+# did we do anything?
|
|
|
|
|
+assert not torch.allclose(first_weight_old, first_weight)
|
|
|
|
|
+
|
|
|
|
|
+lora_model_sd = lora_model.state_dict()
|
|
|
|
|
+deloreanized_sd = {
|
|
|
|
|
+ k.replace("base_model.model.model", "model"): v
|
|
|
|
|
+ for k, v in lora_model_sd.items()
|
|
|
|
|
+ if "lora" not in k
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+LlamaForCausalLM.save_pretrained(
|
|
|
|
|
+ base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
|
|
|
|
|
+)
|