Przeglądaj źródła

Merge pull request #19 from antimatter15/patch-1

Fix LoRa weight merging in export

It can't hurt
Eric J. Wang 3 lat temu
rodzic
commit
6681523bbe
1 zmienionych plików z 6 dodań i 1 usunięć
  1. 6 1
      export_state_dict_checkpoint.py

+ 6 - 1
export_state_dict_checkpoint.py

@@ -21,7 +21,12 @@ lora_model = PeftModel.from_pretrained(
     torch_dtype=torch.float16,
 )
 
-lora_model.eval()  # merge weights
+# merge weights
+for layer in lora_model.base_model.model.model.layers:
+    layer.self_attn.q_proj.merge_weights = True
+    layer.self_attn.v_proj.merge_weights = True
+    
+lora_model.train(False)
 
 lora_model_sd = lora_model.state_dict()