diff --git a/python/llm/example/CPU/Speculative-Decoding/mistral/README.md b/python/llm/example/CPU/Speculative-Decoding/mistral/README.md index b0845056..32e9ec64 100644 --- a/python/llm/example/CPU/Speculative-Decoding/mistral/README.md +++ b/python/llm/example/CPU/Speculative-Decoding/mistral/README.md @@ -1,5 +1,5 @@ # Mistral -In this directory, you will find examples on how you could run Baichuan2 BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models. +In this directory, you will find examples on how you could run Mistral BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models. ## 0. Requirements To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. diff --git a/python/llm/example/CPU/Speculative-Decoding/starcoder/README.md b/python/llm/example/CPU/Speculative-Decoding/starcoder/README.md new file mode 100644 index 00000000..36636fab --- /dev/null +++ b/python/llm/example/CPU/Speculative-Decoding/starcoder/README.md @@ -0,0 +1,97 @@ +# Starcoder +In this directory, you will find examples on how you could run Starcoder BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) and [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py) as reference Starcoder models. + +## 0. Requirements +To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example: Predict Tokens using `generate()` API +In the example [speculative.py](./speculative.py), we show a basic use case for a Starcoder model to predict the next N tokens using `generate()` API, with BigDL-LLM speculative decoding optimizations on Intel CPUs. +### 1. Install +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.9 +conda activate llm +pip install --pre --upgrade bigdl-llm[all] +pip install intel_extension_for_pytorch==2.1.0 +pip install transformers==4.31.0 +``` +### 2. Configures high-performing processor environment variables +```bash +source bigdl-llm-init -t +export OMP_NUM_THREADS=48 # you can change 48 here to #cores of one processor socket +``` +### 3. Run + +We recommend to use `numactl` to bind the program to a specified processor socket: + +```bash +numactl -C 0-47 -m 0 python ./speculative.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT +``` + +For example, 0-47 means bind the python program to core list 0-47 for a 48-core socket. + +Arguments info: + +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Starcoder model (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `bigcode/starcoder`. +- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). A default prompt is provided. +- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `128`. + +#### Sample Output +#### [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) + +```log +def dfs_print_Fibonacci_sequence(n): + if n == 0: + return + elif n == 1: + print(0) + return + elif n == 2: + print(0) + print(1) + return + else: + print(0) + print(1) + dfs_print_Fibonacci_sequence(n-2) + print(dfs_Fibonacci_sequence(n-1)) + +def dfs_Fibonacci_sequence(n): + if n == 0: + return 0 + elif n == 1: + return 1 + else: + return dfs_Fibonacci_sequence +Tokens generated 128 +E2E Generation time xx.xxxxs +First token latency xx.xxxxs +``` + +#### [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py) +```log +def dfs_print_Fibonacci_sequence(n): + if n == 0: + return + print(n) + for i in range(2, n): + print(dfs_print_Fibonacci_sequence(i)) + + +def dfs_print_Fibonacci_sequence_2(n): + if n == 0: + return + print(n) + for i in range(2, n): + print(dfs_print_Fibonacci_sequence_2(i)) + + +def dfs_print_Fibonacci_sequence_3(n): + if n == 0: + return + print(n) + for i in +Tokens generated 128 +E2E Generation time xx.xxxxs +First token latency xx.xxxxs +``` diff --git a/python/llm/example/CPU/Speculative-Decoding/starcoder/speculative.py b/python/llm/example/CPU/Speculative-Decoding/starcoder/speculative.py new file mode 100644 index 00000000..08305401 --- /dev/null +++ b/python/llm/example/CPU/Speculative-Decoding/starcoder/speculative.py @@ -0,0 +1,87 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers import AutoTokenizer +import argparse +import time +import numpy as np + + +torch.nn.Linear.reset_parameters = lambda x: None +seed=42 +torch.manual_seed(seed) +np.random.seed(seed) + +STARCODER_PROMPT_FORMAT = "{prompt}" +prompt = "def dfs_print_Fibonacci_sequence(n):" + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Mistral model') + parser.add_argument('--repo-id-or-model-path', type=str, default="bigcode/starcoder", + help='The huggingface repo id for the Mistral (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--prompt', type=str, default=prompt, + help='Prompt to infer') + parser.add_argument('--n-predict', type=int, default=128, + help='Max tokens to predict') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + + # Load model in optimized bf16 here. + # Set `speculative=True`` to enable speculative decoding, + # it only works when load_in_low_bit="fp16" on Intel GPU or load_in_low_bit="bf16" on latest Intel Xeon CPU + model = AutoModelForCausalLM.from_pretrained(model_path, + optimize_model=True, + torch_dtype=torch.bfloat16, + load_in_low_bit="bf16", + speculative=True, + torchscript=True, + trust_remote_code=True, + use_cache=True) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + + with torch.inference_mode(): + prompt = STARCODER_PROMPT_FORMAT.format(prompt=args.prompt) + inputs = tokenizer(prompt, return_tensors='pt') + input_ids = inputs.input_ids.to(model.device) + actual_in_len = input_ids.shape[1] + print("actual input_ids length:" + str(actual_in_len)) + attention_mask = inputs.attention_mask.to(model.device) + + # warmup + output = model.generate(input_ids, + max_new_tokens=args.n_predict, + attention_mask=attention_mask, + do_sample=False) + output_str = tokenizer.decode(output[0]) + + # speculative decoding + st = time.perf_counter() + output = model.generate(input_ids, + max_new_tokens=args.n_predict, + attention_mask=attention_mask, + do_sample=False) + output_str = tokenizer.decode(output[0], skip_special_tokens=True) + end = time.perf_counter() + + print(output_str) + print(f"Tokens generated {model.n_token_generated}") + print(f"E2E Generation time {(end - st):.4f}s") + print(f"First token latency {model.first_token_time:.4f}s") diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index e3e93b3b..6a7c9099 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1173,6 +1173,10 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper + from bigdl.llm.transformers.models.gptbigcode import gptbigcode_attention_forward + convert_forward(model, + module.GPTBigCodeAttention, + gptbigcode_attention_forward) _attn = _attn_wrapper(module.GPTBigCodeAttention._attn) replace_func(model, module.GPTBigCodeAttention, diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index d8d5aab0..4438270f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -74,7 +74,10 @@ def bloom_layer_norm_forward(self, hidden_states): # if nelement == 0, means fused norm failed, go back to python implement. if result.nelement != 0: return result - return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) + input_dtype = hidden_states.dtype + result = F.layer_norm(hidden_states.to(self.weight.dtype), + self.normalized_shape, self.weight, self.bias, self.eps) + return result.to(input_dtype) def bloom_attention_forward( diff --git a/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py b/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py index 600fbc1d..6f9895b1 100644 --- a/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py +++ b/python/llm/src/bigdl/llm/transformers/models/gptbigcode.py @@ -15,10 +15,14 @@ # +from typing import Optional, Tuple, Union +import torch + + def _attn_wrapper(origin_attn): def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_output, attn_weights = origin_attn(self, - query=query, + query=query.to(key.dtype), key=key, value=value, attention_mask=attention_mask, @@ -27,3 +31,84 @@ def _attn_wrapper(origin_attn): attn_output = attn_output.clone() return attn_output, attn_weights return _attn + + +def gptbigcode_attention_forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs): + if "padding_mask" in kwargs: + logger.warning_once( + "Passing `padding_mask` is deprecated and will be removed in v4.37." + + "Please make sure use `attention_mask` instead.`" + ) + + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + from bigdl.llm.utils.common import invalidInputError + invalidInputError( + False, + "If class is used as cross attention," + + "the weights `q_attn` have to be defined. " + + "Please make sure to instantiate class with " + + "`GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split( + (self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) + # instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + if layer_past.shape[-2] == key_value.shape[-2]: + key_value = torch.cat((layer_past, key_value), dim=-2) + else: + fill_zeros = torch.zeros(layer_past.shape[0], + layer_past.shape[1], + key_value.shape[2] - layer_past.shape[2], + dtype=layer_past.dtype, + device=layer_past.device) + layer_past = torch.cat([layer_past, fill_zeros], dim=-1) + key_value = torch.cat((layer_past, key_value), dim=-2) + + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + attn_output, attn_weights = self._attn(query, + key.transpose(-1, -2), + value, + attention_mask, + head_mask) + + if not self.multi_query: + attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + if self.multi_query: + attn_weights = attn_weights.transpose(1, 2) + outputs += (attn_weights,) + + return outputs diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index a8633fca..2b1fb319 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -168,8 +168,10 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values, if not _enable_ipex: len0 = past_key_values[0][0].size(0) len1 = past_key_values[0][0].size(1) - len2 = past_key_values[0][0].size(2) - len3 = past_key_values[0][0].size(3) + # gpt_bigcode has only 2-dimension kv + if len(past_key_values[0][0].shape) == 4: + len2 = past_key_values[0][0].size(2) + len3 = past_key_values[0][0].size(3) for i in range(len(past_key_values)): if self.config.model_type == "qwen": k0 = torch.ones(len0, len2, len1 + max_new_tokens, len3, @@ -195,6 +197,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values, torch.float32) past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to( torch.float32) + elif self.config.model_type == "gpt_bigcode": + kv = torch.ones(len0 + max_new_tokens, len1, + dtype=torch.float32) + past_key_values_storage.append(kv[None, :, :]) + past_key_values_storage[i][0][:len0, :] = past_key_values[i][0].to( + torch.float32) else: k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3, dtype=torch.float32) @@ -266,6 +274,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values, k0 = past_key_values_storage[i][0][:len0, :, :, :] v0 = past_key_values_storage[i][1][:len0, :, :, :] tmp_past_key_values.append((k0, v0)) + elif self.config.model_type == "gpt_bigcode": + len0 = past_key_values[0][0].size(0) + kv = past_key_values_storage[i][0][:len0, :] + tmp_past_key_values.append(kv[None, :, :]) else: len2 = past_key_values[0][0].size(2) k0 = past_key_values_storage[i][0][:, :, :len2, :] @@ -292,6 +304,12 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s past_key_values[i][0][size:size1, :, :, :].to(torch.float32) past_key_values_storage[i][1][size:size1, :, :, :] = \ past_key_values[i][1][size:size1, :, :, :].to(torch.float32) + elif self.config.model_type == "gpt_bigcode": + size = original_draft_past_key_values[i][0].size(0) + size1 = past_key_values[i][0].size(0) + if size < size1: + past_key_values_storage[i][0][size:size1, :] = \ + past_key_values[i][0][size:size1, :].to(torch.float32) else: size = original_draft_past_key_values[i][0].size(2) size1 = past_key_values[i][0].size(2) @@ -801,6 +819,11 @@ def speculative_generate(self, v[:, :, :-(max_of_max_matched - max_matched), :]) for k, v in past_key_values ] + elif self.config.model_type == "gpt_bigcode": + past_key_values = [ + kv[:, :-(max_of_max_matched - max_matched)] + for kv in past_key_values + ] else: past_key_values = [ (k[:, :, :-(max_of_max_matched - max_matched)],