Parcourir la source

Enable inference on CPU and Mac GPU using pytorch support for MPS (#48)

Peter Marelas il y a 3 ans
Parent
commit
db4af6a7ff
1 fichiers modifiés avec 46 ajouts et 11 suppressions
  1. 46 11
      generate.py

+ 46 - 11
generate.py

@@ -10,16 +10,51 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
 tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
 
-model = LlamaForCausalLM.from_pretrained(
-    "decapoda-research/llama-7b-hf",
-    load_in_8bit=True,
-    torch_dtype=torch.float16,
-    device_map="auto",
-)
-model = PeftModel.from_pretrained(
-    model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
-)
-
+if torch.cuda.is_available():
+    device = "cuda"
+else:
+    device = "cpu"
+
+try:
+    if torch.backends.mps.is_available():
+        device = "mps"
+except:
+    pass
+
+if device == "cuda":
+    model = LlamaForCausalLM.from_pretrained(
+        "decapoda-research/llama-7b-hf",
+        load_in_8bit=True,
+        torch_dtype=torch.float16,
+        device_map="auto",
+    )
+    model = PeftModel.from_pretrained(
+        model, "tloen/alpaca-lora-7b",
+        torch_dtype=torch.float16
+    )
+elif device == "mps":
+    model = LlamaForCausalLM.from_pretrained(
+        "decapoda-research/llama-7b-hf",
+        device_map={"": device},
+        torch_dtype=torch.float16,
+    )
+    model = PeftModel.from_pretrained(
+        model,
+        "tloen/alpaca-lora-7b",
+        device_map={"": device},
+        torch_dtype=torch.float16,
+    )
+else:
+    model = LlamaForCausalLM.from_pretrained(
+        "decapoda-research/llama-7b-hf",
+        device_map={"": device},
+        low_cpu_mem_usage=True
+    )
+    model = PeftModel.from_pretrained(
+        model,
+        "tloen/alpaca-lora-7b",
+        device_map={"": device},
+    )
 
 def generate_prompt(instruction, input=None):
     if input:
@@ -55,7 +90,7 @@ def evaluate(
 ):
     prompt = generate_prompt(instruction, input)
     inputs = tokenizer(prompt, return_tensors="pt")
-    input_ids = inputs["input_ids"].cuda()
+    input_ids = inputs["input_ids"].to(device)
     generation_config = GenerationConfig(
         temperature=temperature,
         top_p=top_p,