|
@@ -8,7 +8,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
|
|
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
|
BASE_MODEL = os.environ.get("BASE_MODEL", None)
|
|
|
assert (
|
|
assert (
|
|
|
BASE_MODEL
|
|
BASE_MODEL
|
|
|
-), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`" # noqa: E501
|
|
|
|
|
|
|
+), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
|
|
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
|
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
|
|
|
|
|
|
@@ -35,10 +35,8 @@ lora_weight = lora_model.base_model.model.model.layers[
|
|
|
|
|
|
|
|
assert torch.allclose(first_weight_old, first_weight)
|
|
assert torch.allclose(first_weight_old, first_weight)
|
|
|
|
|
|
|
|
-# merge weights
|
|
|
|
|
-for layer in lora_model.base_model.model.model.layers:
|
|
|
|
|
- layer.self_attn.q_proj.merge_weights = True
|
|
|
|
|
- layer.self_attn.v_proj.merge_weights = True
|
|
|
|
|
|
|
+# merge weights - new merging method from peft
|
|
|
|
|
+lora_model = lora_model.merge_and_unload()
|
|
|
|
|
|
|
|
lora_model.train(False)
|
|
lora_model.train(False)
|
|
|
|
|
|