Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Wang 2023-09-25 13:59:19 +08:00
commit e8f436453d
29 changed files with 471 additions and 130 deletions

View file

@ -22,13 +22,13 @@ Follow [here](https://github.com/kubeflow/mpi-operator/tree/master#installation)
Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster. Follow [here](https://github.com/intel-analytics/BigDL/tree/main/docker/llm/finetune/lora/docker#prepare-bigdl-image-for-lora-finetuning) to prepare BigDL Lora Finetuning image in your cluster.
As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server. In addition, make an empty directory under the same destination to save the finetuned model output later. As finetuning is from a base model, first download [Llama 7b hf model from the public download site of Hugging Face](https://huggingface.co/decapoda-research/llama-7b-hf/tree/main). Then, download [cleaned alpaca data](https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json), which contains all kinds of general knowledge and has already been cleaned. Next, move the downloaded files to a shared directory on your NFS server.
### 3. Deploy through Helm Chart ### 3. Deploy through Helm Chart
You are allowed to edit and experiment with different parameters in `./kubernetes/values.yaml` to improve finetuning performance and accuracy. For example, you can adjust `trainerNum` and `cpuPerPod` according to node and CPU core numbers in your cluster to make full use of these resources, and different `microBatchSize` result in different training speed and loss (here note that `microBatchSize`×`trainerNum` should not more than 128, as it is the batch size). You are allowed to edit and experiment with different parameters in `./kubernetes/values.yaml` to improve finetuning performance and accuracy. For example, you can adjust `trainerNum` and `cpuPerPod` according to node and CPU core numbers in your cluster to make full use of these resources, and different `microBatchSize` result in different training speed and loss (here note that `microBatchSize`×`trainerNum` should not more than 128, as it is the batch size).
**Note: `dataSubPath`, `modelSubPath` and `outputPath` need to have the same names as files under the NFS directory in step 2.** **Note: `dataSubPath` and `modelSubPath` need to have the same names as files under the NFS directory in step 2.**
After preparing parameters in `./kubernetes/values.yaml`, submit the job as beflow: After preparing parameters in `./kubernetes/values.yaml`, submit the job as beflow:
@ -52,7 +52,9 @@ kubectl exec -it <launcher_pod_name> bash -n bigdl-ppml-finetuning # enter launc
cat launcher.log # display logs collected from other workers cat launcher.log # display logs collected from other workers
``` ```
From the log, you can see whether finetuning process has been invoked successfully in all MPI worker pods, and a progress bar with finetuning speed and estimated time will be showed after some data preprocessing steps (this may take quiet a while). For the fine-tuned model, it is written by the worker 0 (who holds rank 0), so you can find the model output inside the pod or the `output` folder under the NFS path (because it has been mounted to worker 0 as output path). From the log, you can see whether finetuning process has been invoked successfully in all MPI worker pods, and a progress bar with finetuning speed and estimated time will be showed after some data preprocessing steps (this may take quiet a while).
For the fine-tuned model, it is written by the worker 0 (who holds rank 0), so you can find the model output inside the pod, which can be saved to host by command tools like `kubectl cp` or `scp`.
## To run in TDX-CoCo and enable Remote Attestation API ## To run in TDX-CoCo and enable Remote Attestation API

View file

@ -8,7 +8,6 @@ if [ "$WORKER_ROLE" = "launcher" ]
then then
sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile sed "s/:1/ /g" /etc/mpi/hostfile > /home/mpiuser/hostfile
export DATA_PATH="/ppml/data/$DATA_SUB_PATH" export DATA_PATH="/ppml/data/$DATA_SUB_PATH"
export SAVE_PATH="/ppml/output"
sleep 10 sleep 10
mpirun \ mpirun \
-n $WORLD_SIZE \ -n $WORLD_SIZE \
@ -22,13 +21,13 @@ then
python /ppml/lora_finetune.py \ python /ppml/lora_finetune.py \
--base_model '/ppml/model/' \ --base_model '/ppml/model/' \
--data_path "$DATA_PATH" \ --data_path "$DATA_PATH" \
--output_dir "$SAVE_PATH/finetuned_model" \ --output_dir "/home/mpiuser/finetuned_model" \
--micro_batch_size $MICRO_BATCH_SIZE \ --micro_batch_size $MICRO_BATCH_SIZE \
--bf16 > $SAVE_PATH/launcher.log 2>&1 --bf16 > /home/mpiuser/launcher.log 2>&1
exit_status=$? exit_status=$?
if [ $exit_status -ne 0 ]; if [ $exit_status -ne 0 ];
then then
cat $SAVE_PATH/launcher.log cat /home/mpiuser/launcher.log
exit $exit_status exit $exit_status
else else
while true while true

View file

@ -51,9 +51,6 @@ spec:
- name: nfs-storage - name: nfs-storage
subPath: {{ .Values.dataSubPath }} subPath: {{ .Values.dataSubPath }}
mountPath: "/ppml/data/{{ .Values.dataSubPath }}" mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
- name: nfs-storage
subPath: {{ .Values.outputSubPath }}
mountPath: "/ppml/output"
Worker: Worker:
replicas: {{ .Values.trainerNum }} replicas: {{ .Values.trainerNum }}
template: template:
@ -86,9 +83,6 @@ spec:
- name: nfs-storage - name: nfs-storage
subPath: {{ .Values.dataSubPath }} subPath: {{ .Values.dataSubPath }}
mountPath: "/ppml/data/{{ .Values.dataSubPath }}" mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
- name: nfs-storage
subPath: {{ .Values.outputSubPath }}
mountPath: "/ppml/output"
resources: resources:
requests: requests:
cpu: {{ .Values.cpuPerPod }} cpu: {{ .Values.cpuPerPod }}
@ -96,4 +90,4 @@ spec:
- name: nfs-storage - name: nfs-storage
persistentVolumeClaim: persistentVolumeClaim:
claimName: nfs-pvc claimName: nfs-pvc
{{- end }} {{- end }}

View file

@ -71,9 +71,6 @@ spec:
- name: nfs-storage - name: nfs-storage
subPath: {{ .Values.dataSubPath }} subPath: {{ .Values.dataSubPath }}
mountPath: "/ppml/data/{{ .Values.dataSubPath }}" mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
- name: nfs-storage
subPath: {{ .Values.outputSubPath }}
mountPath: "/ppml/output"
- name: dev - name: dev
mountPath: /dev mountPath: /dev
{{- if eq .Values.enableTLS true }} {{- if eq .Values.enableTLS true }}
@ -118,9 +115,6 @@ spec:
- name: nfs-storage - name: nfs-storage
subPath: {{ .Values.dataSubPath }} subPath: {{ .Values.dataSubPath }}
mountPath: "/ppml/data/{{ .Values.dataSubPath }}" mountPath: "/ppml/data/{{ .Values.dataSubPath }}"
- name: nfs-storage
subPath: {{ .Values.outputSubPath }}
mountPath: "/ppml/output"
- name: dev - name: dev
mountPath: /dev mountPath: /dev
resources: resources:

View file

@ -6,11 +6,10 @@ nfsServerIp: your_nfs_server_ip
nfsPath: a_nfs_shared_folder_path_on_the_server nfsPath: a_nfs_shared_folder_path_on_the_server
dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory dataSubPath: alpaca_data_cleaned_archive.json # a subpath of the data file under nfs directory
modelSubPath: llama-7b-hf # a subpath of the model file (dir) under nfs directory modelSubPath: llama-7b-hf # a subpath of the model file (dir) under nfs directory
outputSubPath: output # a subpath of the empty directory under the nfs directory to save finetuned model, for example, if you make an empty dir named 'output' at the nfsPath, the value should be 'output'
ompNumThreads: 14 ompNumThreads: 14
cpuPerPod: 42 cpuPerPod: 42
attestionApiServicePort: 9870 attestionApiServicePort: 9870
enableTLS: false # true or false enableTLS: false # true or false
base64ServerCrt: "your_base64_format_server_crt" base64ServerCrt: "your_base64_format_server_crt"
base64ServerKey: "your_base64_format_server_key" base64ServerKey: "your_base64_format_server_key"

View file

@ -176,18 +176,16 @@ def run_pytorch_autocast_bf16(repo_id,
st = time.perf_counter() st = time.perf_counter()
if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']: if repo_id in ['THUDM/chatglm-6b', 'THUDM/chatglm2-6b']:
# TODO: need verify chatglm family run bf16. # TODO: need verify chatglm family run bf16.
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto').float() invalidInputError(False, "Currently pytorch do not support bfloat16 on cpu for chatglm models.")
#model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto').bfloat()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
elif repo_id in ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf', elif repo_id in ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf',
'meta-llama/Llama-2-70b-chat-hf','decapoda-research/llama-7b-hf', 'meta-llama/Llama-2-70b-chat-hf','decapoda-research/llama-7b-hf',
'decapoda-research/llama-65b-hf','lmsys/vicuna-7b-v1.5', 'decapoda-research/llama-65b-hf','lmsys/vicuna-7b-v1.5',
'lmsys/vicuna-13b-v1.3','project-baize/merged-baize-30b']: 'lmsys/vicuna-13b-v1.3','project-baize/merged-baize-30b']:
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto') model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
# Need to use LlamaTokenizer, reason please refer to issue: https://github.com/intel-analytics/BigDL/issues/8944 # Need to use LlamaTokenizer, reason please refer to issue: https://github.com/intel-analytics/BigDL/issues/8944
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
else: else:
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype='auto') model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter() end = time.perf_counter()
print(">> loading of model costs {}s".format(end - st)) print(">> loading of model costs {}s".format(end - st))

View file

@ -42,7 +42,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -46,7 +46,6 @@ if __name__ == '__main__':
# to obtain optimal performance with BigDL-LLM INT4 optimizations # to obtain optimal performance with BigDL-LLM INT4 optimizations
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -44,7 +44,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -42,7 +42,6 @@ if __name__ == '__main__':
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_4bit=True, load_in_4bit=True,
optimize_model=False,
trust_remote_code=True, trust_remote_code=True,
use_cache=True) use_cache=True)
model = model.to('xpu') model = model.to('xpu')

