|
|
@@ -4,22 +4,28 @@ from typing import List
|
|
|
|
|
|
import fire
|
|
|
import torch
|
|
|
+import transformers
|
|
|
+from datasets import load_dataset
|
|
|
+
|
|
|
+"""
|
|
|
+Unused imports:
|
|
|
import torch.nn as nn
|
|
|
import bitsandbytes as bnb
|
|
|
-from datasets import load_dataset
|
|
|
-import transformers
|
|
|
+"""
|
|
|
|
|
|
+# Catch when user should re-install transformers library
|
|
|
assert (
|
|
|
"LlamaTokenizer" in transformers._import_structure["models.llama"]
|
|
|
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
|
|
|
-from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
-from peft import (
|
|
|
- prepare_model_for_int8_training,
|
|
|
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
|
|
|
+
|
|
|
+from peft import ( # noqa: E402
|
|
|
LoraConfig,
|
|
|
get_peft_model,
|
|
|
get_peft_model_state_dict,
|
|
|
+ prepare_model_for_int8_training,
|
|
|
set_peft_model_state_dict,
|
|
|
)
|
|
|
+from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
|
|
|
|
|
|
|
|
|
def train(
|
|
|
@@ -44,7 +50,7 @@ def train(
|
|
|
],
|
|
|
# llm hyperparams
|
|
|
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
|
|
- group_by_length: bool = False, # faster, but produces an odd training loss curve,
|
|
|
+ group_by_length: bool = False, # faster, but produces an odd training loss curve
|
|
|
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
|
|
):
|
|
|
print(
|
|
|
@@ -86,7 +92,9 @@ def train(
|
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
|
|
|
|
|
- tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
|
|
|
+ tokenizer.pad_token_id = (
|
|
|
+ 0 # unk. we want this to be different from the eos token
|
|
|
+ )
|
|
|
tokenizer.padding_side = "left" # Allow batched inference
|
|
|
|
|
|
def tokenize(prompt, add_eos_token=True):
|
|
|
@@ -138,7 +146,10 @@ def train(
|
|
|
)
|
|
|
model = get_peft_model(model, config)
|
|
|
|
|
|
- data = load_dataset("json", data_files=data_path)
|
|
|
+ if data_path.endswith(".json"): # todo: support jsonl
|
|
|
+ data = load_dataset("json", data_files=data_path)
|
|
|
+ else:
|
|
|
+ data = load_dataset(data_path)
|
|
|
|
|
|
if resume_from_checkpoint:
|
|
|
# Check the available weights and load them
|
|
|
@@ -149,7 +160,9 @@ def train(
|
|
|
checkpoint_name = os.path.join(
|
|
|
resume_from_checkpoint, "adapter_model.bin"
|
|
|
) # only LoRA model - LoRA config above has to fit
|
|
|
- resume_from_checkpoint = False # So the trainer won't try loading its state
|
|
|
+ resume_from_checkpoint = (
|
|
|
+ False # So the trainer won't try loading its state
|
|
|
+ )
|
|
|
# The two files above have a different name depending on how they were saved, but are actually the same.
|
|
|
if os.path.exists(checkpoint_name):
|
|
|
print(f"Restarting from {checkpoint_name}")
|
|
|
@@ -164,8 +177,12 @@ def train(
|
|
|
train_val = data["train"].train_test_split(
|
|
|
test_size=val_set_size, shuffle=True, seed=42
|
|
|
)
|
|
|
- train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
|
|
- val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
|
|
+ train_data = (
|
|
|
+ train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
|
|
+ )
|
|
|
+ val_data = (
|
|
|
+ train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
|
|
+ )
|
|
|
else:
|
|
|
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
|
|
|
val_data = None
|
|
|
@@ -201,7 +218,9 @@ def train(
|
|
|
|
|
|
old_state_dict = model.state_dict
|
|
|
model.state_dict = (
|
|
|
- lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
|
|
+ lambda self, *_, **__: get_peft_model_state_dict(
|
|
|
+ self, old_state_dict()
|
|
|
+ )
|
|
|
).__get__(model, type(model))
|
|
|
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
|
|
@@ -211,13 +230,15 @@ def train(
|
|
|
|
|
|
model.save_pretrained(output_dir)
|
|
|
|
|
|
- print("\n If there's a warning about missing keys above, please disregard :)")
|
|
|
+ print(
|
|
|
+ "\n If there's a warning about missing keys above, please disregard :)"
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def generate_prompt(data_point):
|
|
|
# sorry about the formatting disaster gotta move fast
|
|
|
if data_point["input"]:
|
|
|
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
|
|
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
|
|
|
|
|
|
### Instruction:
|
|
|
{data_point["instruction"]}
|
|
|
@@ -228,7 +249,7 @@ def generate_prompt(data_point):
|
|
|
### Response:
|
|
|
{data_point["output"]}"""
|
|
|
else:
|
|
|
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
|
|
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
|
|
|
|
|
|
### Instruction:
|
|
|
{data_point["instruction"]}
|