Răsfoiți Sursa

Catch outdated installs

Eric Wang 3 ani în urmă
părinte
comite
5f6614e6fc
3 a modificat fișierele cu 16 adăugiri și 1 ștergeri
  1. 7 1
      export_state_dict_checkpoint.py
  2. 4 0
      finetune.py
  3. 5 0
      generate.py

+ 7 - 1
export_state_dict_checkpoint.py

@@ -3,6 +3,12 @@ import json
 
 import torch
 from peft import PeftModel, LoraConfig
+
+import transformers
+
+assert (
+    "LlamaTokenizer" in transformers._import_structure["models.llama"]
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
 from transformers import LlamaTokenizer, LlamaForCausalLM
 
 tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
@@ -25,7 +31,7 @@ lora_model = PeftModel.from_pretrained(
 for layer in lora_model.base_model.model.model.layers:
     layer.self_attn.q_proj.merge_weights = True
     layer.self_attn.v_proj.merge_weights = True
-    
+
 lora_model.train(False)
 
 lora_model_sd = lora_model.state_dict()

+ 4 - 0
finetune.py

@@ -6,6 +6,10 @@ import torch.nn as nn
 import bitsandbytes as bnb
 from datasets import load_dataset
 import transformers
+
+assert (
+    "LlamaTokenizer" in transformers._import_structure["models.llama"]
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
 from transformers import AutoTokenizer, AutoConfig, LlamaForCausalLM, LlamaTokenizer
 from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
 

+ 5 - 0
generate.py

@@ -1,5 +1,10 @@
 import torch
 from peft import PeftModel
+import transformers
+
+assert (
+    "LlamaTokenizer" in transformers._import_structure["models.llama"]
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
 from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
 tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")