View file

@ -0,0 +1,2 @@
python-embed
portable-executable.zip

View file

@ -0,0 +1,33 @@
# BigDL-LLM Portable Executable For Windows: User Guide
This portable executable includes everything you need to run LLM (except models). Please refer to How to use section to get started.
## 13B model running on an Intel 11-Gen Core PC (real-time screen capture)
<p align="left">
<img src=https://llm-assets.readthedocs.io/en/latest/_images/one-click-installer-screen-capture.gif width='80%' />
</p>
## Verified Models
- ChatGLM2-6b
- Baichuan-13B-Chat
- Baichuan2-7B-Chat
- internlm-chat-7b-8k
- Llama-2-7b-chat-hf
## How to use
1. Download the model to your computer. Please ensure there is a file named `config.json` in the model folder, otherwise the script won't work.
![](https://llm-assets.readthedocs.io/en/latest/_images/one-click-installer-user-guide-step1.png)
2. Run `chat.bat` in Terminal and input the path of the model (e.g. `path\to\model`, note that there's no slash at the end of the path).
![](https://llm-assets.readthedocs.io/en/latest/_images/one-click-installer-user-guide-step2.png)
3. Press Enter and wait until model finishes loading. Then enjoy chatting with the model!
4. If you want to stop chatting, just input `stop` and the model will stop running.
![](https://llm-assets.readthedocs.io/en/latest/_images/one-click-installer-user-guide-step34.png)

View file

@ -0,0 +1,8 @@
@echo off
:: execute chat script
set PYTHONUNBUFFERED=1
set /p modelpath="Please enter the model path: "
.\python-embed\python.exe .\chat.py --model-path="%modelpath%"

View file

@ -0,0 +1,116 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import argparse
import sys
# todo: support more model class
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers import TextIteratorStreamer
from transformers.tools.agents import StopSequenceCriteria
from transformers.generation.stopping_criteria import StoppingCriteriaList
from colorama import Fore
from bigdl.llm import optimize_model
SYSTEM_PROMPT = "A chat between a curious human <human> and an artificial intelligence assistant <bot>.\
The assistant gives helpful, detailed, and polite answers to the human's questions."
HUMAN_ID = "<human>"
BOT_ID = "<bot>"
# chat_history formated in [(iput_str, output_str)]
def format_prompt(input_str,
chat_history):
prompt = [f"{SYSTEM_PROMPT}\n"]
for history_input_str, history_output_str in chat_history:
prompt.append(f"{HUMAN_ID} {history_input_str}\n{BOT_ID} {history_output_str}\n")
prompt.append(f"{HUMAN_ID} {input_str}\n{BOT_ID} ")
return "".join(prompt)
def stream_chat(model,
tokenizer,
stopping_criteria,
input_str,
chat_history):
prompt = format_prompt(input_str, chat_history)
# print(prompt)
input_ids = tokenizer([prompt], return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, stopping_criteria=stopping_criteria)
from threading import Thread
# to ensure non-blocking access to the generated text, generation process should be ran in a separate thread
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
output_str = []
print(Fore.BLUE+"BigDL-LLM: "+Fore.RESET, end="")
for partial_output_str in streamer:
output_str.append(partial_output_str)
# remove the last HUMAN_ID if exists
print(partial_output_str.replace(f"{HUMAN_ID}", ""), end="")
chat_history.append((input_str, "".join(output_str).replace(f"{HUMAN_ID}", "").rstrip()))
def auto_select_model(model_name):
try:
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
low_cpu_mem_usage=True,
torch_dtype="auto",
trust_remote_code=True,
use_cache=True)
except:
model = AutoModel.from_pretrained(model_path,
low_cpu_mem_usage=True,
torch_dtype="auto",
trust_remote_code=True,
use_cache=True)
except:
print("Sorry, the model you entered is not supported in installer.")
sys.exit()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, help="path to an llm")
args = parser.parse_args()
model_path = args.model_path
model = auto_select_model(model_path)
model = optimize_model(model)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(HUMAN_ID, tokenizer)])
chat_history = []
while True:
with torch.inference_mode():
user_input = input(Fore.GREEN+"\nHuman: "+Fore.RESET)
if user_input == "stop": # let's stop the conversation when user input "stop"
break
stream_chat(model=model,
tokenizer=tokenizer,
stopping_criteria=stopping_criteria,
input_str=user_input,
chat_history=chat_history)

