Stream llm example for both GPU and CPU (#9390)
This commit is contained in:
parent
c581c6db30
commit
f4a2e32106
5 changed files with 221 additions and 8 deletions
|
|
@ -44,9 +44,12 @@
|
|||
import warnings
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
from streaming_llm.utils import load, download_url, load_jsonl
|
||||
from streaming_llm.enable_streaming_llm import enable_streaming_llm
|
||||
import os, sys
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
stream_llm_src = CURRENT_DIR + "/streaming_llm/"
|
||||
sys.path.append(stream_llm_src)
|
||||
from utils import load, download_url, load_jsonl
|
||||
from enable_streaming_llm import enable_streaming_llm
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
|
|
@ -99,7 +102,6 @@ def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=10
|
|||
prompt = "USER: " + prompt + "\n\nASSISTANT: "
|
||||
print("\n" + prompt, end="")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
seq_len = input_ids.shape[1]
|
||||
if kv_cache is not None:
|
||||
space_needed = seq_len + max_gen_len
|
||||
|
|
@ -107,11 +109,13 @@ def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=10
|
|||
|
||||
past_key_values = greedy_generate(
|
||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
|
||||
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
model, tokenizer = load(args.repo_id_or_model_path)
|
||||
|
||||
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
|
||||
print(f"Loading data from {test_filepath} ...")
|
||||
|
||||
|
|
|
|||
|
|
@ -41,13 +41,13 @@
|
|||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from .kv_cache import StartRecentKVCache
|
||||
from kv_cache import StartRecentKVCache
|
||||
|
||||
|
||||
def enable_streaming_llm(model, start_size, recent_size):
|
||||
if "llama" in model.config.model_type:
|
||||
k_seq_dim = v_seq_dim = 2
|
||||
from .modify_llama import (
|
||||
from modify_llama import (
|
||||
enable_llama_pos_shift_attention,
|
||||
)
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ def enable_streaming_llm(model, start_size, recent_size):
|
|||
k_seq_dim = 3
|
||||
elif "gpt_neox" in model.config.model_type:
|
||||
k_seq_dim = v_seq_dim = 2
|
||||
from .modify_gpt_neox import (
|
||||
from modify_gpt_neox import (
|
||||
enable_gpt_neox_pos_shift_attention,
|
||||
)
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ def enable_streaming_llm(model, start_size, recent_size):
|
|||
elif "falcon" in model.config.model_type:
|
||||
v_seq_dim = 1
|
||||
k_seq_dim = 1
|
||||
from .modify_falcon import (
|
||||
from modify_falcon import (
|
||||
enable_falcon_pos_shift_attention,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ def llama_pos_shift_attention_forward(
|
|||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
|
|
|
|||
45
python/llm/example/GPU/Applications/streaming-llm/README.md
Normal file
45
python/llm/example/GPU/Applications/streaming-llm/README.md
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Low-Bit Streaming LLM using BigDL-LLM
|
||||
|
||||
In this example, we apply low-bit optimizations to [Streaming-LLM](https://github.com/mit-han-lab/streaming-llm/tree/main#efficient-streaming-language-models-with-attention-sinks) using BigDL-LLM, which can deploy low-bit(including FP4/INT4/FP8/INT8) LLMs for infinite-length inputs.
|
||||
Only one code change is needed to load the model using bigdl-llm as follows:
|
||||
```python
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, trust_remote_code=True, optimize_model=False)
|
||||
```
|
||||
|
||||
## Prepare Environment
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.9
|
||||
conda activate llm
|
||||
pip install -U transformers==4.34.0
|
||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
```
|
||||
|
||||
## Configures OneAPI environment variables
|
||||
```bash
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
## Run Example
|
||||
```bash
|
||||
python ./run_streaming_llama.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --enable-streaming
|
||||
```
|
||||
arguments info:
|
||||
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'meta-llama/Llama-2-7b-chat-hf' by default.
|
||||
- `--data-root`: str value, the directory to save downloaded questions data.
|
||||
- `--enable-streaming`: to enable efficient streaming while computing.
|
||||
- `--start-size`: int value, the start size of recent KV cache.
|
||||
- `--recent-size`: optional str value. The path to load low-bit model.
|
||||
|
||||
|
||||
## Sample Output for Inference
|
||||
### 'decapoda-research/llama-7b-hf' Model
|
||||
```log
|
||||
USER: Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.
|
||||
|
||||
ASSISTANT: Dear Mr. Smith,
|
||||
|
||||
I am writing to seek your feedback on the 'Quarterly Financial Report' I prepared for the company. I have attached the report for your reference.
|
||||
The report contains data analysis of the company's performance during the quarter ending 31st March 2019...
|
||||
```
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/examples/run_streaming_llama.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
import os, sys
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
stream_llm_src = CURRENT_DIR.replace("GPU", "CPU") + "/streaming_llm/"
|
||||
sys.path.append(stream_llm_src)
|
||||
from utils import load, download_url, load_jsonl
|
||||
from enable_streaming_llm import enable_streaming_llm
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
|
||||
generated_ids = [pred_token_idx.item()]
|
||||
pos = 0
|
||||
for _ in range(max_gen_len - 1):
|
||||
outputs = model(
|
||||
input_ids=pred_token_idx,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
|
||||
generated_ids.append(pred_token_idx.item())
|
||||
generated_text = (
|
||||
tokenizer.decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
.strip()
|
||||
.split(" ")
|
||||
)
|
||||
|
||||
now = len(generated_text) - 1
|
||||
if now > pos:
|
||||
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
|
||||
pos = now
|
||||
|
||||
if pred_token_idx == tokenizer.eos_token_id:
|
||||
break
|
||||
print(" ".join(generated_text[pos:]), flush=True)
|
||||
return past_key_values
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
|
||||
past_key_values = None
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt = "USER: " + prompt + "\n\nASSISTANT: "
|
||||
print("\n" + prompt, end="")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
seq_len = input_ids.shape[1]
|
||||
if kv_cache is not None:
|
||||
space_needed = seq_len + max_gen_len
|
||||
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
|
||||
|
||||
past_key_values = greedy_generate(
|
||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
|
||||
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
model, tokenizer = load(args.repo_id_or_model_path)
|
||||
device = 'xpu'
|
||||
model = model.to(device)
|
||||
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
|
||||
print(f"Loading data from {test_filepath} ...")
|
||||
|
||||
if not os.path.exists(test_filepath):
|
||||
download_url(
|
||||
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
|
||||
args.data_root,
|
||||
)
|
||||
os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath)
|
||||
|
||||
list_data = load_jsonl(test_filepath)
|
||||
prompts = []
|
||||
for sample in list_data[1:5]:
|
||||
prompts += sample["turns"]
|
||||
|
||||
if args.enable_streaming:
|
||||
kv_cache = enable_streaming_llm(
|
||||
model, start_size=args.start_size, recent_size=args.recent_size
|
||||
)
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
streaming_inference(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts,
|
||||
kv_cache,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id-or-model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
parser.add_argument("--data-root", type=str, default="data/")
|
||||
parser.add_argument("--enable-streaming", action="store_true")
|
||||
parser.add_argument("--start-size", type=int, default=4)
|
||||
parser.add_argument("--recent-size", type=int, default=2000)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
Loading…
Reference in a new issue