|
@@ -7,9 +7,12 @@ tokenizer = LLaMATokenizer.from_pretrained("decapoda-research/llama-7b-hf")
|
|
|
model = LLaMAForCausalLM.from_pretrained(
|
|
model = LLaMAForCausalLM.from_pretrained(
|
|
|
"decapoda-research/llama-7b-hf",
|
|
"decapoda-research/llama-7b-hf",
|
|
|
load_in_8bit=True,
|
|
load_in_8bit=True,
|
|
|
|
|
+ torch_dtype=torch.float16,
|
|
|
device_map="auto",
|
|
device_map="auto",
|
|
|
)
|
|
)
|
|
|
-model = PeftModel.from_pretrained(model, "tloen/alpaca-lora-7b")
|
|
|
|
|
|
|
+model = PeftModel.from_pretrained(
|
|
|
|
|
+ model, "tloen/alpaca-lora-7b", torch_dtype=torch.float16
|
|
|
|
|
+)
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_prompt(instruction, input=None):
|
|
def generate_prompt(instruction, input=None):
|
|
@@ -32,6 +35,9 @@ def generate_prompt(instruction, input=None):
|
|
|
### Response:"""
|
|
### Response:"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+model.eval()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def evaluate(instruction, input=None, **kwargs):
|
|
def evaluate(instruction, input=None, **kwargs):
|
|
|
prompt = generate_prompt(instruction, input)
|
|
prompt = generate_prompt(instruction, input)
|
|
|
inputs = tokenizer(prompt, return_tensors="pt")
|
|
inputs = tokenizer(prompt, return_tensors="pt")
|
|
@@ -42,14 +48,13 @@ def evaluate(instruction, input=None, **kwargs):
|
|
|
num_beams=4,
|
|
num_beams=4,
|
|
|
**kwargs,
|
|
**kwargs,
|
|
|
)
|
|
)
|
|
|
- with torch.no_grad():
|
|
|
|
|
- generation_output = model.generate(
|
|
|
|
|
- input_ids=input_ids,
|
|
|
|
|
- generation_config=generation_config,
|
|
|
|
|
- return_dict_in_generate=True,
|
|
|
|
|
- output_scores=True,
|
|
|
|
|
- max_new_tokens=256,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ generation_output = model.generate(
|
|
|
|
|
+ input_ids=input_ids,
|
|
|
|
|
+ generation_config=generation_config,
|
|
|
|
|
+ return_dict_in_generate=True,
|
|
|
|
|
+ output_scores=True,
|
|
|
|
|
+ max_new_tokens=256,
|
|
|
|
|
+ )
|
|
|
s = generation_output.sequences[0]
|
|
s = generation_output.sequences[0]
|
|
|
output = tokenizer.decode(s)
|
|
output = tokenizer.decode(s)
|
|
|
return output.split("### Response:")[1].strip()
|
|
return output.split("### Response:")[1].strip()
|