|
|
@@ -1,4 +1,5 @@
|
|
|
import os
|
|
|
+import sys
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
@@ -195,7 +196,7 @@ model.state_dict = (
|
|
|
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
|
|
).__get__(model, type(model))
|
|
|
|
|
|
-if torch.__version__ >= "2":
|
|
|
+if torch.__version__ >= "2" and sys.platform != 'win32':
|
|
|
model = torch.compile(model)
|
|
|
|
|
|
trainer.train()
|