View file

@ -0,0 +1,23 @@
:: download python and extract zip
powershell -Command "Start-BitsTransfer -Source https://www.python.org/ftp/python/3.9.13/python-3.9.13-embed-amd64.zip -Destination python-3.9.13-embed-amd64.zip"
powershell -Command "Expand-Archive .\python-3.9.13-embed-amd64.zip -DestinationPath .\python-embed"
del .\python-3.9.13-embed-amd64.zip
set "python-embed=.\python-embed\python.exe"
:: download get-pip.py and install
powershell -Command "Invoke-WebRequest https://bootstrap.pypa.io/get-pip.py -OutFile .\python-embed\get-pip.py"
%python-embed% .\python-embed\get-pip.py
:: enable run site.main() automatically
cd .\python-embed
set "search=#import site"
set "replace=import site"
powershell -Command "(gc python39._pth) -replace '%search%', '%replace%' | Out-File -encoding ASCII python39._pth"
cd ..
:: install pip packages
%python-embed% -m pip install bigdl-llm[all] transformers_stream_generator tiktoken einops colorama
:: compress the python and scripts
powershell -Command "Compress-Archive -Path '.\python-embed', '.\chat.bat', '.\chat.py', '.\README.md' -DestinationPath .\portable-executable.zip"

View file

