|
|
@@ -13,14 +13,16 @@ import torch.nn as nn
|
|
|
import bitsandbytes as bnb
|
|
|
"""
|
|
|
|
|
|
-from peft import ( # noqa: E402
|
|
|
+from peft import (
|
|
|
LoraConfig,
|
|
|
get_peft_model,
|
|
|
get_peft_model_state_dict,
|
|
|
prepare_model_for_int8_training,
|
|
|
set_peft_model_state_dict,
|
|
|
)
|
|
|
-from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
|
|
+from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
+
|
|
|
+from utils.prompter import Prompter
|
|
|
|
|
|
|
|
|
def train(
|
|
|
@@ -52,6 +54,7 @@ def train(
|
|
|
wandb_watch: str = "", # options: false | gradients | all
|
|
|
wandb_log_model: str = "", # options: false | true
|
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
|
+ prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
|
|
|
):
|
|
|
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
|
|
print(
|
|
|
@@ -75,13 +78,16 @@ def train(
|
|
|
f"wandb_run_name: {wandb_run_name}\n"
|
|
|
f"wandb_watch: {wandb_watch}\n"
|
|
|
f"wandb_log_model: {wandb_log_model}\n"
|
|
|
- f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
|
|
+ f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
|
|
|
+ f"prompt template: {prompt_template_name}\n"
|
|
|
)
|
|
|
assert (
|
|
|
base_model
|
|
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
|
|
gradient_accumulation_steps = batch_size // micro_batch_size
|
|
|
|
|
|
+ prompter = Prompter(prompt_template_name)
|
|
|
+
|
|
|
device_map = "auto"
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
|
ddp = world_size != 1
|
|
|
@@ -138,10 +144,16 @@ def train(
|
|
|
return result
|
|
|
|
|
|
def generate_and_tokenize_prompt(data_point):
|
|
|
- full_prompt = generate_prompt(data_point)
|
|
|
+ full_prompt = prompter.generate_prompt(
|
|
|
+ data_point["instruction"],
|
|
|
+ data_point["input"],
|
|
|
+ data_point["output"],
|
|
|
+ )
|
|
|
tokenized_full_prompt = tokenize(full_prompt)
|
|
|
if not train_on_inputs:
|
|
|
- user_prompt = generate_prompt({**data_point, "output": ""})
|
|
|
+ user_prompt = prompter.generate_prompt(
|
|
|
+ data_point["instruction"], data_point["input"]
|
|
|
+ )
|
|
|
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
|
|
|
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
|
|
|
|
|
@@ -260,28 +272,5 @@ def train(
|
|
|
)
|
|
|
|
|
|
|
|
|
-def generate_prompt(data_point):
|
|
|
- # sorry about the formatting disaster gotta move fast
|
|
|
- if data_point["input"]:
|
|
|
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
|
|
-
|
|
|
-### Instruction:
|
|
|
-{data_point["instruction"]}
|
|
|
-
|
|
|
-### Input:
|
|
|
-{data_point["input"]}
|
|
|
-
|
|
|
-### Response:
|
|
|
-{data_point["output"]}"""
|
|
|
- else:
|
|
|
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
|
|
-
|
|
|
-### Instruction:
|
|
|
-{data_point["instruction"]}
|
|
|
-
|
|
|
-### Response:
|
|
|
-{data_point["output"]}"""
|
|
|
-
|
|
|
-
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(train)
|