Ver código fonte

Feat: Add wandb (#168)

* Add wandb

* Fix KeyError

* Add WANDB_WATCH and WANDB_LOG_MODEL

* run_name -> wandb_run_name

* ,

* fix TrainingArgs

---------

Co-authored-by: Eric J. Wang <[email protected]>
NanoCode012 3 anos atrás
pai
commit
69b31e0fed
1 arquivos alterados com 23 adições e 0 exclusões
  1. 23 0
      finetune.py

+ 23 - 0
finetune.py

@@ -51,6 +51,11 @@ def train(
     # llm hyperparams
     train_on_inputs: bool = True,  # if False, masks out inputs in loss
     group_by_length: bool = False,  # faster, but produces an odd training loss curve
+    # wandb params
+    wandb_project: str = "",
+    wandb_run_name: str = "",
+    wandb_watch: str = "", # options: false | gradients | all
+    wandb_log_model: str = "", # options: false | true
     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
 ):
     print(
@@ -70,6 +75,10 @@ def train(
         f"lora_target_modules: {lora_target_modules}\n"
         f"train_on_inputs: {train_on_inputs}\n"
         f"group_by_length: {group_by_length}\n"
+        f"wandb_project: {wandb_project}\n"
+        f"wandb_run_name: {wandb_run_name}\n"
+        f"wandb_watch: {wandb_watch}\n"
+        f"wandb_log_model: {wandb_log_model}\n"
         f"resume_from_checkpoint: {resume_from_checkpoint}\n"
     )
     assert (
@@ -84,6 +93,18 @@ def train(
         device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
         gradient_accumulation_steps = gradient_accumulation_steps // world_size
 
+    # Check if parameter passed or if set within environ
+    use_wandb = len(wandb_project) > 0 or \
+                ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
+    # Only overwrite environ if wandb param passed
+    if len(wandb_project) > 0: 
+        os.environ['WANDB_PROJECT'] = wandb_project
+    if len(wandb_watch) > 0:
+        os.environ['WANDB_WATCH'] = wandb_watch
+    if len(wandb_log_model) > 0:
+        os.environ['WANDB_LOG_MODEL'] = wandb_log_model
+
+
     model = LlamaForCausalLM.from_pretrained(
         base_model,
         load_in_8bit=True,
@@ -209,6 +230,8 @@ def train(
             load_best_model_at_end=True if val_set_size > 0 else False,
             ddp_find_unused_parameters=False if ddp else None,
             group_by_length=group_by_length,
+            report_to="wandb" if use_wandb else None,
+            run_name=wandb_run_name if use_wandb else None
         ),
         data_collator=transformers.DataCollatorForSeq2Seq(
             tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True