@ -0,0 +1,5 @@
# BigDL-LLM Portable Executable Setup Script For Windows
# How to use
Just simply run `setup.bat` and it will download and install all dependency and generate a zip file for user to use.

View file

@ -173,6 +173,15 @@ def optimize(model):
module.SelfAttention, module.SelfAttention,
chatglm_attention_forward chatglm_attention_forward
) )
elif "mpt" in model.config._name_or_path:
modeling_module_name = model.__class__.__module__
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
module = importlib.import_module(attention_module_name)
from bigdl.llm.transformers.models.mpt import mpt_multihead_attention_forward
convert_forward(model,
module.MultiheadAttention,
mpt_multihead_attention_forward
)
elif "gptj" in model.config.model_type: elif "gptj" in model.config.model_type:
# dolly-v1-6b # dolly-v1-6b
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
@ -181,7 +190,7 @@ def optimize(model):
convert_forward(model, convert_forward(model,
module.GPTJAttention, module.GPTJAttention,
gptj_attention_forward) gptj_attention_forward)
elif "bloom" in model.config._name_or_path: elif "bloom" in model.config.model_type:
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.bloom import bloom_attention_forward from bigdl.llm.transformers.models.bloom import bloom_attention_forward
@ -189,17 +198,18 @@ def optimize(model):
module.BloomAttention, module.BloomAttention,
bloom_attention_forward bloom_attention_forward
) )
elif "falcon" in model.config._name_or_path: elif "falcon" in model.config.model_type or "RefinedWeb" in model.config.model_type:
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
if "RWForCausalLM" in model.config.architectures: if "RWForCausalLM" in model.config.architectures:
if hasattr(model.config, "multi_query"): if hasattr(model.config, "multi_query"):
# falcon-7b # falcon-7b need to check performance drop after kv cache support.
from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b # from bigdl.llm.transformers.models.falcon import rw_attention_forward_7b
convert_forward(model, # convert_forward(model,
module.Attention, # module.Attention,
rw_attention_forward_7b # rw_attention_forward_7b
) # )
pass
else: else:
# falcon-40b # falcon-40b
from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b from bigdl.llm.transformers.models.falcon import rw_attention_forward_40b
@ -262,5 +272,4 @@ def optimize(model):
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention, transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention,
gptneox_attention_forward gptneox_attention_forward
) )
return model return model

