瀏覽代碼

fix finetuning code :(

Eric Wang 3 年之前
父節點
當前提交
a2607faff0
共有 1 個文件被更改,包括 3 次插入11 次删除
  1. 3 11
      finetune.py

+ 3 - 11
finetune.py

@@ -68,7 +68,7 @@ def generate_prompt(data_point):
 {data_point["output"]}"""
 
 
-data = data.map(
+data = data.shuffle().map(
     lambda data_point: tokenizer(
         generate_prompt(data_point),
         truncation=True,
@@ -77,17 +77,9 @@ data = data.map(
     )
 )
 
-
-train_testvalid = data.train_test_split(test_size=2000, shuffle=True, seed=42)
-test_valid = train_testvalid["test"].train_test_split(test_size=1000)
-train_data = train_testvalid["train"]
-valid_data = test_valid["train"]
-test_data = test_valid["test"]
-
 trainer = transformers.Trainer(
     model=model,
-    train_dataset=train_data,
-    eval_dataset=valid_data,
+    train_dataset=data["train"],
     args=transformers.TrainingArguments(
         per_device_train_batch_size=MICRO_BATCH_SIZE,
         gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
@@ -95,7 +87,7 @@ trainer = transformers.Trainer(
         num_train_epochs=EPOCHS,
         learning_rate=LEARNING_RATE,
         fp16=True,
-        logging_steps=10,
+        logging_steps=1,
         output_dir="lora-alpaca",
         save_total_limit=3,
     ),