|
|
@@ -26,12 +26,12 @@ except:
|
|
|
|
|
|
if device == "cuda":
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
- "chavinlo/alpaca-native",
|
|
|
+ BASE_MODEL,
|
|
|
load_in_8bit=True,
|
|
|
torch_dtype=torch.float16,
|
|
|
device_map="auto",
|
|
|
)
|
|
|
- # model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
|
|
|
+ model = PeftModel.from_pretrained(model, LORA_WEIGHTS, torch_dtype=torch.float16)
|
|
|
elif device == "mps":
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
BASE_MODEL,
|