|
@@ -26,7 +26,7 @@ from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
|
|
def train(
|
|
def train(
|
|
|
# model/data params
|
|
# model/data params
|
|
|
base_model: str = "", # the only required argument
|
|
base_model: str = "", # the only required argument
|
|
|
- data_path: str = "./alpaca_data_cleaned.json",
|
|
|
|
|
|
|
+ data_path: str = "yahma/alpaca-cleaned",
|
|
|
output_dir: str = "./lora-alpaca",
|
|
output_dir: str = "./lora-alpaca",
|
|
|
# training hyperparams
|
|
# training hyperparams
|
|
|
batch_size: int = 128,
|
|
batch_size: int = 128,
|