|
@@ -155,6 +155,8 @@ def train(
|
|
|
print(f"Restarting from {checkpoint_name}")
|
|
print(f"Restarting from {checkpoint_name}")
|
|
|
adapters_weights = torch.load(checkpoint_name)
|
|
adapters_weights = torch.load(checkpoint_name)
|
|
|
model = set_peft_model_state_dict(model, adapters_weights)
|
|
model = set_peft_model_state_dict(model, adapters_weights)
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f"Checkpoint {checkpoint_name} not found")
|
|
|
|
|
|
|
|
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
|
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
|
|
|
|
|