Forráskód Böngészése

Fix a warning (#186)

Avoids the 
"Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning." 
warning
Angainor Development 3 éve
szülő
commit
69b9d9ea8b
1 módosított fájl, 1 hozzáadás és 0 törlés
  1. 1 0
      finetune.py

+ 1 - 0
finetune.py

@@ -108,6 +108,7 @@ def train(
     model = LlamaForCausalLM.from_pretrained(
         base_model,
         load_in_8bit=True,
+        torch_dtype=torch.float16,
         device_map=device_map,
     )