Parcourir la source

torch.no_grad

Eric Wang il y a 3 ans
Parent
commit
07f5b68e0f
1 fichiers modifiés avec 9 ajouts et 7 suppressions
  1. 9 7
      generate.py

+ 9 - 7
generate.py

@@ -1,3 +1,4 @@
+import torch
 from peft import PeftModel
 from transformers import LLaMATokenizer, LLaMAForCausalLM, GenerationConfig
 
@@ -41,13 +42,14 @@ def evaluate(instruction, input=None, **kwargs):
         num_beams=4,
         **kwargs,
     )
-    generation_output = model.generate(
-        input_ids=input_ids,
-        generation_config=generation_config,
-        return_dict_in_generate=True,
-        output_scores=True,
-        max_new_tokens=256,
-    )
+    with torch.no_grad():
+        generation_output = model.generate(
+            input_ids=input_ids,
+            generation_config=generation_config,
+            return_dict_in_generate=True,
+            output_scores=True,
+            max_new_tokens=256,
+        )
     s = generation_output.sequences[0]
     output = tokenizer.decode(s)
     return output.split("### Response:")[1].strip()