Quellcode durchsuchen

resume_from_checkpoint

Co-authored-by: AngainorDev <[email protected]>
Eric Wang vor 3 Jahren
Ursprung
Commit
da6b427a08
1 geänderte Dateien mit 23 neuen und 2 gelöschten Zeilen
  1. 23 2
      finetune.py

+ 23 - 2
finetune.py

@@ -18,6 +18,7 @@ from peft import (
     LoraConfig,
     get_peft_model,
     get_peft_model_state_dict,
+    set_peft_model_state_dict,
 )
 
 
@@ -43,7 +44,8 @@ def train(
     ],
     # llm hyperparams
     train_on_inputs: bool = True,  # if False, masks out inputs in loss
-    group_by_length: bool = False,  # faster, but produces an odd training loss curve
+    group_by_length: bool = False,  # faster, but produces an odd training loss curve,
+    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
 ):
     print(
         f"Training Alpaca-LoRA model with params:\n"
@@ -62,6 +64,7 @@ def train(
         f"lora_target_modules: {lora_target_modules}\n"
         f"train_on_inputs: {train_on_inputs}\n"
         f"group_by_length: {group_by_length}\n"
+        f"resume_from_checkpoint: {resume_from_checkpoint}\n"
     )
     assert (
         base_model
@@ -137,6 +140,24 @@ def train(
 
     data = load_dataset("json", data_files=data_path)
 
+    if resume_from_checkpoint:
+        # Check the available weights and load them
+        checkpoint_name = os.path.join(
+            resume_from_checkpoint, "pytorch_model.bin"
+        )  # Full checkpoint
+        if not os.path.exists(checkpoint_name):
+            checkpoint_name = os.path.join(
+                resume_from_checkpoint, "adapter_model.bin"
+            )  # only LoRA model - LoRA config above has to fit
+            resume_from_checkpoint = False  # So the trainer won't try loading its state
+        # The two files above have a different name depending on how they were saved, but are actually the same.
+        if os.path.exists(checkpoint_name):
+            print(f"Restarting from {checkpoint_name}")
+            adapters_weights = torch.load(checkpoint_name)
+            model = set_peft_model_state_dict(model, adapters_weights)
+
+    model.print_trainable_parameters()  # Be more transparent about the % of trainable params.
+
     if val_set_size > 0:
         train_val = data["train"].train_test_split(
             test_size=val_set_size, shuffle=True, seed=42
@@ -183,7 +204,7 @@ def train(
     if torch.__version__ >= "2" and sys.platform != "win32":
         model = torch.compile(model)
 
-    trainer.train()
+    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
 
     model.save_pretrained(output_dir)