Stream llm example for both GPU and CPU (#9390)

This commit is contained in:
Guoqiong Song 2024-02-27 15:54:47 -08:00 committed by GitHub
parent c581c6db30
commit f4a2e32106
5 changed files with 221 additions and 8 deletions

View file

@ -44,9 +44,12 @@
import warnings
import torch
import argparse
import os
from streaming_llm.utils import load, download_url, load_jsonl
from streaming_llm.enable_streaming_llm import enable_streaming_llm
import os, sys
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
stream_llm_src = CURRENT_DIR + "/streaming_llm/"
sys.path.append(stream_llm_src)
from utils import load, download_url, load_jsonl
from enable_streaming_llm import enable_streaming_llm
warnings.filterwarnings("ignore")
@ -99,7 +102,6 @@ def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=10
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
@ -107,11 +109,13 @@ def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=10
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)
def main(args):
model, tokenizer = load(args.repo_id_or_model_path)
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
print(f"Loading data from {test_filepath} ...")

View file

@ -41,13 +41,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from .kv_cache import StartRecentKVCache
from kv_cache import StartRecentKVCache
def enable_streaming_llm(model, start_size, recent_size):
if "llama" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from .modify_llama import (
from modify_llama import (
enable_llama_pos_shift_attention,
)
@ -57,7 +57,7 @@ def enable_streaming_llm(model, start_size, recent_size):
k_seq_dim = 3
elif "gpt_neox" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from .modify_gpt_neox import (
from modify_gpt_neox import (
enable_gpt_neox_pos_shift_attention,
)
@ -65,7 +65,7 @@ def enable_streaming_llm(model, start_size, recent_size):
elif "falcon" in model.config.model_type:
v_seq_dim = 1
k_seq_dim = 1
from .modify_falcon import (
from modify_falcon import (
enable_falcon_pos_shift_attention,
)

View file

@ -72,6 +72,7 @@ def llama_pos_shift_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,

View file

@ -0,0 +1,45 @@
# Low-Bit Streaming LLM using BigDL-LLM
In this example, we apply low-bit optimizations to [Streaming-LLM](https://github.com/mit-han-lab/streaming-llm/tree/main#efficient-streaming-language-models-with-attention-sinks) using BigDL-LLM, which can deploy low-bit(including FP4/INT4/FP8/INT8) LLMs for infinite-length inputs.
Only one code change is needed to load the model using bigdl-llm as follows:
```python
from bigdl.llm.transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, trust_remote_code=True, optimize_model=False)
```
## Prepare Environment
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
pip install -U transformers==4.34.0
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
```
## Configures OneAPI environment variables
```bash
source /opt/intel/oneapi/setvars.sh
```
## Run Example
```bash
python ./run_streaming_llama.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --enable-streaming
```
arguments info:
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'meta-llama/Llama-2-7b-chat-hf' by default.
- `--data-root`: str value, the directory to save downloaded questions data.
- `--enable-streaming`: to enable efficient streaming while computing.
- `--start-size`: int value, the start size of recent KV cache.
- `--recent-size`: optional str value. The path to load low-bit model.
## Sample Output for Inference
### 'decapoda-research/llama-7b-hf' Model
```log
USER: Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.
ASSISTANT: Dear Mr. Smith,
I am writing to seek your feedback on the 'Quarterly Financial Report' I prepared for the company. I have attached the report for your reference.
The report contains data analysis of the company's performance during the quarter ending 31st March 2019...
```

View file

@ -0,0 +1,163 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/mit-han-lab/streaming-llm/blob/main/examples/run_streaming_llama.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import warnings
import torch
import argparse
import os
import os, sys
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
stream_llm_src = CURRENT_DIR.replace("GPU", "CPU") + "/streaming_llm/"
sys.path.append(stream_llm_src)
from utils import load, download_url, load_jsonl
from enable_streaming_llm import enable_streaming_llm
warnings.filterwarnings("ignore")
@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
pos = 0
for _ in range(max_gen_len - 1):
outputs = model(
input_ids=pred_token_idx,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids.append(pred_token_idx.item())
generated_text = (
tokenizer.decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)
now = len(generated_text) - 1
if now > pos:
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now
if pred_token_idx == tokenizer.eos_token_id:
break
print(" ".join(generated_text[pos:]), flush=True)
return past_key_values
@torch.no_grad()
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
past_key_values = None
for idx, prompt in enumerate(prompts):
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)
def main(args):
model, tokenizer = load(args.repo_id_or_model_path)
device = 'xpu'
model = model.to(device)
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
print(f"Loading data from {test_filepath} ...")
if not os.path.exists(test_filepath):
download_url(
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
args.data_root,
)
os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath)
list_data = load_jsonl(test_filepath)
prompts = []
for sample in list_data[1:5]:
prompts += sample["turns"]
if args.enable_streaming:
kv_cache = enable_streaming_llm(
model, start_size=args.start_size, recent_size=args.recent_size
)
else:
kv_cache = None
streaming_inference(
model,
tokenizer,
prompts,
kv_cache,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id-or-model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--data-root", type=str, default="data/")
parser.add_argument("--enable-streaming", action="store_true")
parser.add_argument("--start-size", type=int, default=4)
parser.add_argument("--recent-size", type=int, default=2000)
args = parser.parse_args()
main(args)