Add streaming-llm using llama2 on CPU (#9265)

Enable streaming-llm to let model take infinite inputs, tested on desktop and SPR10
This commit is contained in:
Guoqiong Song 2023-10-27 01:30:39 -07:00 committed by GitHub
parent 21631209a9
commit aa319de5e8
9 changed files with 1144 additions and 0 deletions

View file

@ -0,0 +1,40 @@
# Low-Bit Streaming LLM using BigDL-LLM
In this example, we apply low-bit optimizations to [Streaming-LLM](https://github.com/mit-han-lab/streaming-llm/tree/main#efficient-streaming-language-models-with-attention-sinks) using BigDL-LLM, which can deploy low-bit(including FP4/INT4/FP8/INT8) LLMs for infinite-length inputs.
Only one code change is needed to load the model using bigdl-llm as follows:
```python
from bigdl.llm.transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, trust_remote_code=True, optimize_model=False)
```
## Prepare Environment
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all]
```
## Run Example
```bash
python ./run_streaming_llama.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --enable-streaming
```
arguments info:
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'meta-llama/Llama-2-7b-chat-hf' by default.
- `--data-root`: str value, the directory to save downloaded questions data.
- `--enable-streaming`: to enable efficient streaming while computing.
- `--start-size`: int value, the start size of recent KV cache.
- `--recent-size`: optional str value. The path to load low-bit model.
## Sample Output for Inference
### 'decapoda-research/llama-7b-hf' Model
```log
USER: Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.
ASSISTANT: Dear Mr. Smith,
I am writing to seek your feedback on the 'Quarterly Financial Report' I prepared for the company. I have attached the report for your reference.
The report contains data analysis of the company's performance during the quarter ending 31st March 2019...
```

View file

@ -0,0 +1,156 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/mit-han-lab/streaming-llm/blob/main/examples/run_streaming_llama.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import warnings
import torch
import argparse
import os
from streaming_llm.utils import load, download_url, load_jsonl
from streaming_llm.enable_streaming_llm import enable_streaming_llm
warnings.filterwarnings("ignore")
@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
pos = 0
for _ in range(max_gen_len - 1):
outputs = model(
input_ids=pred_token_idx,
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids.append(pred_token_idx.item())
generated_text = (
tokenizer.decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
spaces_between_special_tokens=False,
)
.strip()
.split(" ")
)
now = len(generated_text) - 1
if now > pos:
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
pos = now
if pred_token_idx == tokenizer.eos_token_id:
break
print(" ".join(generated_text[pos:]), flush=True)
return past_key_values
@torch.no_grad()
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
past_key_values = None
for idx, prompt in enumerate(prompts):
prompt = "USER: " + prompt + "\n\nASSISTANT: "
print("\n" + prompt, end="")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
seq_len = input_ids.shape[1]
if kv_cache is not None:
space_needed = seq_len + max_gen_len
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
past_key_values = greedy_generate(
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
)
def main(args):
model, tokenizer = load(args.repo_id_or_model_path)
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
print(f"Loading data from {test_filepath} ...")
if not os.path.exists(test_filepath):
download_url(
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
args.data_root,
)
os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath)
list_data = load_jsonl(test_filepath)
prompts = []
for sample in list_data[1:5]:
prompts += sample["turns"]
if args.enable_streaming:
kv_cache = enable_streaming_llm(
model, start_size=args.start_size, recent_size=args.recent_size
)
else:
kv_cache = None
streaming_inference(
model,
tokenizer,
prompts,
kv_cache,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id-or-model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
)
parser.add_argument("--data-root", type=str, default="data/")
parser.add_argument("--enable-streaming", action="store_true")
parser.add_argument("--start-size", type=int, default=4)
parser.add_argument("--recent-size", type=int, default=2000)
args = parser.parse_args()
main(args)

View file

@ -0,0 +1,22 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.

View file

@ -0,0 +1,81 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is copied from
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/enable_streaming_llm.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from .kv_cache import StartRecentKVCache
def enable_streaming_llm(model, start_size, recent_size):
if "llama" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from .modify_llama import (
enable_llama_pos_shift_attention,
)
enable_llama_pos_shift_attention(model)
elif "mpt" in model.config.model_type:
v_seq_dim = 2
k_seq_dim = 3
elif "gpt_neox" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
from .modify_gpt_neox import (
enable_gpt_neox_pos_shift_attention,
)
enable_gpt_neox_pos_shift_attention(model)
elif "falcon" in model.config.model_type:
v_seq_dim = 1
k_seq_dim = 1
from .modify_falcon import (
enable_falcon_pos_shift_attention,
)
enable_falcon_pos_shift_attention(model)
else:
raise ValueError(f"got {model.config.model_type}")
kv_cache = StartRecentKVCache(
start_size=start_size,
recent_size=recent_size,
k_seq_dim=k_seq_dim,
v_seq_dim=v_seq_dim,
)
return kv_cache

View file

@ -0,0 +1,162 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is copied from
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/kv_cache.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
def slice2d(x, start, end):
return x[:, :, start:end, ...]
def slice3d(x, start, end):
return x[:, :, :, start:end, ...]
def slice1d(x, start, end):
return x[:, start:end, ...]
DIM_TO_SLICE = {
1: slice1d,
2: slice2d,
3: slice3d,
}
class StartRecentKVCache:
def __init__(
self,
start_size=4,
recent_size=512,
k_seq_dim=2,
v_seq_dim=2,
):
print(f"StartRecentKVCache: {start_size}, {recent_size}")
self.start_size = start_size
self.recent_size = recent_size
self.cache_size = start_size + recent_size
self.k_seq_dim = k_seq_dim
self.v_seq_dim = v_seq_dim
self.k_slice = DIM_TO_SLICE[k_seq_dim]
self.v_slice = DIM_TO_SLICE[v_seq_dim]
def __call__(self, past_key_values):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
if seq_len <= self.cache_size:
return past_key_values
return [
[
torch.cat(
[
self.k_slice(k, 0, self.start_size),
self.k_slice(k, seq_len - self.recent_size, seq_len),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, self.start_size),
self.v_slice(v, seq_len - self.recent_size, seq_len),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]
def evict_for_space(self, past_key_values, num_coming):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
if seq_len + num_coming <= self.cache_size:
return past_key_values
return [
[
torch.cat(
[
self.k_slice(k, 0, self.start_size),
self.k_slice(
k, seq_len - self.recent_size + num_coming, seq_len
),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, self.start_size),
self.v_slice(
v, seq_len - self.recent_size + num_coming, seq_len
),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]
def evict_range(self, past_key_values, start, end):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
assert start <= end and end <= seq_len
return [
[
torch.cat(
[
self.k_slice(k, 0, start),
self.k_slice(k, end, seq_len),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, start),
self.v_slice(v, end, seq_len),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]

View file

@ -0,0 +1,200 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is copied from
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_falcon.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
import types
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
rotate_half,
)
__all__ = ["enable_falcon_pos_shift_attention"]
def falcon_pos_shift_attention_forward(
self,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, q_length, self.head_dim
)
# dirty hack to fix the inconsistency between falcon-40b and falcon-7b
num_kv = self.num_heads if self.num_heads == 128 else self.num_kv
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv,
q_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_kv, q_length, self.head_dim
)
past_len = 0
if layer_past is not None:
past_len = layer_past[0].shape[1]
query_layer_copy = query_layer.clone()
query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
key_layer_copy = key_layer.clone()
_, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0)
_, kv_length, _ = key_layer.shape
if alibi is None:
query_layer_ = query_layer.reshape(
batch_size, self.num_heads, -1, self.head_dim
)
key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim)
if layer_past is not None:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False
)
else:
attn_output = F.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
)
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
x = x.permute(0, 2, 1, 3)
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
output_tensor = self.dense(attn_output)
outputs = (output_tensor, present)
assert not output_attentions # not supported.
return outputs
else:
attention_mask_float = (
(attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
)
matmul_result = query_layer @ key_layer.transpose(-1, -2)
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(
batch_size, self.num_heads, q_length, kv_length
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
attention_scores = attention_scores.to(torch.float32)
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
* self.inv_norm_factor
+ attention_mask_float,
dim=-1,
dtype=hidden_states.dtype,
)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(
batch_size * self.num_heads, q_length, kv_length
)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = attention_probs_reshaped @ value_layer
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = self._merge_heads(context_layer)
output_tensor = self.dense(context_layer)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
def enable_falcon_pos_shift_attention(model):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
enable_falcon_pos_shift_attention(
module,
)
if "self_attention" == name[-14:]:
model._modules[name].forward = types.MethodType(
falcon_pos_shift_attention_forward, model._modules[name]
)

View file

@ -0,0 +1,147 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is copied from
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_gpt_neox.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
from transformers.models.gpt_neox.modeling_gpt_neox import (
apply_rotary_pos_emb,
rotate_half,
GPTNeoXAttention,
)
import types
__all__ = ["enable_gpt_neox_pos_shift_attention"]
def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
def gpt_neox_pos_shift_attention_forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
key_position_ids = torch.arange(seq_len, device=position_ids.device).unsqueeze(0)
key = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids)
key = torch.cat((key, key_pass), dim=-1)
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
def enable_gpt_neox_pos_shift_attention(model):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
enable_gpt_neox_pos_shift_attention(
module,
)
if isinstance(module, GPTNeoXAttention):
module.forward = types.MethodType(
gpt_neox_pos_shift_attention_forward, module
)

View file

@ -0,0 +1,214 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is copied from
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_llama.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import Optional, Tuple
import torch
from torch import nn
import torch.utils.checkpoint
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import (
LlamaAttention,
rotate_half,
apply_rotary_pos_emb,
repeat_kv,
)
import types
__all__ = ["enable_llama_pos_shift_attention"]
def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
x_embed = (x * cos) + (rotate_half(x) * sin)
return x_embed
def llama_pos_shift_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
### Shift Pos: query pos is min(cache_size, idx)
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
###
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
### Shift Pos: key pos is the pos in cache
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
###
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
self.head_dim
)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.config.pretraining_tp, dim=2
)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.config.pretraining_tp, dim=1
)
attn_output = sum(
[
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
]
)
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def enable_llama_pos_shift_attention(model):
for name, module in reversed(model._modules.items()):
if len(list(module.children())) > 0:
enable_llama_pos_shift_attention(
module,
)
if isinstance(module, LlamaAttention):
model._modules[name].forward = types.MethodType(
llama_pos_shift_attention_forward, model._modules[name]
)

View file

@ -0,0 +1,122 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py
# which is licensed under the MIT license:
#
# MIT License
#
# Copyright (c) 2023 MIT HAN Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import argparse
import os.path as osp
import ssl
import urllib.request
import os
import json
# code change to import from bigdl-llm API instead of using transformers API
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer
import intel_extension_for_pytorch as ipex
def load(model_name_or_path):
print(f"Loading model from {model_name_or_path} ...")
# however, tensor parallel for running falcon will occur bugs
tokenizer = LlamaTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
)
# set load_in_4bit=True to get performance boost, set optimize_model=False for now
# TODO align logics of optimize_model and streaming
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True
)
if tokenizer.pad_token_id is None:
if tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.pad_token_id = 0
model.eval()
return model, tokenizer
def download_url(url: str, folder="folder"):
"""
Downloads the content of an url to a folder. Modified from \
https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
Args:
url (string): The url of target file.
folder (string): The target folder.
Returns:
string: File path of downloaded files.
"""
file = url.rpartition("/")[2]
file = file if file[0] == "?" else file.split("?")[0]
path = osp.join(folder, file)
if osp.exists(path):
print(f"File {file} exists, use existing file.")
return path
print(f"Downloading {url}")
os.makedirs(folder, exist_ok=True)
ctx = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=ctx)
with open(path, "wb") as f:
f.write(data.read())
return path
def load_jsonl(
file_path,
):
list_data_dict = []
with open(file_path, "r") as f:
for line in f:
list_data_dict.append(json.loads(line))
return list_data_dict