|
@@ -13,11 +13,6 @@ import torch.nn as nn
|
|
|
import bitsandbytes as bnb
|
|
import bitsandbytes as bnb
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
-# Catch when user should re-install transformers library
|
|
|
|
|
-assert (
|
|
|
|
|
- "LlamaTokenizer" in transformers._import_structure["models.llama"]
|
|
|
|
|
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
|
|
|
|
-
|
|
|
|
|
from peft import ( # noqa: E402
|
|
from peft import ( # noqa: E402
|
|
|
LoraConfig,
|
|
LoraConfig,
|
|
|
get_peft_model,
|
|
get_peft_model,
|
|
@@ -38,7 +33,7 @@ def train(
|
|
|
micro_batch_size: int = 4,
|
|
micro_batch_size: int = 4,
|
|
|
num_epochs: int = 3,
|
|
num_epochs: int = 3,
|
|
|
learning_rate: float = 3e-4,
|
|
learning_rate: float = 3e-4,
|
|
|
- cutoff_len: int = 256,
|
|
|
|
|
|
|
+ cutoff_len: int = 512,
|
|
|
val_set_size: int = 2000,
|
|
val_set_size: int = 2000,
|
|
|
# lora hyperparams
|
|
# lora hyperparams
|
|
|
lora_r: int = 8,
|
|
lora_r: int = 8,
|