|
|
@@ -36,10 +36,17 @@ TARGET_MODULES = [
|
|
|
]
|
|
|
DATA_PATH = "alpaca_data_cleaned.json"
|
|
|
|
|
|
+device_map = "auto"
|
|
|
+world_size = int(os.environ.get('WORLD_SIZE', 1))
|
|
|
+ddp = world_size != 1
|
|
|
+if ddp:
|
|
|
+ device_map = {'':int(os.environ.get('LOCAL_RANK') or 0)}
|
|
|
+ GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
|
|
|
+
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
"decapoda-research/llama-7b-hf",
|
|
|
load_in_8bit=True,
|
|
|
- device_map="auto",
|
|
|
+ device_map=device_map,
|
|
|
)
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(
|
|
|
"decapoda-research/llama-7b-hf", add_eos_token=True
|
|
|
@@ -126,6 +133,7 @@ trainer = transformers.Trainer(
|
|
|
output_dir="lora-alpaca",
|
|
|
save_total_limit=3,
|
|
|
load_best_model_at_end=True,
|
|
|
+ ddp_find_unused_parameters=False if ddp else None,
|
|
|
),
|
|
|
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
|
|
|
)
|