Explorar el Código

generate.py tweaks

Eric Wang hace 3 años
padre
commit
c83e30ab78
Se han modificado 1 ficheros con 11 adiciones y 8 borrados
  1. 11 8
      generate.py

+ 11 - 8
generate.py

@@ -10,6 +10,9 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
 tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
 
+BASE_MODEL = "decapoda-research/llama-7b-hf"
+LORA_WEIGHTS = "tloen/alpaca-lora-7b"
+
 if torch.cuda.is_available():
     device = "cuda"
 else:
@@ -23,33 +26,31 @@ except:
 
 if device == "cuda":
     model = LlamaForCausalLM.from_pretrained(
-        "decapoda-research/llama-7b-hf",
+        "chavinlo/alpaca-native",
         load_in_8bit=True,
         torch_dtype=torch.float16,
         device_map="auto",
     )
-    model = PeftModel.from_pretrained(
-        model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
-    )
+    # model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
 elif device == "mps":
     model = LlamaForCausalLM.from_pretrained(
-        "decapoda-research/llama-7b-hf",
+        BASE_MODEL,
         device_map={"": device},
         torch_dtype=torch.float16,
     )
     model = PeftModel.from_pretrained(
         model,
-        "tloen/alpaca-lora-7b",
+        LORA_WEIGHTS,
         device_map={"": device},
         torch_dtype=torch.float16,
     )
 else:
     model = LlamaForCausalLM.from_pretrained(
-        "decapoda-research/llama-7b-hf", device_map={"": device}, low_cpu_mem_usage=True
+        BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
     )
     model = PeftModel.from_pretrained(
         model,
-        "tloen/alpaca-lora-7b",
+        LORA_WEIGHTS,
         device_map={"": device},
     )
 
@@ -75,6 +76,8 @@ def generate_prompt(instruction, input=None):
 
 
 model.eval()
+if torch.__version__ >= "2":
+    model = torch.compile(model)
 
 
 def evaluate(