Răsfoiți Sursa

HF export script

Eric Wang 3 ani în urmă
părinte
comite
3b160d745b
2 a modificat fișierele cu 61 adăugiri și 4 ștergeri
  1. 5 4
      README.md
  2. 56 0
      export_hf_checkpoint.py

+ 5 - 4
README.md

@@ -35,10 +35,12 @@ as well as some code related to prompt construction and tokenization.
 Near the top of this file is a set of hardcoded hyperparameters that you should feel free to modify.
 PRs adapting this code to multi-GPU setups and larger models are always welcome.
 
-### Checkpoint export (`export_state_dict_checkpoint.py`)
+### Checkpoint export (`export_*_checkpoint.py`)
 
-This file contains a script to convert the LoRA back into a standard PyTorch model checkpoint,
-which should help users who want to use the model with projects like [llama.cpp](https://github.com/ggerganov/llama.cpp).
+These files contain scripts that merge the LoRA weights back into the base model
+for export to Hugging Face format and to PyTorch `state_dicts`,
+which should help users who want to export LlamaModel-shaped weights or
+use the model with projects like [llama.cpp](https://github.com/ggerganov/llama.cpp).
 
 ### Dataset
 
@@ -56,7 +58,6 @@ as well as [clusters of bad examples](https://atlas.nomic.ai/map/d2139cc3-bc1c-4
 - We can likely improve our model performance significantly if we combed through the data and fixed bad examples; in fact, dataset quality might be our bottleneck.
 - We're continually fixing bugs and conducting training runs, and the weights on the Hugging Face Hub are being updated accordingly. In particular, those facing issues with response lengths should make sure that they have the latest version of the weights and code.
 
-
 ### Example outputs
 
 **Instruction**: Tell me about alpacas.

+ 56 - 0
export_hf_checkpoint.py

@@ -0,0 +1,56 @@
+import os
+import json
+
+import torch
+from peft import PeftModel, LoraConfig
+
+import transformers
+
+assert (
+    "LlamaTokenizer" in transformers._import_structure["models.llama"]
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
+from transformers import LlamaTokenizer, LlamaForCausalLM
+
+tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+
+base_model = LlamaForCausalLM.from_pretrained(
+    "decapoda-research/llama-7b-hf",
+    load_in_8bit=False,
+    torch_dtype=torch.float16,
+    device_map={"": "cpu"},
+)
+
+first_weight = base_model.model.layers[0].self_attn.q_proj.weight
+first_weight_old = first_weight.clone()
+
+lora_model = PeftModel.from_pretrained(
+    base_model,
+    "tloen/alpaca-lora-7b",
+    device_map={"": "cpu"},
+    torch_dtype=torch.float16,
+)
+
+lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
+
+assert torch.allclose(first_weight_old, first_weight)
+
+# merge weights
+for layer in lora_model.base_model.model.model.layers:
+    layer.self_attn.q_proj.merge_weights = True
+    layer.self_attn.v_proj.merge_weights = True
+
+lora_model.train(False)
+
+# did we do anything?
+assert not torch.allclose(first_weight_old, first_weight)
+
+lora_model_sd = lora_model.state_dict()
+deloreanized_sd = {
+    k.replace("base_model.model.model", "model"): v
+    for k, v in lora_model_sd.items()
+    if "lora" not in k
+}
+
+LlamaForCausalLM.save_pretrained(
+    base_model, "./hf_ckpt", state_dict=deloreanized_sd, max_shard_size="400MB"
+)