Переглянути джерело

Add support for valid set size 0 (#83)

* Add support for valid set size 0

* Make param about valid to default when 0
Kohaku-Blueleaf 3 роки тому
батько
коміт
b5a1a0bca7
1 змінених файлів з 12 додано та 11 видалено
  1. 12 11
      finetune.py

+ 12 - 11
finetune.py

@@ -67,12 +67,6 @@ model = get_peft_model(model, config)
 tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
 data = load_dataset("json", data_files=DATA_PATH)
 
-train_val = data["train"].train_test_split(
-    test_size=VAL_SET_SIZE, shuffle=True, seed=42
-)
-train_data = train_val["train"]
-val_data = train_val["test"]
-
 
 def generate_prompt(data_point):
     # sorry about the formatting disaster gotta move fast
@@ -164,8 +158,15 @@ def generate_and_tokenize_prompt(data_point):
     }
 
 
-train_data = train_data.shuffle().map(generate_and_tokenize_prompt)
-val_data = val_data.shuffle().map(generate_and_tokenize_prompt)
+if VAL_SET_SIZE > 0:
+    train_val = data["train"].train_test_split(
+        test_size=VAL_SET_SIZE, shuffle=True, seed=42
+    )
+    train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
+    val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
+else:
+    train_data = data['train'].shuffle().map(generate_and_tokenize_prompt)
+    val_data = None
 
 trainer = transformers.Trainer(
     model=model,
@@ -179,13 +180,13 @@ trainer = transformers.Trainer(
         learning_rate=LEARNING_RATE,
         fp16=True,
         logging_steps=20,
-        evaluation_strategy="steps",
+        evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
         save_strategy="steps",
-        eval_steps=200,
+        eval_steps=200 if VAL_SET_SIZE > 0 else None,
         save_steps=200,
         output_dir=OUTPUT_DIR,
         save_total_limit=3,
-        load_best_model_at_end=True,
+        load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
         ddp_find_unused_parameters=False if ddp else None,
     ),
     data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),