|
@@ -202,7 +202,7 @@ def train(
|
|
|
if os.path.exists(checkpoint_name):
|
|
if os.path.exists(checkpoint_name):
|
|
|
print(f"Restarting from {checkpoint_name}")
|
|
print(f"Restarting from {checkpoint_name}")
|
|
|
adapters_weights = torch.load(checkpoint_name)
|
|
adapters_weights = torch.load(checkpoint_name)
|
|
|
- model = set_peft_model_state_dict(model, adapters_weights)
|
|
|
|
|
|
|
+ set_peft_model_state_dict(model, adapters_weights)
|
|
|
else:
|
|
else:
|
|
|
print(f"Checkpoint {checkpoint_name} not found")
|
|
print(f"Checkpoint {checkpoint_name} not found")
|
|
|
|
|
|