Эх сурвалжийг харах

Fix a warning (#186)

Avoids the 
"Overriding torch_dtype=None with `torch_dtype=torch.float16` due to requirements of `bitsandbytes` to enable model loading in mixed int8. Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning." 
warning
Angainor Development 3 жил өмнө
parent
commit
69b9d9ea8b
1 өөрчлөгдсөн 1 нэмэгдсэн , 0 устгасан
  1. 1 0
      finetune.py

+ 1 - 0
finetune.py

@@ -108,6 +108,7 @@ def train(
     model = LlamaForCausalLM.from_pretrained(
     model = LlamaForCausalLM.from_pretrained(
         base_model,
         base_model,
         load_in_8bit=True,
         load_in_8bit=True,
+        torch_dtype=torch.float16,
         device_map=device_map,
         device_map=device_map,
     )
     )