Przeglądaj źródła

Templated prompter (#184)

* Templated prompter

* fix dup import

* Set Verbose False by default

I forgot to disable after testing.

* Fix imports order

* Use Black Formatting

* lint

* Re-introduce lost line

* Cleanup

* template default

* isort

---------

Co-authored-by: Eric Wang <[email protected]>
Angainor Development 3 lat temu
rodzic
commit
8d58d37b65

+ 2 - 1
.gitignore

@@ -10,4 +10,5 @@ lora-**
 wandb
 evaluate.py
 test_data.json
-todo.txt
+todo.txt
+.vscode/

+ 17 - 28
finetune.py

@@ -13,14 +13,16 @@ import torch.nn as nn
 import bitsandbytes as bnb
 """
 
-from peft import (  # noqa: E402
+from peft import (
     LoraConfig,
     get_peft_model,
     get_peft_model_state_dict,
     prepare_model_for_int8_training,
     set_peft_model_state_dict,
 )
-from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402
+from transformers import LlamaForCausalLM, LlamaTokenizer
+
+from utils.prompter import Prompter
 
 
 def train(
@@ -52,6 +54,7 @@ def train(
     wandb_watch: str = "",  # options: false | gradients | all
     wandb_log_model: str = "",  # options: false | true
     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
+    prompt_template_name: str = "alpaca",  # The prompt template to use, will default to alpaca.
 ):
     if int(os.environ.get("LOCAL_RANK", 0)) == 0:
         print(
@@ -75,13 +78,16 @@ def train(
             f"wandb_run_name: {wandb_run_name}\n"
             f"wandb_watch: {wandb_watch}\n"
             f"wandb_log_model: {wandb_log_model}\n"
-            f"resume_from_checkpoint: {resume_from_checkpoint}\n"
+            f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
+            f"prompt template: {prompt_template_name}\n"
         )
     assert (
         base_model
     ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
     gradient_accumulation_steps = batch_size // micro_batch_size
 
+    prompter = Prompter(prompt_template_name)
+
     device_map = "auto"
     world_size = int(os.environ.get("WORLD_SIZE", 1))
     ddp = world_size != 1
@@ -138,10 +144,16 @@ def train(
         return result
 
     def generate_and_tokenize_prompt(data_point):
-        full_prompt = generate_prompt(data_point)
+        full_prompt = prompter.generate_prompt(
+            data_point["instruction"],
+            data_point["input"],
+            data_point["output"],
+        )
         tokenized_full_prompt = tokenize(full_prompt)
         if not train_on_inputs:
-            user_prompt = generate_prompt({**data_point, "output": ""})
+            user_prompt = prompter.generate_prompt(
+                data_point["instruction"], data_point["input"]
+            )
             tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
             user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
@@ -260,28 +272,5 @@ def train(
     )
 
 
-def generate_prompt(data_point):
-    # sorry about the formatting disaster gotta move fast
-    if data_point["input"]:
-        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
-
-### Instruction:
-{data_point["instruction"]}
-
-### Input:
-{data_point["input"]}
-
-### Response:
-{data_point["output"]}"""
-    else:
-        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  # noqa: E501
-
-### Instruction:
-{data_point["instruction"]}
-
-### Response:
-{data_point["output"]}"""
-
-
 if __name__ == "__main__":
     fire.Fire(train)

+ 8 - 25
generate.py

@@ -7,6 +7,8 @@ import transformers
 from peft import PeftModel
 from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
 
+from utils.prompter import Prompter
+
 if torch.cuda.is_available():
     device = "cuda"
 else:
@@ -23,12 +25,15 @@ def main(
     load_8bit: bool = False,
     base_model: str = "",
     lora_weights: str = "tloen/alpaca-lora-7b",
+    prompt_template: str = "",  # The prompt template to use, will default to alpaca.
+    server_name: str = "127.0.0.1",  # Allows to listen on all interfaces by providing '0.0.0.0'
     share_gradio: bool = False,
 ):
     assert (
         base_model
     ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
 
+    prompter = Prompter(prompt_template)
     tokenizer = LlamaTokenizer.from_pretrained(base_model)
     if device == "cuda":
         model = LlamaForCausalLM.from_pretrained(
@@ -86,7 +91,7 @@ def main(
         max_new_tokens=128,
         **kwargs,
     ):
-        prompt = generate_prompt(instruction, input)
+        prompt = prompter.generate_prompt(instruction, input)
         inputs = tokenizer(prompt, return_tensors="pt")
         input_ids = inputs["input_ids"].to(device)
         generation_config = GenerationConfig(
@@ -106,7 +111,7 @@ def main(
             )
         s = generation_output.sequences[0]
         output = tokenizer.decode(s)
-        return output.split("### Response:")[1].strip()
+        return prompter.get_response(output)
 
     gr.Interface(
         fn=evaluate,
@@ -141,7 +146,7 @@ def main(
         ],
         title="🦙🌲 Alpaca-LoRA",
         description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",  # noqa: E501
-    ).launch(share=share_gradio)
+    ).launch(server_name=server_name, share=share_gradio)
     # Old testing code follows.
 
     """
@@ -163,27 +168,5 @@ def main(
     """
 
 
-def generate_prompt(instruction, input=None):
-    if input:
-        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
-
-### Instruction:
-{instruction}
-
-### Input:
-{input}
-
-### Response:
-"""
-    else:
-        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  # noqa: E501
-
-### Instruction:
-{instruction}
-
-### Response:
-"""
-
-
 if __name__ == "__main__":
     fire.Fire(main)

+ 46 - 0
templates/README.md

@@ -0,0 +1,46 @@
+# Prompt templates
+
+This directory contains template styles for the prompts used to finetune LoRA models.
+
+## Format
+
+A template is described via a JSON file with the following keys:
+
+- `prompt_input`: The template to use when input is not None. Uses `{instruction}` and `{input}` placeholders.
+- `prompt_no_input`: The template to use when input is None. Uses `{instruction}` placeholders.
+- `description`: A short description of the template, with possible use cases.
+- `response_split`: The text to use as separator when cutting real response from the model output.
+
+No `{response}` placeholder was used, since the response is always the last element of the template and is just to be concatenated to the rest.
+
+## Example template
+
+The default template, used unless otherwise specified, is `alpaca.json`
+
+```json
+{
+    "description": "Template used by Alpaca-LoRA.",
+    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
+    "response_split": "### Response:"    
+}
+
+```
+
+## Current templates
+
+### alpaca
+
+Default template used for generic LoRA fine tunes so far.
+
+### alpaca_legacy
+
+Legacy template used by the original alpaca repo, with no `\n` after the response field. Kept for reference and experiments.
+
+### alpaca_short
+
+A trimmed down alpaca template which seems to perform just as well and spare some tokens. Models created with the default template seem to be queryable by the short tempalte as well. More experiments are welcome.
+
+### vigogne
+
+The default alpaca template, translated to french. This template was used to train the "Vigogne" LoRA and is to be used to query it, or for extra fine tuning.

+ 6 - 0
templates/alpaca.json

@@ -0,0 +1,6 @@
+{
+    "description": "Template used by Alpaca-LoRA.",
+    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
+    "response_split": "### Response:"    
+}

+ 6 - 0
templates/alpaca_legacy.json

@@ -0,0 +1,6 @@
+{
+    "description": "Legacy template, used by Original Alpaca repository.",
+    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:",
+    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:",
+    "response_split": "### Response:"    
+}

+ 6 - 0
templates/alpaca_short.json

@@ -0,0 +1,6 @@
+{
+    "description": "A shorter template to experiment with.",
+    "prompt_input": "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
+    "prompt_no_input": "### Instruction:\n{instruction}\n\n### Response:\n",
+    "response_split": "### Response:"    
+}

+ 6 - 0
templates/vigogne.json

@@ -0,0 +1,6 @@
+{
+    "description": "French template, used by Vigogne for finetuning.",
+    "prompt_input": "Ci-dessous se trouve une instruction qui décrit une tâche, associée à une entrée qui fournit un contexte supplémentaire. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Entrée:\n{input}\n\n### Réponse:\n",
+    "prompt_no_input": "Ci-dessous se trouve une instruction qui décrit une tâche. Écrivez une réponse qui complète correctement la demande.\n\n### Instruction:\n{instruction}\n\n### Réponse:\n",
+    "response_split": "### Réponse:"
+}

+ 7 - 0
utils/README.md

@@ -0,0 +1,7 @@
+# Directory for helpers modules
+
+## prompter.py
+
+Prompter class, a template manager.
+
+`from utils.prompter import Prompter`

+ 0 - 0
utils/__init__.py


+ 51 - 0
utils/prompter.py

@@ -0,0 +1,51 @@
+"""
+A dedicated helper to manage templates and prompt building.
+"""
+
+import json
+import os.path as osp
+from typing import Union
+
+
+class Prompter(object):
+    __slots__ = ("template", "_verbose")
+
+    def __init__(self, template_name: str = "", verbose: bool = False):
+        self._verbose = verbose
+        if not template_name:
+            # Enforce the default here, so the constructor can be called with '' and will not break.
+            template_name = "alpaca"
+        file_name = osp.join("templates", f"{template_name}.json")
+        if not osp.exists(file_name):
+            raise ValueError(f"Can't read {file_name}")
+        with open(file_name) as fp:
+            self.template = json.load(fp)
+        if self._verbose:
+            print(
+                f"Using prompt template {template_name}: {self.template['description']}"
+            )
+
+    def generate_prompt(
+        self,
+        instruction: str,
+        input: Union[None, str] = None,
+        label: Union[None, str] = None,
+    ) -> str:
+        # returns the full prompt from instruction and optional input
+        # if a label (=response, =output) is provided, it's also appended.
+        if input:
+            res = self.template["prompt_input"].format(
+                instruction=instruction, input=input
+            )
+        else:
+            res = self.template["prompt_no_input"].format(
+                instruction=instruction
+            )
+        if label:
+            res = f"{res}{label}"
+        if self._verbose:
+            print(res)
+        return res
+
+    def get_response(self, output: str) -> str:
+        return output.split(self.template["response_split"])[1].strip()