Sfoglia il codice sorgente

Add HF dataset loading, add linters, pyproject.toml (#175)

* add HF dataset loading, add linters, pyproject.toml

- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md

* restore default settings

* resume_from_checkpoint

Co-authored-by: AngainorDev <[email protected]>

* Print warning on checkpoint not found

* add HF dataset loading, add linters, pyproject.toml

- applied markdownlint
- add black, black[jupyter], isort
- fix noqa codes
- add .github workflow linting
- update README.md

* Default to local copy and update it

* Typo

* Remove duplicate code block

---------

Co-authored-by: Eric Wang <[email protected]>
Co-authored-by: AngainorDev <[email protected]>
claysauruswrecks 3 anni fa
parent
commit
1310547f9f
10 ha cambiato i file con 294 aggiunte e 284 eliminazioni
  1. 33 0
      .github/workflows/lint.yml
  2. 39 37
      README.md
  3. 108 188
      alpaca_data_cleaned.json
  4. 13 8
      export_hf_checkpoint.py
  5. 20 11
      export_state_dict_checkpoint.py
  6. 37 16
      finetune.py
  7. 24 16
      generate.py
  8. 6 2
      lengths.ipynb
  9. 8 0
      pyproject.toml
  10. 6 6
      requirements.txt

+ 33 - 0
.github/workflows/lint.yml

@@ -0,0 +1,33 @@
+name: Lint
+
+on:
+  # Trigger the workflow on push or pull request,
+  # but only for the main branch
+  push:
+    branches:
+      - main
+  pull_request:
+    - main
+
+jobs:
+  run-linters:
+    name: Run linters
+    runs-on: ubuntu-latest
+
+    steps:
+      - name: Check out Git repository
+        uses: actions/checkout@v2
+
+      - name: Set up Python
+        uses: actions/setup-python@v1
+        with:
+          python-version: 3.8
+
+      - name: Install Python dependencies
+        run: pip install black black[jupyter] flake8
+
+      - name: lint isort
+        run: isort --check --diff
+
+      - name: lint black
+        run: black --check --diff

+ 39 - 37
README.md

@@ -1,4 +1,4 @@
-## 🦙🌲🤏 Alpaca-LoRA: Low-Rank LLaMA Instruct-Tuning
+# 🦙🌲🤏 Alpaca-LoRA: Low-Rank LLaMA Instruct-Tuning
 
 - 🤗 **Try the pretrained model out [here](https://huggingface.co/spaces/tloen/alpaca-lora), courtesy of a GPU grant from Huggingface!**
 - Users have created a Discord server for discussion and support [here](https://discord.gg/prbq284xX5)
@@ -15,15 +15,27 @@ as well as Tim Dettmers' [bitsandbytes](https://github.com/TimDettmers/bitsandby
 
 Without hyperparameter tuning, the LoRA model produces outputs comparable to the Stanford Alpaca model. (Please see the outputs included below.) Further tuning might be able to achieve better performance; I invite interested users to give it a try and report their results.
 
-### Setup
+## Setup
 
 1. Install dependencies
 
-```
-pip install -r requirements.txt
-```
+    ```bash
+    pip install -r requirements.txt
+    ```
+
+1. Set environment variables, or modify the files referencing `BASE_MODEL`:
+
+    ```bash
+    # Files referencing `BASE_MODEL`
+    # export_hf_checkpoint.py
+    # export_state_dict_checkpoint.py
+
+    export BASE_MODEL=decapoda-research/llama-7b-hf
+    ```
 
-2. If bitsandbytes doesn't work, [install it from source.](https://github.com/TimDettmers/bitsandbytes/blob/main/compile_from_source.md) Windows users can follow [these instructions](https://github.com/tloen/alpaca-lora/issues/17).
+    Both `finetune.py` and `generate.py` use `--base_model` flag as shown further below.
+
+1. If bitsandbytes doesn't work, [install it from source.](https://github.com/TimDettmers/bitsandbytes/blob/main/compile_from_source.md) Windows users can follow [these instructions](https://github.com/tloen/alpaca-lora/issues/17).
 
 ### Training (`finetune.py`)
 
@@ -36,15 +48,16 @@ Example usage:
 ```bash
 python finetune.py \
     --base_model 'decapoda-research/llama-7b-hf' \
-    --data_path './alpaca_data_cleaned.json' \
+    --data_path 'yahma/alpaca-cleaned' \
     --output_dir './lora-alpaca'
 ```
 
 We can also tweak our hyperparameters:
+
 ```bash
 python finetune.py \
     --base_model 'decapoda-research/llama-7b-hf' \
-    --data_path './alpaca_data_cleaned.json' \
+    --data_path 'yahma/alpaca-cleaned' \
     --output_dir './lora-alpaca' \
     --batch_size 128 \
     --micro_batch_size 4 \
@@ -81,17 +94,6 @@ They should help users
 who want to run inference in projects like [llama.cpp](https://github.com/ggerganov/llama.cpp)
 or [alpaca.cpp](https://github.com/antimatter15/alpaca.cpp).
 
-### Dataset
-
-In addition to `alpaca_data.json`, which contains the original Stanford Alpaca dataset,
-we also include `alpaca_data_cleaned.json`, which has been [stripped of various tokenization artifacts](https://github.com/tloen/alpaca-lora/pull/32)
-with the help of @gururise.
-This file is now used by default in the training script.
-
-@AndriyMulyar has also provided interactive, embedding-based visualizations of the original dataset's [instructions](https://atlas.nomic.ai/map/alpaca_instructions)
-and [outputs](https://atlas.nomic.ai/map/alpaca_outputs),
-as well as [clusters of bad examples](https://atlas.nomic.ai/map/d2139cc3-bc1c-441c-8d6f-3e6ffbbc2eda/838019ff-8fe2-42ba-809a-d86d2b98cd50/-18.11668742841587/-11.348087116836096/-20.88850316347706/-17.680468640801223/774455612).
-
 ### Notes
 
 - We can likely improve our model performance significantly if we had a better dataset. Consider supporting the [LAION Open Assistant](https://open-assistant.io/) effort to produce a high-quality dataset for supervised fine-tuning (or bugging them to release their data).
@@ -105,26 +107,26 @@ as well as [clusters of bad examples](https://atlas.nomic.ai/map/d2139cc3-bc1c-4
 - [AlpacaDataCleaned](https://github.com/gururise/AlpacaDataCleaned), a project to improve the quality of the Alpaca dataset
 - Various adapter weights (download at own risk):
   - 7B:
-    - https://huggingface.co/tloen/alpaca-lora-7b
-    - https://huggingface.co/samwit/alpaca7B-lora
-    - 🇧🇷 https://huggingface.co/22h/cabrita-lora-v0-1
-    - 🇨🇳 https://huggingface.co/qychen/luotuo-lora-7b-0.1
-    - 🇯🇵 https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-7b-v0
-    - 🇫🇷 https://huggingface.co/bofenghuang/vigogne-lora-7b
-    - 🇹🇭 https://huggingface.co/Thaweewat/thai-buffala-lora-7b-v0-1
-    - 🇩🇪 https://huggingface.co/thisserand/alpaca_lora_german
-    - 🇮🇹 https://huggingface.co/teelinsan/camoscio-7b-llama
+    - <https://huggingface.co/tloen/alpaca-lora-7b>
+    - <https://huggingface.co/samwit/alpaca7B-lora>
+    - 🇧🇷 <https://huggingface.co/22h/cabrita-lora-v0-1>
+    - 🇨🇳 <https://huggingface.co/qychen/luotuo-lora-7b-0.1>
+    - 🇯🇵 <https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-7b-v0>
+    - 🇫🇷 <https://huggingface.co/bofenghuang/vigogne-lora-7b>
+    - 🇹🇭 <https://huggingface.co/Thaweewat/thai-buffala-lora-7b-v0-1>
+    - 🇩🇪 <https://huggingface.co/thisserand/alpaca_lora_german>
+    - 🇮🇹 <https://huggingface.co/teelinsan/camoscio-7b-llama>
   - 13B:
-    - https://huggingface.co/chansung/alpaca-lora-13b
-    - https://huggingface.co/mattreid/alpaca-lora-13b
-    - https://huggingface.co/samwit/alpaca13B-lora
-    - 🇯🇵 https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-13b-v0
-    - 🇰🇷 https://huggingface.co/chansung/koalpaca-lora-13b
-    - 🇨🇳 https://huggingface.co/facat/alpaca-lora-cn-13b
+    - <https://huggingface.co/chansung/alpaca-lora-13b>
+    - <https://huggingface.co/mattreid/alpaca-lora-13b>
+    - <https://huggingface.co/samwit/alpaca13B-lora>
+    - 🇯🇵 <https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-13b-v0>
+    - 🇰🇷 <https://huggingface.co/chansung/koalpaca-lora-13b>
+    - 🇨🇳 <https://huggingface.co/facat/alpaca-lora-cn-13b>
   - 30B:
-    - https://huggingface.co/baseten/alpaca-30b
-    - https://huggingface.co/chansung/alpaca-lora-30b
-    - 🇯🇵 https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-30b-v0
+    - <https://huggingface.co/baseten/alpaca-30b>
+    - <https://huggingface.co/chansung/alpaca-lora-30b>
+    - 🇯🇵 <https://huggingface.co/kunishou/Japanese-Alapaca-LoRA-30b-v0>
 - [alpaca-native](https://huggingface.co/chavinlo/alpaca-native), a replication using the original Alpaca code
 
 ### Example outputs

File diff suppressed because it is too large
+ 108 - 188
alpaca_data_cleaned.json


+ 13 - 8
export_hf_checkpoint.py

@@ -1,20 +1,23 @@
 import os
-import json
 
 import torch
-from peft import PeftModel, LoraConfig
-
 import transformers
+from peft import PeftModel
+
+# Unused imports
+# import json
+# from peft import LoraConfig
 
 assert (
     "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
-from transformers import LlamaTokenizer, LlamaForCausalLM
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
+
+from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402
 
-BASE_MODEL = None
+BASE_MODEL = os.environ.get("BASE_MODEL", None)
 assert (
     BASE_MODEL
-), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
+), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`"  # noqa: E501
 
 tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
 
@@ -35,7 +38,9 @@ lora_model = PeftModel.from_pretrained(
     torch_dtype=torch.float16,
 )
 
-lora_weight = lora_model.base_model.model.model.layers[0].self_attn.q_proj.weight
+lora_weight = lora_model.base_model.model.model.layers[
+    0
+].self_attn.q_proj.weight
 
 assert torch.allclose(first_weight_old, first_weight)
 

+ 20 - 11
export_state_dict_checkpoint.py

@@ -1,20 +1,23 @@
-import os
 import json
+import os
 
 import torch
-from peft import PeftModel, LoraConfig
-
 import transformers
 
+# Unused imports
+# from peft import LoraConfig
+from peft import PeftModel
+
 assert (
     "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
-from transformers import LlamaTokenizer, LlamaForCausalLM
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
+
+from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: E402
 
-BASE_MODEL = None
+BASE_MODEL = os.environ.get("BASE_MODEL", None)
 assert (
     BASE_MODEL
-), "Please specify a BASE_MODEL in the script, e.g. 'decapoda-research/llama-7b-hf'"
+), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=decapoda-research/llama-7b-hf`"  # noqa: E501
 
 tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
 
@@ -54,22 +57,28 @@ n_heads = params["n_heads"]
 dim = params["dim"]
 dims_per_head = dim // n_heads
 base = 10000.0
-inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+inv_freq = 1.0 / (
+    base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
+)
 
 
 def permute(w):
     return (
-        w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
+        w.view(n_heads, dim // n_heads // 2, 2, dim)
+        .transpose(1, 2)
+        .reshape(dim, dim)
     )
 
 
 def unpermute(w):
     return (
-        w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
+        w.view(n_heads, 2, dim // n_heads // 2, dim)
+        .transpose(1, 2)
+        .reshape(dim, dim)
     )
 
 
-def translate_state_dict_key(k):
+def translate_state_dict_key(k):  # noqa: C901
     k = k.replace("base_model.model.", "")
     if k == "model.embed_tokens.weight":
         return "tok_embeddings.weight"

+ 37 - 16
finetune.py

@@ -4,22 +4,28 @@ from typing import List
 
 import fire
 import torch
+import transformers
+from datasets import load_dataset
+
+"""
+Unused imports:
 import torch.nn as nn
 import bitsandbytes as bnb
-from datasets import load_dataset
-import transformers
+"""
 
+# Catch when user should re-install transformers library
 assert (
     "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
-from transformers import LlamaForCausalLM, LlamaTokenizer
-from peft import (
-    prepare_model_for_int8_training,
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
+
+from peft import (  # noqa: E402
     LoraConfig,
     get_peft_model,
     get_peft_model_state_dict,
+    prepare_model_for_int8_training,
     set_peft_model_state_dict,
 )
+from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402
 
 
 def train(
@@ -44,7 +50,7 @@ def train(
     ],
     # llm hyperparams
     train_on_inputs: bool = True,  # if False, masks out inputs in loss
-    group_by_length: bool = False,  # faster, but produces an odd training loss curve,
+    group_by_length: bool = False,  # faster, but produces an odd training loss curve
     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
 ):
     print(
@@ -86,7 +92,9 @@ def train(
 
     tokenizer = LlamaTokenizer.from_pretrained(base_model)
 
-    tokenizer.pad_token_id = 0  # unk. we want this to be different from the eos token
+    tokenizer.pad_token_id = (
+        0  # unk. we want this to be different from the eos token
+    )
     tokenizer.padding_side = "left"  # Allow batched inference
 
     def tokenize(prompt, add_eos_token=True):
@@ -138,7 +146,10 @@ def train(
     )
     model = get_peft_model(model, config)
 
-    data = load_dataset("json", data_files=data_path)
+    if data_path.endswith(".json"):  # todo: support jsonl
+        data = load_dataset("json", data_files=data_path)
+    else:
+        data = load_dataset(data_path)
 
     if resume_from_checkpoint:
         # Check the available weights and load them
@@ -149,7 +160,9 @@ def train(
             checkpoint_name = os.path.join(
                 resume_from_checkpoint, "adapter_model.bin"
             )  # only LoRA model - LoRA config above has to fit
-            resume_from_checkpoint = False  # So the trainer won't try loading its state
+            resume_from_checkpoint = (
+                False  # So the trainer won't try loading its state
+            )
         # The two files above have a different name depending on how they were saved, but are actually the same.
         if os.path.exists(checkpoint_name):
             print(f"Restarting from {checkpoint_name}")
@@ -164,8 +177,12 @@ def train(
         train_val = data["train"].train_test_split(
             test_size=val_set_size, shuffle=True, seed=42
         )
-        train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
-        val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
+        train_data = (
+            train_val["train"].shuffle().map(generate_and_tokenize_prompt)
+        )
+        val_data = (
+            train_val["test"].shuffle().map(generate_and_tokenize_prompt)
+        )
     else:
         train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
         val_data = None
@@ -201,7 +218,9 @@ def train(
 
     old_state_dict = model.state_dict
     model.state_dict = (
-        lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
+        lambda self, *_, **__: get_peft_model_state_dict(
+            self, old_state_dict()
+        )
     ).__get__(model, type(model))
 
     if torch.__version__ >= "2" and sys.platform != "win32":
@@ -211,13 +230,15 @@ def train(
 
     model.save_pretrained(output_dir)
 
-    print("\n If there's a warning about missing keys above, please disregard :)")
+    print(
+        "\n If there's a warning about missing keys above, please disregard :)"
+    )
 
 
 def generate_prompt(data_point):
     # sorry about the formatting disaster gotta move fast
     if data_point["input"]:
-        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
 
 ### Instruction:
 {data_point["instruction"]}
@@ -228,7 +249,7 @@ def generate_prompt(data_point):
 ### Response:
 {data_point["output"]}"""
     else:
-        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
+        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  # noqa: E501
 
 ### Instruction:
 {data_point["instruction"]}

+ 24 - 16
generate.py

@@ -1,15 +1,15 @@
 import sys
 
 import fire
+import gradio as gr
 import torch
-from peft import PeftModel
 import transformers
-import gradio as gr
+from peft import PeftModel
 
 assert (
     "LlamaTokenizer" in transformers._import_structure["models.llama"]
-), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
-from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
+), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"  # noqa: E501
+from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
 
 if torch.cuda.is_available():
     device = "cuda"
@@ -19,7 +19,7 @@ else:
 try:
     if torch.backends.mps.is_available():
         device = "mps"
-except:
+except:  # noqa: E722
     pass
 
 
@@ -28,9 +28,9 @@ def main(
     base_model: str = "",
     lora_weights: str = "tloen/alpaca-lora-7b",
 ):
-    assert base_model, (
-        "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
-    )
+    assert (
+        base_model
+    ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
 
     tokenizer = LlamaTokenizer.from_pretrained(base_model)
     if device == "cuda":
@@ -115,15 +115,23 @@ def main(
         fn=evaluate,
         inputs=[
             gr.components.Textbox(
-                lines=2, label="Instruction", placeholder="Tell me about alpacas."
+                lines=2,
+                label="Instruction",
+                placeholder="Tell me about alpacas.",
             ),
             gr.components.Textbox(lines=2, label="Input", placeholder="none"),
-            gr.components.Slider(minimum=0, maximum=1, value=0.1, label="Temperature"),
-            gr.components.Slider(minimum=0, maximum=1, value=0.75, label="Top p"),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=0.1, label="Temperature"
+            ),
+            gr.components.Slider(
+                minimum=0, maximum=1, value=0.75, label="Top p"
+            ),
             gr.components.Slider(
                 minimum=0, maximum=100, step=1, value=40, label="Top k"
             ),
-            gr.components.Slider(minimum=1, maximum=4, step=1, value=4, label="Beams"),
+            gr.components.Slider(
+                minimum=1, maximum=4, step=1, value=4, label="Beams"
+            ),
             gr.components.Slider(
                 minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
             ),
@@ -135,7 +143,7 @@ def main(
             )
         ],
         title="🦙🌲 Alpaca-LoRA",
-        description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",
+        description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",  # noqa: E501
     ).launch()
     # Old testing code follows.
 
@@ -147,7 +155,7 @@ def main(
         "Tell me about the king of France in 2019.",
         "List all Canadian provinces in alphabetical order.",
         "Write a Python program that prints the first 10 Fibonacci numbers.",
-        "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
+        "Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",  # noqa: E501
         "Tell me five words that rhyme with 'shock'.",
         "Translate the sentence 'I have no mouth but I must scream' into Spanish.",
         "Count up from 1 to 500.",
@@ -160,7 +168,7 @@ def main(
 
 def generate_prompt(instruction, input=None):
     if input:
-        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
 
 ### Instruction:
 {instruction}
@@ -171,7 +179,7 @@ def generate_prompt(instruction, input=None):
 ### Response:
 """
     else:
-        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
+        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.  # noqa: E501
 
 ### Instruction:
 {instruction}

+ 6 - 2
lengths.ipynb

@@ -22,7 +22,9 @@
     "from transformers import LlamaTokenizer\n",
     "\n",
     "\n",
-    "tokenizer = LlamaTokenizer.from_pretrained(\"decapoda-research/llama-7b-hf\", add_eos_token=True)\n",
+    "tokenizer = LlamaTokenizer.from_pretrained(\n",
+    "    \"decapoda-research/llama-7b-hf\", add_eos_token=True\n",
+    ")\n",
     "tokenizer.pad_token = tokenizer.eos_token\n",
     "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
     "\n",
@@ -52,7 +54,9 @@
     "{data_point[\"output\"]}\"\"\"\n",
     "\n",
     "\n",
-    "data = data.map(lambda data_point: {\"prompt\": tokenizer(generate_prompt(data_point))})"
+    "data = data.map(\n",
+    "    lambda data_point: {\"prompt\": tokenizer(generate_prompt(data_point))}\n",
+    ")"
    ]
   },
   {

+ 8 - 0
pyproject.toml

@@ -0,0 +1,8 @@
+[tool.black]
+line-length = 79
+
+[tool.isort]
+include_trailing_comma = true
+line_length = 79
+multi_line_output = 3
+profile = "black"

+ 6 - 6
requirements.txt

@@ -1,10 +1,10 @@
-datasets
-loralib
-sentencepiece
-git+https://github.com/huggingface/transformers.git
 accelerate
+appdirs
 bitsandbytes
+black
+black[jupyter]
+datasets
+fire
 git+https://github.com/huggingface/peft.git
+git+https://github.com/huggingface/transformers.git
 gradio
-appdirs
-fire

Some files were not shown because too many files changed in this diff