소스 검색

Support streaming output on generate (#263)

Pokai Chang 3 년 전
부모
커밋
e2ed209d3b
3개의 변경된 파일128개의 추가작업 그리고 3개의 파일을 삭제
  1. 46 2
      generate.py
  2. 7 1
      utils/README.md
  3. 75 0
      utils/callbacks.py

+ 46 - 2
generate.py

@@ -8,6 +8,7 @@ import transformers
 from peft import PeftModel
 from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
 
+from utils.callbacks import Iteratorize, Stream
 from utils.prompter import Prompter
 
 if torch.cuda.is_available():
@@ -91,6 +92,7 @@ def main(
         top_k=40,
         num_beams=4,
         max_new_tokens=128,
+        stream_output=False,
         **kwargs,
     ):
         prompt = prompter.generate_prompt(instruction, input)
@@ -103,6 +105,47 @@ def main(
             num_beams=num_beams,
             **kwargs,
         )
+
+        generate_params = {
+            "input_ids": input_ids,
+            "generation_config": generation_config,
+            "return_dict_in_generate": True,
+            "output_scores": True,
+            "max_new_tokens": max_new_tokens,
+        }
+
+        if stream_output:
+            # Stream the reply 1 token at a time.
+            # This is based on the trick of using 'stopping_criteria' to create an iterator,
+            # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
+
+            def generate_with_callback(callback=None, **kwargs):
+                kwargs.setdefault(
+                    "stopping_criteria", transformers.StoppingCriteriaList()
+                )
+                kwargs["stopping_criteria"].append(
+                    Stream(callback_func=callback)
+                )
+                with torch.no_grad():
+                    model.generate(**kwargs)
+
+            def generate_with_streaming(**kwargs):
+                return Iteratorize(
+                    generate_with_callback, kwargs, callback=None
+                )
+
+            with generate_with_streaming(**generate_params) as generator:
+                for output in generator:
+                    # new_tokens = len(output) - len(input_ids[0])
+                    decoded_output = tokenizer.decode(output)
+
+                    if output[-1] in [tokenizer.eos_token_id]:
+                        break
+
+                    yield prompter.get_response(decoded_output)
+            return  # early return for stream_output
+
+        # Without streaming
         with torch.no_grad():
             generation_output = model.generate(
                 input_ids=input_ids,
@@ -113,7 +156,7 @@ def main(
             )
         s = generation_output.sequences[0]
         output = tokenizer.decode(s)
-        return prompter.get_response(output)
+        yield prompter.get_response(output)
 
     gr.Interface(
         fn=evaluate,
@@ -139,6 +182,7 @@ def main(
             gr.components.Slider(
                 minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
             ),
+            gr.components.Checkbox(label="Stream output"),
         ],
         outputs=[
             gr.inputs.Textbox(
@@ -148,7 +192,7 @@ def main(
         ],
         title="🦙🌲 Alpaca-LoRA",
         description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).",  # noqa: E501
-    ).launch(server_name="0.0.0.0", share=share_gradio)
+    ).queue().launch(server_name="0.0.0.0", share=share_gradio)
     # Old testing code follows.
 
     """

+ 7 - 1
utils/README.md

@@ -4,4 +4,10 @@
 
 Prompter class, a template manager.
 
-`from utils.prompter import Prompter`
+`from utils.prompter import Prompter`
+
+## callbacks.py
+
+Helpers to support streaming generate output.
+
+`from utils.callbacks import Iteratorize, Stream`

+ 75 - 0
utils/callbacks.py

@@ -0,0 +1,75 @@
+"""
+Helpers to support streaming generate output.
+Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
+"""
+
+import gc
+import traceback
+from queue import Queue
+from threading import Thread
+
+import torch
+import transformers
+
+
+class Stream(transformers.StoppingCriteria):
+    def __init__(self, callback_func=None):
+        self.callback_func = callback_func
+
+    def __call__(self, input_ids, scores) -> bool:
+        if self.callback_func is not None:
+            self.callback_func(input_ids[0])
+        return False
+
+
+class Iteratorize:
+
+    """
+    Transforms a function that takes a callback
+    into a lazy iterator (generator).
+    """
+
+    def __init__(self, func, kwargs={}, callback=None):
+        self.mfunc = func
+        self.c_callback = callback
+        self.q = Queue()
+        self.sentinel = object()
+        self.kwargs = kwargs
+        self.stop_now = False
+
+        def _callback(val):
+            if self.stop_now:
+                raise ValueError
+            self.q.put(val)
+
+        def gentask():
+            try:
+                ret = self.mfunc(callback=_callback, **self.kwargs)
+            except ValueError:
+                pass
+            except:
+                traceback.print_exc()
+                pass
+
+            self.q.put(self.sentinel)
+            if self.c_callback:
+                self.c_callback(ret)
+
+        self.thread = Thread(target=gentask)
+        self.thread.start()
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        obj = self.q.get(True, None)
+        if obj is self.sentinel:
+            raise StopIteration
+        else:
+            return obj
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.stop_now = True