|
@@ -51,6 +51,11 @@ def train(
|
|
|
# llm hyperparams
|
|
# llm hyperparams
|
|
|
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
|
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
|
|
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
|
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
|
|
|
|
+ # wandb params
|
|
|
|
|
+ wandb_project: str = "",
|
|
|
|
|
+ wandb_run_name: str = "",
|
|
|
|
|
+ wandb_watch: str = "", # options: false | gradients | all
|
|
|
|
|
+ wandb_log_model: str = "", # options: false | true
|
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
|
):
|
|
):
|
|
|
print(
|
|
print(
|
|
@@ -70,6 +75,10 @@ def train(
|
|
|
f"lora_target_modules: {lora_target_modules}\n"
|
|
f"lora_target_modules: {lora_target_modules}\n"
|
|
|
f"train_on_inputs: {train_on_inputs}\n"
|
|
f"train_on_inputs: {train_on_inputs}\n"
|
|
|
f"group_by_length: {group_by_length}\n"
|
|
f"group_by_length: {group_by_length}\n"
|
|
|
|
|
+ f"wandb_project: {wandb_project}\n"
|
|
|
|
|
+ f"wandb_run_name: {wandb_run_name}\n"
|
|
|
|
|
+ f"wandb_watch: {wandb_watch}\n"
|
|
|
|
|
+ f"wandb_log_model: {wandb_log_model}\n"
|
|
|
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
|
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
|
|
)
|
|
)
|
|
|
assert (
|
|
assert (
|
|
@@ -84,6 +93,18 @@ def train(
|
|
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
|
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
|
|
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
|
|
|
|
|
|
|
|
|
+ # Check if parameter passed or if set within environ
|
|
|
|
|
+ use_wandb = len(wandb_project) > 0 or \
|
|
|
|
|
+ ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
|
|
|
|
|
+ # Only overwrite environ if wandb param passed
|
|
|
|
|
+ if len(wandb_project) > 0:
|
|
|
|
|
+ os.environ['WANDB_PROJECT'] = wandb_project
|
|
|
|
|
+ if len(wandb_watch) > 0:
|
|
|
|
|
+ os.environ['WANDB_WATCH'] = wandb_watch
|
|
|
|
|
+ if len(wandb_log_model) > 0:
|
|
|
|
|
+ os.environ['WANDB_LOG_MODEL'] = wandb_log_model
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
base_model,
|
|
base_model,
|
|
|
load_in_8bit=True,
|
|
load_in_8bit=True,
|
|
@@ -209,6 +230,8 @@ def train(
|
|
|
load_best_model_at_end=True if val_set_size > 0 else False,
|
|
load_best_model_at_end=True if val_set_size > 0 else False,
|
|
|
ddp_find_unused_parameters=False if ddp else None,
|
|
ddp_find_unused_parameters=False if ddp else None,
|
|
|
group_by_length=group_by_length,
|
|
group_by_length=group_by_length,
|
|
|
|
|
+ report_to="wandb" if use_wandb else None,
|
|
|
|
|
+ run_name=wandb_run_name if use_wandb else None
|
|
|
),
|
|
),
|
|
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
|
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|