浏览代码

fix fp16 inference

Eric Wang 3 年之前
父节点
当前提交
e04897baae
共有 1 个文件被更改,包括 11 次插入3 次删除
  1. 11 3
      generate.py

+ 11 - 3
generate.py

@@ -1,3 +1,4 @@
+import sys
 import torch
 from peft import PeftModel
 import transformers
@@ -10,6 +11,7 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
 tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
 
+LOAD_8BIT = False
 BASE_MODEL = "decapoda-research/llama-7b-hf"
 LORA_WEIGHTS = "tloen/alpaca-lora-7b"
 
@@ -27,11 +29,15 @@ except:
 if device == "cuda":
     model = LlamaForCausalLM.from_pretrained(
         BASE_MODEL,
-        load_in_8bit=True,
+        load_in_8bit=LOAD_8BIT,
         torch_dtype=torch.float16,
         device_map="auto",
     )
-    model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
+    model = PeftModel.from_pretrained(
+        model,
+        LORA_WEIGHTS,
+        torch_dtype=torch.float16,
+    )
 elif device == "mps":
     model = LlamaForCausalLM.from_pretrained(
         BASE_MODEL,
@@ -74,9 +80,11 @@ def generate_prompt(instruction, input=None):
 
 ### Response:"""
 
+if not LOAD_8BIT:
+    model.half()  # seems to fix bugs for some users.
 
 model.eval()
-if torch.__version__ >= "2":
+if torch.__version__ >= "2" and sys.platform != "win32":
     model = torch.compile(model)