Speculative Starcoder on CPU (#10138)

* Speculative Starcoder on CPU

* enable kv-cache pre-allocation

* refine codes

* refine

* fix style

* fix style

* fix style

* refine

* refine

* Update speculative.py

* Update gptbigcode.py

* fix style

* Update speculative.py

* enable mixed-datatype layernorm on top of torch API

* adaptive dtype

* Update README.md
This commit is contained in:
Heyang Sun 2024-02-27 09:57:29 +08:00 committed by GitHub
parent a47989c860
commit 36a9e88104
7 changed files with 304 additions and 5 deletions

View file

@ -1,5 +1,5 @@
# Mistral
In this directory, you will find examples on how you could run Baichuan2 BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models.
In this directory, you will find examples on how you could run Mistral BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models.
## 0. Requirements
To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.

View file

@ -0,0 +1,97 @@
# Starcoder
In this directory, you will find examples on how you could run Starcoder BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) and [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py) as reference Starcoder models.
## 0. Requirements
To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
## Example: Predict Tokens using `generate()` API
In the example [speculative.py](./speculative.py), we show a basic use case for a Starcoder model to predict the next N tokens using `generate()` API, with BigDL-LLM speculative decoding optimizations on Intel CPUs.
### 1. Install
We suggest using conda to manage environment:
```bash
conda create -n llm python=3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all]
pip install intel_extension_for_pytorch==2.1.0
pip install transformers==4.31.0
```
### 2. Configures high-performing processor environment variables
```bash
source bigdl-llm-init -t
export OMP_NUM_THREADS=48 # you can change 48 here to #cores of one processor socket
```
### 3. Run
We recommend to use `numactl` to bind the program to a specified processor socket:
```bash
numactl -C 0-47 -m 0 python ./speculative.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
```
For example, 0-47 means bind the python program to core list 0-47 for a 48-core socket.
Arguments info:
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Starcoder model (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `bigcode/starcoder`.
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). A default prompt is provided.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `128`.
#### Sample Output
#### [bigcode/starcoder](https://huggingface.co/bigcode/starcoder)
```log
def dfs_print_Fibonacci_sequence(n):
if n == 0:
return
elif n == 1:
print(0)
return
elif n == 2:
print(0)
print(1)
return
else:
print(0)
print(1)
dfs_print_Fibonacci_sequence(n-2)
print(dfs_Fibonacci_sequence(n-1))
def dfs_Fibonacci_sequence(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return dfs_Fibonacci_sequence
Tokens generated 128
E2E Generation time xx.xxxxs
First token latency xx.xxxxs
```
#### [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py)
```log
def dfs_print_Fibonacci_sequence(n):
if n == 0:
return
print(n)
for i in range(2, n):
print(dfs_print_Fibonacci_sequence(i))
def dfs_print_Fibonacci_sequence_2(n):
if n == 0:
return
print(n)
for i in range(2, n):
print(dfs_print_Fibonacci_sequence_2(i))
def dfs_print_Fibonacci_sequence_3(n):
if n == 0:
return
print(n)
for i in
Tokens generated 128
E2E Generation time xx.xxxxs
First token latency xx.xxxxs
```

View file

@ -0,0 +1,87 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import argparse
import time
import numpy as np
torch.nn.Linear.reset_parameters = lambda x: None
seed=42
torch.manual_seed(seed)
np.random.seed(seed)
STARCODER_PROMPT_FORMAT = "{prompt}"
prompt = "def dfs_print_Fibonacci_sequence(n):"
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Mistral model')
parser.add_argument('--repo-id-or-model-path', type=str, default="bigcode/starcoder",
help='The huggingface repo id for the Mistral (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default=prompt,
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=128,
help='Max tokens to predict')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
# Load model in optimized bf16 here.
# Set `speculative=True`` to enable speculative decoding,
# it only works when load_in_low_bit="fp16" on Intel GPU or load_in_low_bit="bf16" on latest Intel Xeon CPU
model = AutoModelForCausalLM.from_pretrained(model_path,
optimize_model=True,
torch_dtype=torch.bfloat16,
load_in_low_bit="bf16",
speculative=True,
torchscript=True,
trust_remote_code=True,
use_cache=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
with torch.inference_mode():
prompt = STARCODER_PROMPT_FORMAT.format(prompt=args.prompt)
inputs = tokenizer(prompt, return_tensors='pt')
input_ids = inputs.input_ids.to(model.device)
actual_in_len = input_ids.shape[1]
print("actual input_ids length:" + str(actual_in_len))
attention_mask = inputs.attention_mask.to(model.device)
# warmup
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
attention_mask=attention_mask,
do_sample=False)
output_str = tokenizer.decode(output[0])
# speculative decoding
st = time.perf_counter()
output = model.generate(input_ids,
max_new_tokens=args.n_predict,
attention_mask=attention_mask,
do_sample=False)
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
end = time.perf_counter()
print(output_str)
print(f"Tokens generated {model.n_token_generated}")
print(f"E2E Generation time {(end - st):.4f}s")
print(f"First token latency {model.first_token_time:.4f}s")

