Parcourir la source

Enabling model parallelism (training 30b on 2x 3090s and beyond) (#131)

* override broken data parallelism with model parallelism

* formatting

* formatting, again

---------

Co-authored-by: Eric Wang <[email protected]>
кѳѳsнī il y a 3 ans
Parent
commit
55b664f46f
1 fichiers modifiés avec 5 ajouts et 0 suppressions
  1. 5 0
      finetune.py

+ 5 - 0
finetune.py

@@ -204,6 +204,11 @@ def train(
         train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
         val_data = None
 
+    if not ddp and torch.cuda.device_count() > 1:
+        # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
+        model.is_parallelizable = True
+        model.model_parallel = True
+
     trainer = transformers.Trainer(
         model=model,
         train_dataset=train_data,