View file

@ -30,6 +30,7 @@ def save_low_bit(self, *args, **kwargs):
invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False), invalidInputError(self.config.to_dict().get("bigdl_transformers_low_bit", False),
f"Detected this model is not a low-bit model, please use from_pretrained's" f"Detected this model is not a low-bit model, please use from_pretrained's"
f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.") f" load_in_4bit or load_in_low_bit parameter to load a 4-bit model first.")
self.to('cpu')
self.save_pretrained(*args, **kwargs) self.save_pretrained(*args, **kwargs)
import json import json
import os import os

View file

@ -26,7 +26,7 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -71,7 +71,7 @@ def baichuan_attention_forward_7b(
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -87,13 +87,13 @@ def baichuan_attention_forward_7b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
max_cache_length, max_cache_length,
dtype=key_states.dtype, dtype=key_states.dtype,
device=device) device=device)
new_key_states[:] = key_states new_key_states[:] = key_states
new_value_states[:] = value_states new_value_states[:] = value_states
key_states = new_key_states key_states = new_key_states
@ -169,7 +169,7 @@ def baichuan_attention_forward_13b(
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -185,13 +185,13 @@ def baichuan_attention_forward_13b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
max_cache_length, max_cache_length,
dtype=key_states.dtype, dtype=key_states.dtype,
device=device) device=device)
new_key_states[:] = key_states new_key_states[:] = key_states
new_value_states[:] = value_states new_value_states[:] = value_states
key_states = new_key_states key_states = new_key_states

View file

@ -26,7 +26,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from transformers.utils import logging, ContextManagers from transformers.utils import logging, ContextManagers
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -83,7 +83,7 @@ def baichuan_attention_forward_7b(
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -99,13 +99,13 @@ def baichuan_attention_forward_7b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
max_cache_length, max_cache_length,
dtype=key_states.dtype, dtype=key_states.dtype,
device=device) device=device)
new_key_states[:] = key_states new_key_states[:] = key_states
new_value_states[:] = value_states new_value_states[:] = value_states
key_states = new_key_states key_states = new_key_states
@ -177,8 +177,10 @@ def baichuan_attention_forward_13b(
cache_k = past_key_value[0] cache_k = past_key_value[0]
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
if device.type == 'xpu':
torch.xpu.empty_cache()
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -194,13 +196,13 @@ def baichuan_attention_forward_13b(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
max_cache_length, max_cache_length,
dtype=key_states.dtype, dtype=key_states.dtype,
device=device) device=device)
new_key_states[:] = key_states new_key_states[:] = key_states
new_value_states[:] = value_states new_value_states[:] = value_states
key_states = new_key_states key_states = new_key_states

