|
|
@@ -1,3 +1,4 @@
|
|
|
+import sys
|
|
|
import torch
|
|
|
from peft import PeftModel
|
|
|
import transformers
|
|
|
@@ -10,6 +11,7 @@ from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
|
|
|
|
|
+LOAD_8BIT = False
|
|
|
BASE_MODEL = "decapoda-research/llama-7b-hf"
|
|
|
LORA_WEIGHTS = "tloen/alpaca-lora-7b"
|
|
|
|
|
|
@@ -27,11 +29,15 @@ except:
|
|
|
if device == "cuda":
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
BASE_MODEL,
|
|
|
- load_in_8bit=True,
|
|
|
+ load_in_8bit=LOAD_8BIT,
|
|
|
torch_dtype=torch.float16,
|
|
|
device_map="auto",
|
|
|
)
|
|
|
- model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
|
|
|
+ model = PeftModel.from_pretrained(
|
|
|
+ model,
|
|
|
+ LORA_WEIGHTS,
|
|
|
+ torch_dtype=torch.float16,
|
|
|
+ )
|
|
|
elif device == "mps":
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
BASE_MODEL,
|
|
|
@@ -74,9 +80,11 @@ def generate_prompt(instruction, input=None):
|
|
|
|
|
|
### Response:"""
|
|
|
|
|
|
+if not LOAD_8BIT:
|
|
|
+ model.half() # seems to fix bugs for some users.
|
|
|
|
|
|
model.eval()
|
|
|
-if torch.__version__ >= "2":
|
|
|
+if torch.__version__ >= "2" and sys.platform != "win32":
|
|
|
model = torch.compile(model)
|
|
|
|
|
|
|