|
@@ -53,29 +53,30 @@ def train(
|
|
|
wandb_log_model: str = "", # options: false | true
|
|
wandb_log_model: str = "", # options: false | true
|
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
|
):
|
|
):
|
|
|
- print(
|
|
|
|
|
- f"Training Alpaca-LoRA model with params:\n"
|
|
|
|
|
- f"base_model: {base_model}\n"
|
|
|
|
|
- f"data_path: {data_path}\n"
|
|
|
|
|
- f"output_dir: {output_dir}\n"
|
|
|
|
|
- f"batch_size: {batch_size}\n"
|
|
|
|
|
- f"micro_batch_size: {micro_batch_size}\n"
|
|
|
|
|
- f"num_epochs: {num_epochs}\n"
|
|
|
|
|
- f"learning_rate: {learning_rate}\n"
|
|
|
|
|
- f"cutoff_len: {cutoff_len}\n"
|
|
|
|
|
- f"val_set_size: {val_set_size}\n"
|
|
|
|
|
- f"lora_r: {lora_r}\n"
|
|
|
|
|
- f"lora_alpha: {lora_alpha}\n"
|
|
|
|
|
- f"lora_dropout: {lora_dropout}\n"
|
|
|
|
|
- f"lora_target_modules: {lora_target_modules}\n"
|
|
|
|
|
- f"train_on_inputs: {train_on_inputs}\n"
|
|
|
|
|
- f"group_by_length: {group_by_length}\n"
|
|
|
|
|
- f"wandb_project: {wandb_project}\n"
|
|
|
|
|
- f"wandb_run_name: {wandb_run_name}\n"
|
|
|
|
|
- f"wandb_watch: {wandb_watch}\n"
|
|
|
|
|
- f"wandb_log_model: {wandb_log_model}\n"
|
|
|
|
|
- f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
|
|
|
|
+ print(
|
|
|
|
|
+ f"Training Alpaca-LoRA model with params:\n"
|
|
|
|
|
+ f"base_model: {base_model}\n"
|
|
|
|
|
+ f"data_path: {data_path}\n"
|
|
|
|
|
+ f"output_dir: {output_dir}\n"
|
|
|
|
|
+ f"batch_size: {batch_size}\n"
|
|
|
|
|
+ f"micro_batch_size: {micro_batch_size}\n"
|
|
|
|
|
+ f"num_epochs: {num_epochs}\n"
|
|
|
|
|
+ f"learning_rate: {learning_rate}\n"
|
|
|
|
|
+ f"cutoff_len: {cutoff_len}\n"
|
|
|
|
|
+ f"val_set_size: {val_set_size}\n"
|
|
|
|
|
+ f"lora_r: {lora_r}\n"
|
|
|
|
|
+ f"lora_alpha: {lora_alpha}\n"
|
|
|
|
|
+ f"lora_dropout: {lora_dropout}\n"
|
|
|
|
|
+ f"lora_target_modules: {lora_target_modules}\n"
|
|
|
|
|
+ f"train_on_inputs: {train_on_inputs}\n"
|
|
|
|
|
+ f"group_by_length: {group_by_length}\n"
|
|
|
|
|
+ f"wandb_project: {wandb_project}\n"
|
|
|
|
|
+ f"wandb_run_name: {wandb_run_name}\n"
|
|
|
|
|
+ f"wandb_watch: {wandb_watch}\n"
|
|
|
|
|
+ f"wandb_log_model: {wandb_log_model}\n"
|
|
|
|
|
+ f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
|
|
|
|
+ )
|
|
|
assert (
|
|
assert (
|
|
|
base_model
|
|
base_model
|
|
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
|
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|