|
|
@@ -204,6 +204,11 @@ def train(
|
|
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
|
|
val_data = None
|
|
|
|
|
|
+ if not ddp and torch.cuda.device_count() > 1:
|
|
|
+ # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
|
|
|
+ model.is_parallelizable = True
|
|
|
+ model.model_parallel = True
|
|
|
+
|
|
|
trainer = transformers.Trainer(
|
|
|
model=model,
|
|
|
train_dataset=train_data,
|