瀏覽代碼

Fix LoRa weight merging

Kevin Kwok 3 年之前
父節點
當前提交
dde89950f3
共有 1 個文件被更改,包括 6 次插入1 次删除
  1. 6 1
      export_state_dict_checkpoint.py

+ 6 - 1
export_state_dict_checkpoint.py

@@ -21,7 +21,12 @@ lora_model = PeftModel.from_pretrained(
     torch_dtype=torch.float16,
 )
 
-lora_model.eval()  # merge weights
+# merge weights
+for layer in lora_model.base_model.model.model.layers:
+    layer.self_attn.q_proj.merge_weights = True
+    layer.self_attn.v_proj.merge_weights = True
+    
+lora_model.train(False)
 
 lora_model_sd = lora_model.state_dict()