|
|
@@ -36,6 +36,10 @@ TARGET_MODULES = [
|
|
|
]
|
|
|
DATA_PATH = "alpaca_data_cleaned.json"
|
|
|
OUTPUT_DIR = "lora-alpaca"
|
|
|
+BASE_MODEL = None
|
|
|
+assert (
|
|
|
+ BASE_MODEL
|
|
|
+), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
|
|
|
|
|
|
device_map = "auto"
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
|
|
@@ -45,13 +49,11 @@ if ddp:
|
|
|
GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
- "decapoda-research/llama-7b-hf",
|
|
|
+ BASE_MODEL,
|
|
|
load_in_8bit=True,
|
|
|
device_map=device_map,
|
|
|
)
|
|
|
-tokenizer = LlamaTokenizer.from_pretrained(
|
|
|
- "decapoda-research/llama-7b-hf", add_eos_token=True
|
|
|
-)
|
|
|
+tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL, add_eos_token=True)
|
|
|
|
|
|
model = prepare_model_for_int8_training(model)
|
|
|
|