|
|
@@ -18,6 +18,7 @@ from peft import (
|
|
|
LoraConfig,
|
|
|
get_peft_model,
|
|
|
get_peft_model_state_dict,
|
|
|
+ set_peft_model_state_dict,
|
|
|
)
|
|
|
|
|
|
|
|
|
@@ -43,7 +44,8 @@ def train(
|
|
|
],
|
|
|
# llm hyperparams
|
|
|
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
|
|
- group_by_length: bool = False, # faster, but produces an odd training loss curve
|
|
|
+ group_by_length: bool = False, # faster, but produces an odd training loss curve,
|
|
|
+ resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
|
):
|
|
|
print(
|
|
|
f"Training Alpaca-LoRA model with params:\n"
|
|
|
@@ -62,6 +64,7 @@ def train(
|
|
|
f"lora_target_modules: {lora_target_modules}\n"
|
|
|
f"train_on_inputs: {train_on_inputs}\n"
|
|
|
f"group_by_length: {group_by_length}\n"
|
|
|
+ f"resume_from_checkpoint: {resume_from_checkpoint}\n"
|
|
|
)
|
|
|
assert (
|
|
|
base_model
|
|
|
@@ -137,6 +140,24 @@ def train(
|
|
|
|
|
|
data = load_dataset("json", data_files=data_path)
|
|
|
|
|
|
+ if resume_from_checkpoint:
|
|
|
+ # Check the available weights and load them
|
|
|
+ checkpoint_name = os.path.join(
|
|
|
+ resume_from_checkpoint, "pytorch_model.bin"
|
|
|
+ ) # Full checkpoint
|
|
|
+ if not os.path.exists(checkpoint_name):
|
|
|
+ checkpoint_name = os.path.join(
|
|
|
+ resume_from_checkpoint, "adapter_model.bin"
|
|
|
+ ) # only LoRA model - LoRA config above has to fit
|
|
|
+ resume_from_checkpoint = False # So the trainer won't try loading its state
|
|
|
+ # The two files above have a different name depending on how they were saved, but are actually the same.
|
|
|
+ if os.path.exists(checkpoint_name):
|
|
|
+ print(f"Restarting from {checkpoint_name}")
|
|
|
+ adapters_weights = torch.load(checkpoint_name)
|
|
|
+ model = set_peft_model_state_dict(model, adapters_weights)
|
|
|
+
|
|
|
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
|
|
+
|
|
|
if val_set_size > 0:
|
|
|
train_val = data["train"].train_test_split(
|
|
|
test_size=val_set_size, shuffle=True, seed=42
|
|
|
@@ -183,7 +204,7 @@ def train(
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
|
|
model = torch.compile(model)
|
|
|
|
|
|
- trainer.train()
|
|
|
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
|
|
|
|
model.save_pretrained(output_dir)
|
|
|
|