Procházet zdrojové kódy

Print only on Rank 0 (#187)

* Print only on Rank 0

When training on multiple GPU, the settings are printed once per gpu.
This only prints from rank 0

See https://github.com/tloen/alpaca-lora/issues/182#issuecomment-1485550636
for a sample output.

Could apply to a few other prints further down as well.

* Typo

* Added failsafe

So this works whether or not LOCAL_RANK is defined.
Angainor Development před 3 roky
rodič
revize
fcbc45e4c0
1 změnil soubory, kde provedl 24 přidání a 23 odebrání
  1. 24 23
      finetune.py

+ 24 - 23
finetune.py

@@ -53,29 +53,30 @@ def train(
     wandb_log_model: str = "",  # options: false | true
     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
 ):
-    print(
-        f"Training Alpaca-LoRA model with params:\n"
-        f"base_model: {base_model}\n"
-        f"data_path: {data_path}\n"
-        f"output_dir: {output_dir}\n"
-        f"batch_size: {batch_size}\n"
-        f"micro_batch_size: {micro_batch_size}\n"
-        f"num_epochs: {num_epochs}\n"
-        f"learning_rate: {learning_rate}\n"
-        f"cutoff_len: {cutoff_len}\n"
-        f"val_set_size: {val_set_size}\n"
-        f"lora_r: {lora_r}\n"
-        f"lora_alpha: {lora_alpha}\n"
-        f"lora_dropout: {lora_dropout}\n"
-        f"lora_target_modules: {lora_target_modules}\n"
-        f"train_on_inputs: {train_on_inputs}\n"
-        f"group_by_length: {group_by_length}\n"
-        f"wandb_project: {wandb_project}\n"
-        f"wandb_run_name: {wandb_run_name}\n"
-        f"wandb_watch: {wandb_watch}\n"
-        f"wandb_log_model: {wandb_log_model}\n"
-        f"resume_from_checkpoint: {resume_from_checkpoint}\n"
-    )
+    if int(os.environ.get("LOCAL_RANK", 0)) == 0:
+        print(
+            f"Training Alpaca-LoRA model with params:\n"
+            f"base_model: {base_model}\n"
+            f"data_path: {data_path}\n"
+            f"output_dir: {output_dir}\n"
+            f"batch_size: {batch_size}\n"
+            f"micro_batch_size: {micro_batch_size}\n"
+            f"num_epochs: {num_epochs}\n"
+            f"learning_rate: {learning_rate}\n"
+            f"cutoff_len: {cutoff_len}\n"
+            f"val_set_size: {val_set_size}\n"
+            f"lora_r: {lora_r}\n"
+            f"lora_alpha: {lora_alpha}\n"
+            f"lora_dropout: {lora_dropout}\n"
+            f"lora_target_modules: {lora_target_modules}\n"
+            f"train_on_inputs: {train_on_inputs}\n"
+            f"group_by_length: {group_by_length}\n"
+            f"wandb_project: {wandb_project}\n"
+            f"wandb_run_name: {wandb_run_name}\n"
+            f"wandb_watch: {wandb_watch}\n"
+            f"wandb_log_model: {wandb_log_model}\n"
+            f"resume_from_checkpoint: {resume_from_checkpoint}\n"
+        )
     assert (
         base_model
     ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"