View file

@ -1173,6 +1173,10 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper
from bigdl.llm.transformers.models.gptbigcode import gptbigcode_attention_forward
convert_forward(model,
module.GPTBigCodeAttention,
gptbigcode_attention_forward)
_attn = _attn_wrapper(module.GPTBigCodeAttention._attn)
replace_func(model,
module.GPTBigCodeAttention,

View file

@ -74,7 +74,10 @@ def bloom_layer_norm_forward(self, hidden_states):
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
input_dtype = hidden_states.dtype
result = F.layer_norm(hidden_states.to(self.weight.dtype),
self.normalized_shape, self.weight, self.bias, self.eps)
return result.to(input_dtype)
def bloom_attention_forward(

View file

@ -15,10 +15,14 @@
#
from typing import Optional, Tuple, Union
import torch
def _attn_wrapper(origin_attn):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_output, attn_weights = origin_attn(self,
query=query,
query=query.to(key.dtype),
key=key,
value=value,
attention_mask=attention_mask,
@ -27,3 +31,84 @@ def _attn_wrapper(origin_attn):
attn_output = attn_output.clone()
return attn_output, attn_weights
return _attn
def gptbigcode_attention_forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs):
if "padding_mask" in kwargs:
logger.warning_once(
"Passing `padding_mask` is deprecated and will be removed in v4.37." +
"Please make sure use `attention_mask` instead.`"
)
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
from bigdl.llm.utils.common import invalidInputError
invalidInputError(
False,
"If class is used as cross attention," +
"the weights `q_attn` have to be defined. " +
"Please make sure to instantiate class with " +
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim)
# instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
if layer_past is not None:
if layer_past.shape[-2] == key_value.shape[-2]:
key_value = torch.cat((layer_past, key_value), dim=-2)
else:
fill_zeros = torch.zeros(layer_past.shape[0],
layer_past.shape[1],
key_value.shape[2] - layer_past.shape[2],
dtype=layer_past.dtype,
device=layer_past.device)
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
attn_output, attn_weights = self._attn(query,
key.transpose(-1, -2),
value,
attention_mask,
head_mask)
if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs

View file

@ -168,8 +168,10 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
if not _enable_ipex:
len0 = past_key_values[0][0].size(0)
len1 = past_key_values[0][0].size(1)
len2 = past_key_values[0][0].size(2)
len3 = past_key_values[0][0].size(3)
# gpt_bigcode has only 2-dimension kv
if len(past_key_values[0][0].shape) == 4:
len2 = past_key_values[0][0].size(2)
len3 = past_key_values[0][0].size(3)
for i in range(len(past_key_values)):
if self.config.model_type == "qwen":
k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3,
@ -195,6 +197,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
torch.float32)
past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
torch.float32)
elif self.config.model_type == "gpt_bigcode":
kv = torch.ones(len0 + max_new_tokens, len1,
dtype=torch.float32)
past_key_values_storage.append(kv[None, :, :])
past_key_values_storage[i][0][:len0, :] = past_key_values[i][0].to(
torch.float32)
else:
k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
dtype=torch.float32)
@ -266,6 +274,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values,
k0 = past_key_values_storage[i][0][:len0, :, :, :]
v0 = past_key_values_storage[i][1][:len0, :, :, :]
tmp_past_key_values.append((k0, v0))
elif self.config.model_type == "gpt_bigcode":
len0 = past_key_values[0][0].size(0)
kv = past_key_values_storage[i][0][:len0, :]
tmp_past_key_values.append(kv[None, :, :])
else:
len2 = past_key_values[0][0].size(2)
k0 = past_key_values_storage[i][0][:, :, :len2, :]
@ -292,6 +304,12 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
past_key_values_storage[i][1][size:size1, :, :, :] = \
past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
elif self.config.model_type == "gpt_bigcode":
size = original_draft_past_key_values[i][0].size(0)
size1 = past_key_values[i][0].size(0)
if size < size1:
past_key_values_storage[i][0][size:size1, :] = \
past_key_values[i][0][size:size1, :].to(torch.float32)
else:
size = original_draft_past_key_values[i][0].size(2)
size1 = past_key_values[i][0].size(2)
@ -801,6 +819,11 @@ def speculative_generate(self,
v[:, :, :-(max_of_max_matched - max_matched), :])
for k, v in past_key_values
]
elif self.config.model_type == "gpt_bigcode":
past_key_values = [
kv[:, :-(max_of_max_matched - max_matched)]
for kv in past_key_values
]
else:
past_key_values = [
(k[:, :, :-(max_of_max_matched - max_matched)],