Эх сурвалжийг харах

Use CLI arguments (#159)

* CLI args for finetune

* Update README

* CLI args for generate.py

* reqs.txt

* reorder hyperparams

* lora_target_modules

* cleanup
Eric J. Wang 3 жил өмнө
parent
commit
5fa807d106
4 өөрчлөгдсөн 348 нэмэгдсэн , 279 устгасан
  1. 42 6
      README.md
  2. 166 140
      finetune.py
  3. 139 133
      generate.py
  4. 1 0
      requirements.txt

+ 42 - 6
README.md

@@ -2,7 +2,6 @@
 
 - 🤗 **Try the pretrained model out [here](https://huggingface.co/spaces/tloen/alpaca-lora), courtesy of a GPU grant from Huggingface!**
 - Users have created a Discord server for discussion and support [here](https://discord.gg/prbq284xX5)
-- **This repository does not contain code for hosting and/or facilitating the downloading and/or streaming of the LLaMA weights. You will have to specify your own HuggingFace Hub base model to run the code, such as `decapoda-research/llama-7b-hf`.**
 
 This repository contains code for reproducing the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) results using [low-rank adaptation (LoRA)](https://arxiv.org/pdf/2106.09685.pdf).
 We provide an Instruct model of similar quality to `text-davinci-003` that can run [on a Raspberry Pi](https://twitter.com/miolini/status/1634982361757790209) (for research),
@@ -26,17 +25,54 @@ pip install -r requirements.txt
 
 2. If bitsandbytes doesn't work, [install it from source.](https://github.com/TimDettmers/bitsandbytes/blob/main/compile_from_source.md) Windows users can follow [these instructions](https://github.com/tloen/alpaca-lora/issues/17).
 
-### Inference (`generate.py`)
-
-This file reads the foundation model from the Hugging Face model hub and the LoRA weights from `tloen/alpaca-lora-7b`, and runs a Gradio interface for inference on a specified input. Users should treat this as example code for the use of the model, and modify it as needed.
-
 ### Training (`finetune.py`)
 
 This file contains a straightforward application of PEFT to the LLaMA model,
 as well as some code related to prompt construction and tokenization.
-Near the top of this file is a set of hardcoded hyperparameters that you should feel free to modify.
 PRs adapting this code to support larger models are always welcome.
 
+Example usage:
+
+```bash
+python finetune.py \
+    --base_model 'decapoda-research/llama-7b-hf' \
+    --data_path './alpaca_data_cleaned.json' \
+    --output_dir './lora-alpaca'
+```
+
+We can also tweak our hyperparameters:
+```bash
+python finetune.py \
+    --base_model 'decapoda-research/llama-7b-hf' \
+    --data_path 'alpaca_data_cleaned.json' \
+    --output_dir './lora-alpaca' \
+    --batch_size 128 \
+    --micro_batch_size 4 \
+    --num_epochs 3 \
+    --learning_rate 1e-4 \
+    --cutoff_len 512 \
+    --val_set_size 2000 \
+    --lora_r 8 \
+    --lora_alpha 16 \
+    --lora_dropout 0.05 \
+    --lora_target_modules '[q_proj,v_proj]' \
+    --train_on_inputs \
+    --group_by_length
+```
+
+### Inference (`generate.py`)
+
+This file reads the foundation model from the Hugging Face model hub and the LoRA weights from `tloen/alpaca-lora-7b`, and runs a Gradio interface for inference on a specified input. Users should treat this as example code for the use of the model, and modify it as needed.
+
+Example usage:
+
+```bash
+python generate.py \
+    --load_8bit \
+    --base_model 'decapoda-research/llama-7b-hf' \
+    --lora_weights 'tloen/alpaca-lora-7b'
+```
+
 ### Checkpoint export (`export_*_checkpoint.py`)
 
 These files contain scripts that merge the LoRA weights back into the base model

+ 166 - 140
finetune.py

@@ -1,6 +1,8 @@
 import os
 import sys
+from typing import List
 
+import fire
 import torch
 import torch.nn as nn
 import bitsandbytes as bnb
@@ -19,61 +21,173 @@ from peft import (
 )
 
 
-# optimized for RTX 4090. for larger GPUs, increase some of these?
-MICRO_BATCH_SIZE = 4  # this could actually be 5 but i like powers of 2
-BATCH_SIZE = 128
-GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
-EPOCHS = 3
-LEARNING_RATE = 3e-4
-CUTOFF_LEN = 512
-LORA_R = 8
-LORA_ALPHA = 16
-LORA_DROPOUT = 0.05
-VAL_SET_SIZE = 2000
-TARGET_MODULES = [
-    "q_proj",
-    "v_proj",
-]
-DATA_PATH = "alpaca_data_cleaned.json"
-OUTPUT_DIR = "lora-alpaca"
-BASE_MODEL = None
-assert (
-    BASE_MODEL
-), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
-TRAIN_ON_INPUTS = True
-GROUP_BY_LENGTH = True  # faster, but produces an odd training loss curve
-
-device_map = "auto"
-world_size = int(os.environ.get("WORLD_SIZE", 1))
-ddp = world_size != 1
-if ddp:
-    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
-    GRADIENT_ACCUMULATION_STEPS = GRADIENT_ACCUMULATION_STEPS // world_size
-
-model = LlamaForCausalLM.from_pretrained(
-    BASE_MODEL,
-    load_in_8bit=True,
-    device_map=device_map,
-)
+def train(
+    # model/data params
+    base_model: str = "",  # the only required argument
+    data_path: str = "./alpaca_data_cleaned.json",
+    output_dir: str = "./lora-alpaca",
+    # training hyperparams
+    batch_size: int = 128,
+    micro_batch_size: int = 4,
+    num_epochs: int = 3,
+    learning_rate: float = 3e-4,
+    cutoff_len: int = 512,
+    val_set_size: int = 2000,
+    # lora hyperparams
+    lora_r: int = 8,
+    lora_alpha: int = 16,
+    lora_dropout: float = 0.05,
+    lora_target_modules: List[str] = [
+        "q_proj",
+        "v_proj",
+    ],
+    # llm hyperparams
+    train_on_inputs: bool = True,  # if False, masks out inputs in loss
+    group_by_length: bool = True,  # faster, but produces an odd training loss curve
+):
+    print(
+        f"Training Alpaca-LoRA model with params:\n"
+        f"base_model: {base_model}\n"
+        f"data_path: {data_path}\n"
+        f"output_dir: {output_dir}\n"
+        f"batch_size: {batch_size}\n"
+        f"micro_batch_size: {micro_batch_size}\n"
+        f"num_epochs: {num_epochs}\n"
+        f"learning_rate: {learning_rate}\n"
+        f"cutoff_len: {cutoff_len}\n"
+        f"val_set_size: {val_set_size}\n"
+        f"lora_r: {lora_r}\n"
+        f"lora_alpha: {lora_alpha}\n"
+        f"lora_dropout: {lora_dropout}\n"
+        f"lora_target_modules: {lora_target_modules}\n"
+        f"train_on_inputs: {train_on_inputs}\n"
+        f"group_by_length: {group_by_length}\n"
+    )
+    assert (
+        base_model
+    ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
+    gradient_accumulation_steps = batch_size // micro_batch_size
+
+    device_map = "auto"
+    world_size = int(os.environ.get("WORLD_SIZE", 1))
+    ddp = world_size != 1
+    if ddp:
+        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
+        gradient_accumulation_steps = gradient_accumulation_steps // world_size
+
+    model = LlamaForCausalLM.from_pretrained(
+        base_model,
+        load_in_8bit=True,
+        device_map=device_map,
+    )
+
+    tokenizer = LlamaTokenizer.from_pretrained(base_model)
+
+    tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
+    tokenizer.padding_side = "left"  # Allow batched inference
+
+    def tokenize(prompt, add_eos_token=True):
+        # there's probably a way to do this with the tokenizer settings
+        # but again, gotta move fast
+        result = tokenizer(
+            prompt,
+            truncation=True,
+            max_length=cutoff_len,
+            padding=False,
+            return_tensors=None,
+        )
+        if (
+            result["input_ids"][-1] != tokenizer.eos_token_id
+            and len(result["input_ids"]) < cutoff_len
+            and add_eos_token
+        ):
+            result["input_ids"].append(tokenizer.eos_token_id)
+            result["attention_mask"].append(1)
+
+        result["labels"] = result["input_ids"].copy()
+
+        return result
+
+    def generate_and_tokenize_prompt(data_point):
+        full_prompt = generate_prompt(data_point)
+        tokenized_full_prompt = tokenize(full_prompt)
+        if not train_on_inputs:
+            user_prompt = generate_prompt({**data_point, "output": ""})
+            tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
+            user_prompt_len = len(tokenized_user_prompt["input_ids"])
+
+            tokenized_full_prompt["labels"] = [
+                -100
+            ] * user_prompt_len + tokenized_full_prompt["labels"][
+                user_prompt_len:
+            ]  # could be sped up, probably
+        return tokenized_full_prompt
+
+    model = prepare_model_for_int8_training(model)
+
+    config = LoraConfig(
+        r=lora_r,
+        lora_alpha=lora_alpha,
+        target_modules=lora_target_modules,
+        lora_dropout=lora_dropout,
+        bias="none",
+        task_type="CAUSAL_LM",
+    )
+    model = get_peft_model(model, config)
 
-tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
+    data = load_dataset("json", data_files=data_path)
 
-tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
-tokenizer.padding_side = "left"  # Allow batched inference
+    if val_set_size > 0:
+        train_val = data["train"].train_test_split(
+            test_size=val_set_size, shuffle=True, seed=42
+        )
+        train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
+        val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
+    else:
+        train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
+        val_data = None
+
+    trainer = transformers.Trainer(
+        model=model,
+        train_dataset=train_data,
+        eval_dataset=val_data,
+        args=transformers.TrainingArguments(
+            per_device_train_batch_size=micro_batch_size,
+            gradient_accumulation_steps=gradient_accumulation_steps,
+            warmup_steps=100,
+            num_train_epochs=num_epochs,
+            learning_rate=learning_rate,
+            fp16=True,
+            logging_steps=10,
+            evaluation_strategy="steps" if val_set_size > 0 else "no",
+            save_strategy="steps",
+            eval_steps=200 if val_set_size > 0 else None,
+            save_steps=200,
+            output_dir=output_dir,
+            save_total_limit=3,
+            load_best_model_at_end=True if val_set_size > 0 else False,
+            ddp_find_unused_parameters=False if ddp else None,
+            group_by_length=group_by_length,
+        ),
+        data_collator=transformers.DataCollatorForSeq2Seq(
+            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
+        ),
+    )
+    model.config.use_cache = False
 
-model = prepare_model_for_int8_training(model)
+    old_state_dict = model.state_dict
+    model.state_dict = (
+        lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
+    ).__get__(model, type(model))
 
-config = LoraConfig(
-    r=LORA_R,
-    lora_alpha=LORA_ALPHA,
-    target_modules=TARGET_MODULES,
-    lora_dropout=LORA_DROPOUT,
-    bias="none",
-    task_type="CAUSAL_LM",
-)
-model = get_peft_model(model, config)
+    if torch.__version__ >= "2" and sys.platform != "win32":
+        model = torch.compile(model)
+
+    trainer.train()
 
-data = load_dataset("json", data_files=DATA_PATH)
+    model.save_pretrained(output_dir)
+
+    print("\n If there's a warning about missing keys above, please disregard :)")
 
 
 def generate_prompt(data_point):
@@ -99,93 +213,5 @@ def generate_prompt(data_point):
 {data_point["output"]}"""
 
 
-def tokenize(prompt, add_eos_token=True):
-    # there's probably a way to do this with the tokenizer settings
-    # but again, gotta move fast
-    result = tokenizer(
-        prompt,
-        truncation=True,
-        max_length=CUTOFF_LEN,
-        padding=False,
-        return_tensors=None,
-    )
-    if (
-        result["input_ids"][-1] != tokenizer.eos_token_id
-        and len(result["input_ids"]) < CUTOFF_LEN
-        and add_eos_token
-    ):
-        result["input_ids"].append(tokenizer.eos_token_id)
-        result["attention_mask"].append(1)
-
-    result["labels"] = result["input_ids"].copy()
-
-    return result
-
-
-def generate_and_tokenize_prompt(data_point):
-    full_prompt = generate_prompt(data_point)
-    tokenized_full_prompt = tokenize(full_prompt)
-    if not TRAIN_ON_INPUTS:
-        user_prompt = generate_prompt({**data_point, "output": ""})
-        tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
-        user_prompt_len = len(tokenized_user_prompt["input_ids"])
-
-        tokenized_full_prompt["labels"] = [
-            -100
-        ] * user_prompt_len + tokenized_full_prompt["labels"][
-            user_prompt_len:
-        ]  # could be sped up, probably
-    return tokenized_full_prompt
-
-
-if VAL_SET_SIZE > 0:
-    train_val = data["train"].train_test_split(
-        test_size=VAL_SET_SIZE, shuffle=True, seed=42
-    )
-    train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
-    val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
-else:
-    train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
-    val_data = None
-
-trainer = transformers.Trainer(
-    model=model,
-    train_dataset=train_data,
-    eval_dataset=val_data,
-    args=transformers.TrainingArguments(
-        per_device_train_batch_size=MICRO_BATCH_SIZE,
-        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
-        warmup_steps=100,
-        num_train_epochs=EPOCHS,
-        learning_rate=LEARNING_RATE,
-        fp16=True,
-        logging_steps=10,
-        evaluation_strategy="steps" if VAL_SET_SIZE > 0 else "no",
-        save_strategy="steps",
-        eval_steps=200 if VAL_SET_SIZE > 0 else None,
-        save_steps=200,
-        output_dir=OUTPUT_DIR,
-        save_total_limit=3,
-        load_best_model_at_end=True if VAL_SET_SIZE > 0 else False,
-        ddp_find_unused_parameters=False if ddp else None,
-        group_by_length=GROUP_BY_LENGTH,
-    ),
-    data_collator=transformers.DataCollatorForSeq2Seq(
-        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
-    ),
-)
-model.config.use_cache = False
-
-old_state_dict = model.state_dict
-model.state_dict = (
-    lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
-).__get__(model, type(model))
-
-if torch.__version__ >= "2" and sys.platform != "win32":
-    model = torch.compile(model)
-
-trainer.train()
-
-model.save_pretrained(OUTPUT_DIR)
-
-print("\n If there's a warning about missing keys above, please disregard :)")
+if __name__ == "__main__":
+    fire.Fire(train)

+ 139 - 133
generate.py

@@ -1,4 +1,6 @@
 import sys
+
+import fire
 import torch
 from peft import PeftModel
 import transformers
@@ -9,16 +11,6 @@ assert (
 ), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
 from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
 
-LOAD_8BIT = False
-BASE_MODEL = None
-LORA_WEIGHTS = "tloen/alpaca-lora-7b"
-
-tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
-
-assert (
-    BASE_MODEL
-), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
-
 if torch.cuda.is_available():
     device = "cuda"
 else:
@@ -30,44 +22,140 @@ try:
 except:
     pass
 
-if device == "cuda":
-    model = LlamaForCausalLM.from_pretrained(
-        BASE_MODEL,
-        load_in_8bit=LOAD_8BIT,
-        torch_dtype=torch.float16,
-        device_map="auto",
-    )
-    model = PeftModel.from_pretrained(
-        model,
-        LORA_WEIGHTS,
-        torch_dtype=torch.float16,
-    )
-elif device == "mps":
-    model = LlamaForCausalLM.from_pretrained(
-        BASE_MODEL,
-        device_map={"": device},
-        torch_dtype=torch.float16,
-    )
-    model = PeftModel.from_pretrained(
-        model,
-        LORA_WEIGHTS,
-        device_map={"": device},
-        torch_dtype=torch.float16,
-    )
-else:
-    model = LlamaForCausalLM.from_pretrained(
-        BASE_MODEL, device_map={"": device}, low_cpu_mem_usage=True
-    )
-    model = PeftModel.from_pretrained(
-        model,
-        LORA_WEIGHTS,
-        device_map={"": device},
+
+def main(
+    load_8bit: bool = False,
+    base_model: str = "",
+    lora_weights: str = "tloen/alpaca-lora-7b",
+):
+    assert base_model, (
+        "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
     )
 
-# unwind broken decapoda-research config
-model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
-model.config.bos_token_id = 1
-model.config.eos_token_id = 2
+    tokenizer = LlamaTokenizer.from_pretrained(base_model)
+    if device == "cuda":
+        model = LlamaForCausalLM.from_pretrained(
+            base_model,
+            load_in_8bit=load_8bit,
+            torch_dtype=torch.float16,
+            device_map="auto",
+        )
+        model = PeftModel.from_pretrained(
+            model,
+            lora_weights,
+            torch_dtype=torch.float16,
+        )
+    elif device == "mps":
+        model = LlamaForCausalLM.from_pretrained(
+            base_model,
+            device_map={"": device},
+            torch_dtype=torch.float16,
+        )
+        model = PeftModel.from_pretrained(
+            model,
+            lora_weights,
+            device_map={"": device},
+            torch_dtype=torch.float16,
+        )
+    else:
+        model = LlamaForCausalLM.from_pretrained(
+            base_model, device_map={"": device}, low_cpu_mem_usage=True
+        )
+        model = PeftModel.from_pretrained(
+            model,
+            lora_weights,
+            device_map={"": device},
+        )
+
+    # unwind broken decapoda-research config
+    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
+    model.config.bos_token_id = 1
+    model.config.eos_token_id = 2
+
+    if not load_8bit:
+        model.half()  # seems to fix bugs for some users.
+
+    model.eval()
+    if torch.__version__ >= "2" and sys.platform != "win32":
+        model = torch.compile(model)
+
+    def evaluate(
+        instruction,
+        input=None,
+        temperature=0.1,
+        top_p=0.75,
+        top_k=40,
+        num_beams=4,
+        max_new_tokens=128,
+        **kwargs,
+    ):
+        prompt = generate_prompt(instruction, input)
+        inputs = tokenizer(prompt, return_tensors="pt")
+        input_ids = inputs["input_ids"].to(device)
+        generation_config = GenerationConfig(
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            num_beams=num_beams,
+            **kwargs,
+        )
+        with torch.no_grad():
+            generation_output = model.generate(
+                input_ids=input_ids,
+                generation_config=generation_config,
+                return_dict_in_generate=True,
+                output_scores=True,
+                max_new_tokens=max_new_tokens,
+            )
+        s = generation_output.sequences[0]
+        output = tokenizer.decode(s)
+        return output.split("### Response:")[1].strip()
+
+    gr.Interface(
+        fn=evaluate,
+        inputs=[
+            gr.components.Textbox(
+                lines=2, label="Instruction", placeholder="Tell me about alpacas."
+            ),
+            gr.components.Textbox(lines=2, label="Input", placeholder="none"),
+            gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
+            gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
+            gr.components.Slider(
+                minimum=0, maximum=100, step=1, value=40, label="Top k"
+            ),
+            gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
+            gr.components.Slider(
+                minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
+            ),
+        ],
+        outputs=[
+            gr.inputs.Textbox(
+                lines=5,
+                label="Output",
+            )
+        ],
+        title="🦙🌲 Alpaca-LoRA",
+        description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
+    ).launch()
+    # Old testing code follows.
+
+    """
+    # testing code for readme
+    for instruction in [
+        "Tell me about alpacas.",
+        "Tell me about the president of Mexico in 2019.",
+        "Tell me about the king of France in 2019.",
+        "List all Canadian provinces in alphabetical order.",
+        "Write a Python program that prints the first 10 Fibonacci numbers.",
+        "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
+        "Tell me five words that rhyme with 'shock'.",
+        "Translate the sentence 'I have no mouth but I must scream' into Spanish.",
+        "Count up from 1 to 500.",
+    ]:
+        print("Instruction:", instruction)
+        print("Response:", evaluate(instruction))
+        print()
+    """
 
 
 def generate_prompt(instruction, input=None):
@@ -80,99 +168,17 @@ def generate_prompt(instruction, input=None):
 ### Input:
 {input}
 
-### Response:"""
+### Response:
+"""
     else:
         return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
 
 ### Instruction:
 {instruction}
 
-### Response:"""
-
-
-if not LOAD_8BIT:
-    model.half()  # seems to fix bugs for some users.
-
-model.eval()
-if torch.__version__ >= "2" and sys.platform != "win32":
-    model = torch.compile(model)
-
-
-def evaluate(
-    instruction,
-    input=None,
-    temperature=0.1,
-    top_p=0.75,
-    top_k=40,
-    num_beams=4,
-    max_new_tokens=128,
-    **kwargs,
-):
-    prompt = generate_prompt(instruction, input)
-    inputs = tokenizer(prompt, return_tensors="pt")
-    input_ids = inputs["input_ids"].to(device)
-    generation_config = GenerationConfig(
-        temperature=temperature,
-        top_p=top_p,
-        top_k=top_k,
-        num_beams=num_beams,
-        **kwargs,
-    )
-    with torch.no_grad():
-        generation_output = model.generate(
-            input_ids=input_ids,
-            generation_config=generation_config,
-            return_dict_in_generate=True,
-            output_scores=True,
-            max_new_tokens=max_new_tokens,
-        )
-    s = generation_output.sequences[0]
-    output = tokenizer.decode(s)
-    return output.split("### Response:")[1].strip()
-
-
-gr.Interface(
-    fn=evaluate,
-    inputs=[
-        gr.components.Textbox(
-            lines=2, label="Instruction", placeholder="Tell me about alpacas."
-        ),
-        gr.components.Textbox(lines=2, label="Input", placeholder="none"),
-        gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
-        gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
-        gr.components.Slider(minimum=0, maximum=100, step=1, value=40, label="Top k"),
-        gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
-        gr.components.Slider(
-            minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
-        ),
-    ],
-    outputs=[
-        gr.inputs.Textbox(
-            lines=5,
-            label="Output",
-        )
-    ],
-    title="🦙🌲 Alpaca-LoRA",
-    description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
-).launch()
+### Response:
+"""
 
-# Old testing code follows.
 
-"""
 if __name__ == "__main__":
-    # testing code for readme
-    for instruction in [
-        "Tell me about alpacas.",
-        "Tell me about the president of Mexico in 2019.",
-        "Tell me about the king of France in 2019.",
-        "List all Canadian provinces in alphabetical order.",
-        "Write a Python program that prints the first 10 Fibonacci numbers.",
-        "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
-        "Tell me five words that rhyme with 'shock'.",
-        "Translate the sentence 'I have no mouth but I must scream' into Spanish.",
-        "Count up from 1 to 500.",
-    ]:
-        print("Instruction:", instruction)
-        print("Response:", evaluate(instruction))
-        print()
-"""
+    fire.Fire(main)

+ 1 - 0
requirements.txt

@@ -7,3 +7,4 @@ bitsandbytes
 git+https://github.com/huggingface/peft.git
 gradio
 appdirs
+fire