浏览代码

remove asserts

Eric Wang 3 年之前
父节点
当前提交
804d22ad43
共有 4 个文件被更改,包括 1 次插入27 次删除
  1. 0 9
      export_hf_checkpoint.py
  2. 0 8
      export_state_dict_checkpoint.py
  3. 1 6
      finetune.py
  4. 0 4
      generate.py

+ 0 - 9
export_hf_checkpoint.py

@@ -3,15 +3,6 @@ import os
 import torch
 import transformers
 from peft import PeftModel
-
-# Unused imports
-# import json
-# from peft import LoraConfig
-
-assert (
-    "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
-
 from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402
 
 BASE_MODEL = os.environ.get("BASE_MODEL", None)

+ 0 - 8
export_state_dict_checkpoint.py

@@ -3,15 +3,7 @@ import os
 
 import torch
 import transformers
-
-# Unused imports
-# from peft import LoraConfig
 from peft import PeftModel
-
-assert (
-    "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
-
 from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: E402
 
 BASE_MODEL = os.environ.get("BASE_MODEL", None)

+ 1 - 6
finetune.py

@@ -13,11 +13,6 @@ import torch.nn as nn
 import bitsandbytes as bnb
 """
 
-# Catch when user should re-install transformers library
-assert (
-    "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
-
 from peft import (  # noqa: E402
     LoraConfig,
     get_peft_model,
@@ -38,7 +33,7 @@ def train(
     micro_batch_size: int = 4,
     num_epochs: int = 3,
     learning_rate: float = 3e-4,
-    cutoff_len: int = 256,
+    cutoff_len: int = 512,
     val_set_size: int = 2000,
     # lora hyperparams
     lora_r: int = 8,

+ 0 - 4
generate.py

@@ -5,10 +5,6 @@ import gradio as gr
 import torch
 import transformers
 from peft import PeftModel
-
-assert (
-    "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
 from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
 
 if torch.cuda.is_available():