فهرست منبع

generate.py memory, perf updates

Eric Wang 3 سال پیش
والد
کامیت
d68ff15ceb
1فایلهای تغییر یافته به همراه14 افزوده شده و 9 حذف شده
  1. 14 9
      generate.py

+ 14 - 9
generate.py

@@ -7,9 +7,12 @@ tokenizer = LLaMATokenizer.from_pretrained("decapoda-research/llama-7b-hf")
 model = LLaMAForCausalLM.from_pretrained(
     "decapoda-research/llama-7b-hf",
     load_in_8bit=True,
+    torch_dtype=torch.float16,
     device_map="auto",
 )
-model = PeftModel.from_pretrained(model, "tloen/alpaca-lora-7b")
+model = PeftModel.from_pretrained(
+    model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
+)
 
 
 def generate_prompt(instruction, input=None):
@@ -32,6 +35,9 @@ def generate_prompt(instruction, input=None):
 ### Response:"""
 
 
+model.eval()
+
+
 def evaluate(instruction, input=None, **kwargs):
     prompt = generate_prompt(instruction, input)
     inputs = tokenizer(prompt, return_tensors="pt")
@@ -42,14 +48,13 @@ def evaluate(instruction, input=None, **kwargs):
         num_beams=4,
         **kwargs,
     )
-    with torch.no_grad():
-        generation_output = model.generate(
-            input_ids=input_ids,
-            generation_config=generation_config,
-            return_dict_in_generate=True,
-            output_scores=True,
-            max_new_tokens=256,
-        )
+    generation_output = model.generate(
+        input_ids=input_ids,
+        generation_config=generation_config,
+        return_dict_in_generate=True,
+        output_scores=True,
+        max_new_tokens=256,
+    )
     s = generation_output.sequences[0]
     output = tokenizer.decode(s)
     return output.split("### Response:")[1].strip()