Add streaming-llm using llama2 on CPU (#9265)
Enable streaming-llm to let model take infinite inputs, tested on desktop and SPR10
This commit is contained in:
parent
21631209a9
commit
aa319de5e8
9 changed files with 1144 additions and 0 deletions
40
python/llm/example/CPU/applications/streaming-llm/README.md
Normal file
40
python/llm/example/CPU/applications/streaming-llm/README.md
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
# Low-Bit Streaming LLM using BigDL-LLM
|
||||
|
||||
In this example, we apply low-bit optimizations to [Streaming-LLM](https://github.com/mit-han-lab/streaming-llm/tree/main#efficient-streaming-language-models-with-attention-sinks) using BigDL-LLM, which can deploy low-bit(including FP4/INT4/FP8/INT8) LLMs for infinite-length inputs.
|
||||
Only one code change is needed to load the model using bigdl-llm as follows:
|
||||
```python
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, load_in_4bit=True, trust_remote_code=True, optimize_model=False)
|
||||
```
|
||||
|
||||
## Prepare Environment
|
||||
We suggest using conda to manage environment:
|
||||
```bash
|
||||
conda create -n llm python=3.9
|
||||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade bigdl-llm[all]
|
||||
```
|
||||
|
||||
## Run Example
|
||||
```bash
|
||||
python ./run_streaming_llama.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --enable-streaming
|
||||
```
|
||||
arguments info:
|
||||
- `--repo-id-or-model-path`: str value, argument defining the huggingface repo id for the large language model to be downloaded, or the path to the huggingface checkpoint folder, the value is 'meta-llama/Llama-2-7b-chat-hf' by default.
|
||||
- `--data-root`: str value, the directory to save downloaded questions data.
|
||||
- `--enable-streaming`: to enable efficient streaming while computing.
|
||||
- `--start-size`: int value, the start size of recent KV cache.
|
||||
- `--recent-size`: optional str value. The path to load low-bit model.
|
||||
|
||||
|
||||
## Sample Output for Inference
|
||||
### 'decapoda-research/llama-7b-hf' Model
|
||||
```log
|
||||
USER: Draft a professional email seeking your supervisor's feedback on the 'Quarterly Financial Report' you prepared. Ask specifically about the data analysis, presentation style, and the clarity of conclusions drawn. Keep the email short and to the point.
|
||||
|
||||
ASSISTANT: Dear Mr. Smith,
|
||||
|
||||
I am writing to seek your feedback on the 'Quarterly Financial Report' I prepared for the company. I have attached the report for your reference.
|
||||
The report contains data analysis of the company's performance during the quarter ending 31st March 2019...
|
||||
```
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/examples/run_streaming_llama.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
import argparse
|
||||
import os
|
||||
from streaming_llm.utils import load, download_url, load_jsonl
|
||||
from streaming_llm.enable_streaming_llm import enable_streaming_llm
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
|
||||
outputs = model(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
|
||||
generated_ids = [pred_token_idx.item()]
|
||||
pos = 0
|
||||
for _ in range(max_gen_len - 1):
|
||||
outputs = model(
|
||||
input_ids=pred_token_idx,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
past_key_values = outputs.past_key_values
|
||||
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
|
||||
generated_ids.append(pred_token_idx.item())
|
||||
generated_text = (
|
||||
tokenizer.decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
.strip()
|
||||
.split(" ")
|
||||
)
|
||||
|
||||
now = len(generated_text) - 1
|
||||
if now > pos:
|
||||
print(" ".join(generated_text[pos:now]), end=" ", flush=True)
|
||||
pos = now
|
||||
|
||||
if pred_token_idx == tokenizer.eos_token_id:
|
||||
break
|
||||
print(" ".join(generated_text[pos:]), flush=True)
|
||||
return past_key_values
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
|
||||
past_key_values = None
|
||||
for idx, prompt in enumerate(prompts):
|
||||
prompt = "USER: " + prompt + "\n\nASSISTANT: "
|
||||
print("\n" + prompt, end="")
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(model.device)
|
||||
seq_len = input_ids.shape[1]
|
||||
if kv_cache is not None:
|
||||
space_needed = seq_len + max_gen_len
|
||||
past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
|
||||
|
||||
past_key_values = greedy_generate(
|
||||
model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
model, tokenizer = load(args.repo_id_or_model_path)
|
||||
test_filepath = os.path.join(args.data_root, "mt_bench.jsonl")
|
||||
print(f"Loading data from {test_filepath} ...")
|
||||
|
||||
if not os.path.exists(test_filepath):
|
||||
download_url(
|
||||
"https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
|
||||
args.data_root,
|
||||
)
|
||||
os.rename(os.path.join(args.data_root, "question.jsonl"), test_filepath)
|
||||
|
||||
list_data = load_jsonl(test_filepath)
|
||||
prompts = []
|
||||
for sample in list_data[1:5]:
|
||||
prompts += sample["turns"]
|
||||
|
||||
if args.enable_streaming:
|
||||
kv_cache = enable_streaming_llm(
|
||||
model, start_size=args.start_size, recent_size=args.recent_size
|
||||
)
|
||||
else:
|
||||
kv_cache = None
|
||||
|
||||
streaming_inference(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts,
|
||||
kv_cache,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--repo-id-or-model-path", type=str, default="meta-llama/Llama-2-7b-chat-hf"
|
||||
)
|
||||
parser.add_argument("--data-root", type=str, default="data/")
|
||||
parser.add_argument("--enable-streaming", action="store_true")
|
||||
parser.add_argument("--start-size", type=int, default=4)
|
||||
parser.add_argument("--recent-size", type=int, default=2000)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||
# physically located elsewhere.
|
||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/enable_streaming_llm.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from .kv_cache import StartRecentKVCache
|
||||
|
||||
|
||||
def enable_streaming_llm(model, start_size, recent_size):
|
||||
if "llama" in model.config.model_type:
|
||||
k_seq_dim = v_seq_dim = 2
|
||||
from .modify_llama import (
|
||||
enable_llama_pos_shift_attention,
|
||||
)
|
||||
|
||||
enable_llama_pos_shift_attention(model)
|
||||
elif "mpt" in model.config.model_type:
|
||||
v_seq_dim = 2
|
||||
k_seq_dim = 3
|
||||
elif "gpt_neox" in model.config.model_type:
|
||||
k_seq_dim = v_seq_dim = 2
|
||||
from .modify_gpt_neox import (
|
||||
enable_gpt_neox_pos_shift_attention,
|
||||
)
|
||||
|
||||
enable_gpt_neox_pos_shift_attention(model)
|
||||
elif "falcon" in model.config.model_type:
|
||||
v_seq_dim = 1
|
||||
k_seq_dim = 1
|
||||
from .modify_falcon import (
|
||||
enable_falcon_pos_shift_attention,
|
||||
)
|
||||
|
||||
enable_falcon_pos_shift_attention(model)
|
||||
else:
|
||||
raise ValueError(f"got {model.config.model_type}")
|
||||
kv_cache = StartRecentKVCache(
|
||||
start_size=start_size,
|
||||
recent_size=recent_size,
|
||||
k_seq_dim=k_seq_dim,
|
||||
v_seq_dim=v_seq_dim,
|
||||
)
|
||||
return kv_cache
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/kv_cache.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def slice2d(x, start, end):
|
||||
return x[:, :, start:end, ...]
|
||||
|
||||
|
||||
def slice3d(x, start, end):
|
||||
return x[:, :, :, start:end, ...]
|
||||
|
||||
|
||||
def slice1d(x, start, end):
|
||||
return x[:, start:end, ...]
|
||||
|
||||
|
||||
DIM_TO_SLICE = {
|
||||
1: slice1d,
|
||||
2: slice2d,
|
||||
3: slice3d,
|
||||
}
|
||||
|
||||
|
||||
class StartRecentKVCache:
|
||||
def __init__(
|
||||
self,
|
||||
start_size=4,
|
||||
recent_size=512,
|
||||
k_seq_dim=2,
|
||||
v_seq_dim=2,
|
||||
):
|
||||
print(f"StartRecentKVCache: {start_size}, {recent_size}")
|
||||
self.start_size = start_size
|
||||
self.recent_size = recent_size
|
||||
self.cache_size = start_size + recent_size
|
||||
self.k_seq_dim = k_seq_dim
|
||||
self.v_seq_dim = v_seq_dim
|
||||
self.k_slice = DIM_TO_SLICE[k_seq_dim]
|
||||
self.v_slice = DIM_TO_SLICE[v_seq_dim]
|
||||
|
||||
def __call__(self, past_key_values):
|
||||
if past_key_values is None:
|
||||
return None
|
||||
seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
||||
if seq_len <= self.cache_size:
|
||||
return past_key_values
|
||||
return [
|
||||
[
|
||||
torch.cat(
|
||||
[
|
||||
self.k_slice(k, 0, self.start_size),
|
||||
self.k_slice(k, seq_len - self.recent_size, seq_len),
|
||||
],
|
||||
dim=self.k_seq_dim,
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
self.v_slice(v, 0, self.start_size),
|
||||
self.v_slice(v, seq_len - self.recent_size, seq_len),
|
||||
],
|
||||
dim=self.v_seq_dim,
|
||||
),
|
||||
]
|
||||
for k, v in past_key_values
|
||||
]
|
||||
|
||||
def evict_for_space(self, past_key_values, num_coming):
|
||||
if past_key_values is None:
|
||||
return None
|
||||
seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
||||
if seq_len + num_coming <= self.cache_size:
|
||||
return past_key_values
|
||||
return [
|
||||
[
|
||||
torch.cat(
|
||||
[
|
||||
self.k_slice(k, 0, self.start_size),
|
||||
self.k_slice(
|
||||
k, seq_len - self.recent_size + num_coming, seq_len
|
||||
),
|
||||
],
|
||||
dim=self.k_seq_dim,
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
self.v_slice(v, 0, self.start_size),
|
||||
self.v_slice(
|
||||
v, seq_len - self.recent_size + num_coming, seq_len
|
||||
),
|
||||
],
|
||||
dim=self.v_seq_dim,
|
||||
),
|
||||
]
|
||||
for k, v in past_key_values
|
||||
]
|
||||
|
||||
def evict_range(self, past_key_values, start, end):
|
||||
if past_key_values is None:
|
||||
return None
|
||||
seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
||||
assert start <= end and end <= seq_len
|
||||
return [
|
||||
[
|
||||
torch.cat(
|
||||
[
|
||||
self.k_slice(k, 0, start),
|
||||
self.k_slice(k, end, seq_len),
|
||||
],
|
||||
dim=self.k_seq_dim,
|
||||
),
|
||||
torch.cat(
|
||||
[
|
||||
self.v_slice(v, 0, start),
|
||||
self.v_slice(v, end, seq_len),
|
||||
],
|
||||
dim=self.v_seq_dim,
|
||||
),
|
||||
]
|
||||
for k, v in past_key_values
|
||||
]
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_falcon.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
import types
|
||||
from transformers.models.falcon.modeling_falcon import (
|
||||
FalconAttention,
|
||||
rotate_half,
|
||||
)
|
||||
|
||||
__all__ = ["enable_falcon_pos_shift_attention"]
|
||||
|
||||
|
||||
def falcon_pos_shift_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(
|
||||
hidden_states
|
||||
) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(
|
||||
batch_size * self.num_heads, q_length, self.head_dim
|
||||
)
|
||||
|
||||
# dirty hack to fix the inconsistency between falcon-40b and falcon-7b
|
||||
num_kv = self.num_heads if self.num_heads == 128 else self.num_kv
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv,
|
||||
q_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv, q_length, self.head_dim
|
||||
)
|
||||
|
||||
past_len = 0
|
||||
if layer_past is not None:
|
||||
past_len = layer_past[0].shape[1]
|
||||
|
||||
query_layer_copy = query_layer.clone()
|
||||
query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
key_layer_copy = key_layer.clone()
|
||||
_, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0)
|
||||
|
||||
_, kv_length, _ = key_layer.shape
|
||||
|
||||
if alibi is None:
|
||||
query_layer_ = query_layer.reshape(
|
||||
batch_size, self.num_heads, -1, self.head_dim
|
||||
)
|
||||
key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim)
|
||||
value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim)
|
||||
|
||||
if layer_past is not None:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False
|
||||
)
|
||||
else:
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
||||
)
|
||||
|
||||
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
|
||||
|
||||
output_tensor = self.dense(attn_output)
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
assert not output_attentions # not supported.
|
||||
return outputs
|
||||
else:
|
||||
attention_mask_float = (
|
||||
(attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
||||
)
|
||||
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(
|
||||
batch_size, self.num_heads, q_length, kv_length
|
||||
)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
|
||||
attention_scores = attention_scores.to(torch.float32)
|
||||
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(
|
||||
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
|
||||
* self.inv_norm_factor
|
||||
+ attention_mask_float,
|
||||
dim=-1,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(
|
||||
batch_size * self.num_heads, q_length, kv_length
|
||||
)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = attention_probs_reshaped @ value_layer
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def enable_falcon_pos_shift_attention(model):
|
||||
for name, module in reversed(model._modules.items()):
|
||||
if len(list(module.children())) > 0:
|
||||
enable_falcon_pos_shift_attention(
|
||||
module,
|
||||
)
|
||||
|
||||
if "self_attention" == name[-14:]:
|
||||
model._modules[name].forward = types.MethodType(
|
||||
falcon_pos_shift_attention_forward, model._modules[name]
|
||||
)
|
||||
|
|
@ -0,0 +1,147 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_gpt_neox.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers.models.gpt_neox.modeling_gpt_neox import (
|
||||
apply_rotary_pos_emb,
|
||||
rotate_half,
|
||||
GPTNeoXAttention,
|
||||
)
|
||||
import types
|
||||
|
||||
__all__ = ["enable_gpt_neox_pos_shift_attention"]
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
|
||||
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
|
||||
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
|
||||
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed
|
||||
|
||||
|
||||
def gpt_neox_pos_shift_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
):
|
||||
has_layer_past = layer_past is not None
|
||||
|
||||
# Compute QKV
|
||||
# Attention heads [batch, seq_len, hidden_size]
|
||||
# --> [batch, seq_len, (np * 3 * head_size)]
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
|
||||
# [batch, seq_len, (num_heads * 3 * head_size)]
|
||||
# --> [batch, seq_len, num_heads, 3 * head_size]
|
||||
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
|
||||
qkv = qkv.view(*new_qkv_shape)
|
||||
|
||||
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
|
||||
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
|
||||
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
||||
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
||||
|
||||
# Compute rotary embeddings on rotary_ndims
|
||||
query_rot = query[..., : self.rotary_ndims]
|
||||
query_pass = query[..., self.rotary_ndims :]
|
||||
|
||||
# Compute token offset for rotary embeddings (when decoding)
|
||||
seq_len = key.shape[-2]
|
||||
if has_layer_past:
|
||||
seq_len += layer_past[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value, seq_len=seq_len)
|
||||
query = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids)
|
||||
query = torch.cat((query, query_pass), dim=-1)
|
||||
|
||||
# Cache QKV values
|
||||
if has_layer_past:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
present = (key, value) if use_cache else None
|
||||
|
||||
key_rot = key[..., : self.rotary_ndims]
|
||||
key_pass = key[..., self.rotary_ndims :]
|
||||
key_position_ids = torch.arange(seq_len, device=position_ids.device).unsqueeze(0)
|
||||
key = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids)
|
||||
key = torch.cat((key, key_pass), dim=-1)
|
||||
|
||||
# Compute attention
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
# Reshape outputs
|
||||
attn_output = self._merge_heads(
|
||||
attn_output, self.num_attention_heads, self.head_size
|
||||
)
|
||||
attn_output = self.dense(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def enable_gpt_neox_pos_shift_attention(model):
|
||||
for name, module in reversed(model._modules.items()):
|
||||
if len(list(module.children())) > 0:
|
||||
enable_gpt_neox_pos_shift_attention(
|
||||
module,
|
||||
)
|
||||
|
||||
if isinstance(module, GPTNeoXAttention):
|
||||
module.forward = types.MethodType(
|
||||
gpt_neox_pos_shift_attention_forward, module
|
||||
)
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is copied from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/pos_shift/modify_llama.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
rotate_half,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
import types
|
||||
|
||||
__all__ = ["enable_llama_pos_shift_attention"]
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
x_embed = (x * cos) + (rotate_half(x) * sin)
|
||||
return x_embed
|
||||
|
||||
|
||||
def llama_pos_shift_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.config.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i])
|
||||
for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i])
|
||||
for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i])
|
||||
for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
### Shift Pos: query pos is min(cache_size, idx)
|
||||
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
|
||||
###
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
### Shift Pos: key pos is the pos in cache
|
||||
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
|
||||
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
||||
###
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
||||
self.head_dim
|
||||
)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
||||
query_states.dtype
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
if self.config.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(
|
||||
self.hidden_size // self.config.pretraining_tp, dim=2
|
||||
)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.config.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
[
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.config.pretraining_tp)
|
||||
]
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def enable_llama_pos_shift_attention(model):
|
||||
for name, module in reversed(model._modules.items()):
|
||||
if len(list(module.children())) > 0:
|
||||
enable_llama_pos_shift_attention(
|
||||
module,
|
||||
)
|
||||
|
||||
if isinstance(module, LlamaAttention):
|
||||
model._modules[name].forward = types.MethodType(
|
||||
llama_pos_shift_attention_forward, model._modules[name]
|
||||
)
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py
|
||||
# which is licensed under the MIT license:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 MIT HAN Lab
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import ssl
|
||||
import urllib.request
|
||||
import os
|
||||
import json
|
||||
# code change to import from bigdl-llm API instead of using transformers API
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers import LlamaTokenizer
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
def load(model_name_or_path):
|
||||
print(f"Loading model from {model_name_or_path} ...")
|
||||
# however, tensor parallel for running falcon will occur bugs
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# set load_in_4bit=True to get performance boost, set optimize_model=False for now
|
||||
# TODO align logics of optimize_model and streaming
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
load_in_4bit=True,
|
||||
optimize_model=False,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.pad_token_id = 0
|
||||
|
||||
model.eval()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def download_url(url: str, folder="folder"):
|
||||
"""
|
||||
Downloads the content of an url to a folder. Modified from \
|
||||
https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
|
||||
|
||||
Args:
|
||||
url (string): The url of target file.
|
||||
folder (string): The target folder.
|
||||
|
||||
Returns:
|
||||
string: File path of downloaded files.
|
||||
"""
|
||||
|
||||
file = url.rpartition("/")[2]
|
||||
file = file if file[0] == "?" else file.split("?")[0]
|
||||
path = osp.join(folder, file)
|
||||
if osp.exists(path):
|
||||
print(f"File {file} exists, use existing file.")
|
||||
return path
|
||||
|
||||
print(f"Downloading {url}")
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
ctx = ssl._create_unverified_context()
|
||||
data = urllib.request.urlopen(url, context=ctx)
|
||||
with open(path, "wb") as f:
|
||||
f.write(data.read())
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def load_jsonl(
|
||||
file_path,
|
||||
):
|
||||
list_data_dict = []
|
||||
with open(file_path, "r") as f:
|
||||
for line in f:
|
||||
list_data_dict.append(json.loads(line))
|
||||
return list_data_dict
|
||||
Loading…
Reference in a new issue