|
@@ -54,8 +54,8 @@ def train(
|
|
|
# wandb params
|
|
# wandb params
|
|
|
wandb_project: str = "",
|
|
wandb_project: str = "",
|
|
|
wandb_run_name: str = "",
|
|
wandb_run_name: str = "",
|
|
|
- wandb_watch: str = "", # options: false | gradients | all
|
|
|
|
|
- wandb_log_model: str = "", # options: false | true
|
|
|
|
|
|
|
+ wandb_watch: str = "", # options: false | gradients | all
|
|
|
|
|
+ wandb_log_model: str = "", # options: false | true
|
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
|
):
|
|
):
|
|
|
print(
|
|
print(
|
|
@@ -94,16 +94,16 @@ def train(
|
|
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
|
gradient_accumulation_steps = gradient_accumulation_steps // world_size
|
|
|
|
|
|
|
|
# Check if parameter passed or if set within environ
|
|
# Check if parameter passed or if set within environ
|
|
|
- use_wandb = len(wandb_project) > 0 or \
|
|
|
|
|
- ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
|
|
|
|
|
|
|
+ use_wandb = len(wandb_project) > 0 or (
|
|
|
|
|
+ "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
|
|
|
|
|
+ )
|
|
|
# Only overwrite environ if wandb param passed
|
|
# Only overwrite environ if wandb param passed
|
|
|
- if len(wandb_project) > 0:
|
|
|
|
|
- os.environ['WANDB_PROJECT'] = wandb_project
|
|
|
|
|
|
|
+ if len(wandb_project) > 0:
|
|
|
|
|
+ os.environ["WANDB_PROJECT"] = wandb_project
|
|
|
if len(wandb_watch) > 0:
|
|
if len(wandb_watch) > 0:
|
|
|
- os.environ['WANDB_WATCH'] = wandb_watch
|
|
|
|
|
|
|
+ os.environ["WANDB_WATCH"] = wandb_watch
|
|
|
if len(wandb_log_model) > 0:
|
|
if len(wandb_log_model) > 0:
|
|
|
- os.environ['WANDB_LOG_MODEL'] = wandb_log_model
|
|
|
|
|
-
|
|
|
|
|
|
|
+ os.environ["WANDB_LOG_MODEL"] = wandb_log_model
|
|
|
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
base_model,
|
|
base_model,
|
|
@@ -231,7 +231,7 @@ def train(
|
|
|
ddp_find_unused_parameters=False if ddp else None,
|
|
ddp_find_unused_parameters=False if ddp else None,
|
|
|
group_by_length=group_by_length,
|
|
group_by_length=group_by_length,
|
|
|
report_to="wandb" if use_wandb else None,
|
|
report_to="wandb" if use_wandb else None,
|
|
|
- run_name=wandb_run_name if use_wandb else None
|
|
|
|
|
|
|
+ run_name=wandb_run_name if use_wandb else None,
|
|
|
),
|
|
),
|
|
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
|
data_collator=transformers.DataCollatorForSeq2Seq(
|
|
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|