|
@@ -35,6 +35,7 @@ TARGET_MODULES = [
|
|
|
"v_proj",
|
|
"v_proj",
|
|
|
]
|
|
]
|
|
|
DATA_PATH = "alpaca_data_cleaned.json"
|
|
DATA_PATH = "alpaca_data_cleaned.json"
|
|
|
|
|
+OUTPUT_DIR = "lora-alpaca"
|
|
|
|
|
|
|
|
device_map = "auto"
|
|
device_map = "auto"
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
@@ -182,7 +183,7 @@ trainer = transformers.Trainer(
|
|
|
save_strategy="steps",
|
|
save_strategy="steps",
|
|
|
eval_steps=200,
|
|
eval_steps=200,
|
|
|
save_steps=200,
|
|
save_steps=200,
|
|
|
- output_dir="lora-alpaca",
|
|
|
|
|
|
|
+ output_dir=OUTPUT_DIR,
|
|
|
save_total_limit=3,
|
|
save_total_limit=3,
|
|
|
load_best_model_at_end=True,
|
|
load_best_model_at_end=True,
|
|
|
ddp_find_unused_parameters=False if ddp else None,
|
|
ddp_find_unused_parameters=False if ddp else None,
|
|
@@ -201,6 +202,6 @@ if torch.__version__ >= "2" and sys.platform != 'win32':
|
|
|
|
|
|
|
|
trainer.train()
|
|
trainer.train()
|
|
|
|
|
|
|
|
-model.save_pretrained("lora-alpaca")
|
|
|
|
|
|
|
+model.save_pretrained(OUTPUT_DIR)
|
|
|
|
|
|
|
|
print("\n If there's a warning about missing keys above, please disregard :)")
|
|
print("\n If there's a warning about missing keys above, please disregard :)")
|