View file

@ -37,7 +37,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import functional as F from torch.nn import functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -96,6 +96,8 @@ def bloom_attention_forward(
self.head_dim self.head_dim
) )
_, _, kv_length = key_layer.shape _, _, kv_length = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-1]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.transpose(1, 2).view(batch_size, self.num_heads, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
@ -106,7 +108,7 @@ def bloom_attention_forward(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -124,7 +126,7 @@ def bloom_attention_forward(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,

View file

@ -22,7 +22,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
def rotate_half(x): def rotate_half(x):
@ -68,7 +68,7 @@ def attention_fn(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
past_length, past_length,
@ -82,10 +82,10 @@ def attention_fn(
elif use_cache: elif use_cache:
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH + KV_CACHE_ALLOC_BLOCK_LENGTH
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length, self.hidden_size_per_attention_head, cur_length,
max_cache_length, max_cache_length,
dtype=query_layer.dtype, device=device) dtype=query_layer.dtype, device=device)
key_cache[:] = key_layer key_cache[:] = key_layer
value_cache[:] = value_layer value_cache[:] = value_layer
key_layer = key_cache key_layer = key_cache

View file

@ -20,7 +20,7 @@
import torch import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -152,7 +152,7 @@ def chatglm2_attention_forward_8eb45c(
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
past_length, past_length,
@ -170,10 +170,10 @@ def chatglm2_attention_forward_8eb45c(
max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
+ KV_CACHE_ALLOC_BLOCK_LENGTH + KV_CACHE_ALLOC_BLOCK_LENGTH
key_cache, value_cache = create_kv_cache(batch_size, self.num_attention_heads_per_partition, key_cache, value_cache = init_kv_cache(batch_size, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, cur_length, self.hidden_size_per_attention_head, cur_length,
max_cache_length, max_cache_length,
dtype=query_layer.dtype, device=device) dtype=query_layer.dtype, device=device)
key_cache[:] = key_layer key_cache[:] = key_layer
value_cache[:] = value_layer value_cache[:] = value_layer
key_layer = key_cache key_layer = key_cache

View file

@ -38,7 +38,7 @@ from typing import Optional, Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -86,7 +86,8 @@ def rw_attention_forward_7b(
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len) query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
_, kv_length, _ = key_layer.shape _, kv_length, _ = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_kv, q_length, self.head_dim)
@ -98,7 +99,7 @@ def rw_attention_forward_7b(
cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_kv, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_kv, self.num_kv,
self.head_dim, self.head_dim,
@ -116,7 +117,7 @@ def rw_attention_forward_7b(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_kv, self.num_kv,
self.head_dim, self.head_dim,
@ -264,6 +265,8 @@ def rw_attention_forward_40b(
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len) query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, seq_len)
_, kv_length, _ = key_layer.shape _, kv_length, _ = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, q_length, self.head_dim)
@ -275,7 +278,7 @@ def rw_attention_forward_40b(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -293,7 +296,7 @@ def rw_attention_forward_40b(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -437,7 +440,8 @@ def falcon_attention_forward(
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
_, kv_length, _ = key_layer.shape _, kv_length, _ = key_layer.shape
if layer_past is not None:
kv_length += layer_past[0].shape[-2]
query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim) query_layer = query_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim) key_layer = key_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim) value_layer = value_layer.view(batch_size, self.num_heads, query_length, self.head_dim)
@ -448,7 +452,7 @@ def falcon_attention_forward(
cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim) cache_v = layer_past[1].view(batch_size, self.num_heads, -1, self.head_dim)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache( new_cache_k, new_cache_v = extend_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
@ -466,7 +470,7 @@ def falcon_attention_forward(
elif use_cache: elif use_cache:
max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache( new_key_states, new_value_states = init_kv_cache(
batch_size, batch_size,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,

View file

@ -19,8 +19,8 @@
import torch import torch
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache, \ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
apply_rotary_pos_emb apply_rotary_pos_emb, append_kv_cache
from transformers.utils.import_utils import is_torch_fx_proxy from transformers.utils.import_utils import is_torch_fx_proxy
@ -144,7 +144,7 @@ def gptj_attention_forward(
past_length = cache_k.size(2) past_length = cache_k.size(2)
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
new_cache_k, new_cache_v = create_kv_cache(batch_size, new_cache_k, new_cache_v = extend_kv_cache(batch_size,
self.num_attention_heads, self.num_attention_heads,
self.head_dim, self.head_dim,
past_length, past_length,
@ -158,13 +158,13 @@ def gptj_attention_forward(
key, value = append_kv_cache(cache_k, cache_v, key, value) key, value = append_kv_cache(cache_k, cache_v, key, value)
elif use_cache: elif use_cache:
key_cache, value_cache = create_kv_cache(batch_size, key_cache, value_cache = init_kv_cache(batch_size,
self.num_attention_heads, self.num_attention_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=key.dtype, dtype=key.dtype,
device=device) device=device)
key_cache[:] = key key_cache[:] = key
value_cache[:] = value value_cache[:] = value
key = key_cache key = key_cache

View file

@ -34,7 +34,7 @@
import torch import torch
from typing import Optional, Tuple from typing import Optional, Tuple
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -91,7 +91,7 @@ def gptneox_attention_forward(
past_value = layer_past[1] past_value = layer_past[1]
if past_key.stride()[1] <= past_key.size(2) * past_key.size(3): if past_key.stride()[1] <= past_key.size(2) * past_key.size(3):
# allocate new # allocate new
new_past_key, new_past_value = create_kv_cache(bsz, new_past_key, new_past_value = extend_kv_cache(bsz,
self.num_attention_heads, self.num_attention_heads,
self.head_size, self.head_size,
past_key.size(2), past_key.size(2),
@ -106,13 +106,13 @@ def gptneox_attention_forward(
key, value = append_kv_cache(past_key, past_value, key, value) key, value = append_kv_cache(past_key, past_value, key, value)
elif use_cache: elif use_cache:
max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key, new_value = create_kv_cache(bsz, new_key, new_value = init_kv_cache(bsz,
self.num_attention_heads, self.num_attention_heads,
self.head_size, self.head_size,
seq_len, seq_len,
max_cache_length, max_cache_length,
dtype=key.dtype, dtype=key.dtype,
device=device) device=device)
new_key[:] = key new_key[:] = key
new_value[:] = value new_value[:] = value
key = new_key key = new_key

View file

@ -37,7 +37,7 @@ from typing import Optional, Tuple
import math import math
import torch.nn.functional as F import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
@ -113,7 +113,7 @@ def llama_attention_forward_4_31(
cache_v = past_key_value[1] cache_v = past_key_value[1]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3): if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new # allocate new
new_cache_k, new_cache_v = create_kv_cache(bsz, new_cache_k, new_cache_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA self.num_key_value_heads, # Support GQA
self.head_dim, self.head_dim,
cache_k.size(2), cache_k.size(2),
@ -129,13 +129,13 @@ def llama_attention_forward_4_31(
elif use_cache: elif use_cache:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = create_kv_cache(bsz, new_key_states, new_value_states = init_kv_cache(bsz,
self.num_key_value_heads, self.num_key_value_heads,
self.head_dim, self.head_dim,
kv_seq_len, kv_seq_len,
max_cache_length, max_cache_length,
dtype=key_states.dtype, dtype=key_states.dtype,
device=device) device=device)
new_key_states[:] = key_states new_key_states[:] = key_states
new_value_states[:] = value_states new_value_states[:] = value_states
key_states = new_key_states key_states = new_key_states

View file

@ -0,0 +1,149 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://huggingface.co/mosaicml/mpt-7b-chat/blob/main/attention.py
#
import warnings
import torch
from einops import rearrange
import math
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def mpt_multihead_attention_forward(self, x, past_key_value=None, attn_bias=None,
attention_mask=None, is_causal=True, needs_weights=False):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
(query, key, value) = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
(context, attn_weights, past_key_value) = \
mpt_scaled_multihead_dot_product_attention(query, key, value, self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights)
return (self.out_proj(context), attn_weights, past_key_value)
def mpt_scaled_multihead_dot_product_attention(query, key, value, n_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False):
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
bsz, n_heads, q_len, head_dim = q.size()
device = q.device
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
kv_seq_len = k.shape[-1]
if past_key_value is not None:
if len(past_key_value) != 0:
# k = torch.cat([past_key_value[0], k], dim=3)
# v = torch.cat([past_key_value[1], v], dim=2)
cache_k = past_key_value[0].transpose(2, 3)
cache_v = past_key_value[1]
kv_seq_len += cache_k.shape[-2]
if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
# allocate new
new_cache_k, new_cache_v = extend_kv_cache(bsz,
kv_n_heads, # Support GQA
head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_states, value_states = append_kv_cache(cache_k, cache_v, k.transpose(2, 3), v)
k = key_states.transpose(2, 3)
v = value_states
else:
max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
new_key_states, new_value_states = init_kv_cache(bsz,
kv_n_heads,
head_dim,
kv_seq_len,
max_cache_length,
dtype=k.dtype,
device=device)
new_key_states[:] = k.transpose(2, 3)
new_value_states[:] = v
k = new_key_states.transpose(2, 3)
v = new_value_states
past_key_value = (k, v)
(b, _, s_q, d) = q.shape
s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k \
or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
invalidInputError(False, f'attn_bias (shape: {attn_bias.shape}) '
f'is expected to broadcast to shape: {attn_weight.shape}.')
attn_weight = attn_weight + attn_bias
min_val = torch.finfo(q.dtype).min
if key_padding_mask is not None:
if attn_bias is not None:
warnings.warn('Propogating key_padding_mask to the attention module '
+ 'and applying it within the attention module can cause '
+ 'unneccessary computation/memory usage. Consider integrating '
+ 'into attn_bias once and passing that to each attention '
+ 'module instead.')
attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
if is_causal and (not q.size(2) == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
causal_mask = causal_mask.tril()
causal_mask = causal_mask.to(torch.bool)
causal_mask = ~causal_mask
causal_mask = causal_mask[-s_q:, -s_k:]
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p,
training=training, inplace=True)
out = attn_weight.to(v.dtype).matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)')
if needs_weights:
return (out, attn_weight, past_key_value)
return (out, None, past_key_value)

View file

@ -18,9 +18,7 @@ import torch
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
if device.type == 'xpu':
torch.xpu.empty_cache()
key_cache_storage = torch.empty(batch_size, num_heads, key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim, max_length, head_dim,
dtype=dtype, device=device) dtype=dtype, device=device)
@ -29,7 +27,7 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
dtype=dtype, device=device) dtype=dtype, device=device)
key_cache = key_cache_storage.as_strided((batch_size, num_heads, key_cache = key_cache_storage.as_strided((batch_size, num_heads,
current_length, head_dim), current_length, head_dim),
key_cache_storage.stride(), key_cache_storage.stride(),
storage_offset=0) storage_offset=0)
value_cache = value_cache_storage.as_strided((batch_size, num_heads, value_cache = value_cache_storage.as_strided((batch_size, num_heads,
@ -39,6 +37,13 @@ def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length,
return key_cache, value_cache return key_cache, value_cache
def extend_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
# empty cache to reduce gpu memory
if device.type == 'xpu':
torch.xpu.empty_cache()
return init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device)
def append_kv_cache(cache_k, cache_v, key_states, value_states): def append_kv_cache(cache_k, cache_v, key_states, value_states):
new_size = (cache_k.size(0), new_size = (cache_k.size(0),
cache_k.size(1), cache_k.size(1),