Eric Wang преди 3 години
родител
ревизия
07f5b68e0f
променени са 1 файла, в които са добавени 9 реда и са изтрити 7 реда
  1. 9 7
      generate.py

+ 9 - 7
generate.py

@@ -1,3 +1,4 @@
+import torch
 from peft import PeftModel
 from transformers import LLaMATokenizer, LLaMAForCausalLM, GenerationConfig
 
@@ -41,13 +42,14 @@ def evaluate(instruction, input=None, **kwargs):
         num_beams=4,
         **kwargs,
     )
-    generation_output = model.generate(
-        input_ids=input_ids,
-        generation_config=generation_config,
-        return_dict_in_generate=True,
-        output_scores=True,
-        max_new_tokens=256,
-    )
+    with torch.no_grad():
+        generation_output = model.generate(
+            input_ids=input_ids,
+            generation_config=generation_config,
+            return_dict_in_generate=True,
+            output_scores=True,
+            max_new_tokens=256,
+        )
     s = generation_output.sequences[0]
     output = tokenizer.decode(s)
     return output.split("### Response:")[1].strip()