Speculative Starcoder on CPU (#10138)
* Speculative Starcoder on CPU * enable kv-cache pre-allocation * refine codes * refine * fix style * fix style * fix style * refine * refine * Update speculative.py * Update gptbigcode.py * fix style * Update speculative.py * enable mixed-datatype layernorm on top of torch API * adaptive dtype * Update README.md
This commit is contained in:
		
							parent
							
								
									a47989c860
								
							
						
					
					
						commit
						36a9e88104
					
				
					 7 changed files with 304 additions and 5 deletions
				
			
		| 
						 | 
					@ -1,5 +1,5 @@
 | 
				
			||||||
# Mistral
 | 
					# Mistral
 | 
				
			||||||
In this directory, you will find examples on how you could run Baichuan2 BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models.
 | 
					In this directory, you will find examples on how you could run Mistral BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) as reference Mistral models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 0. Requirements
 | 
					## 0. Requirements
 | 
				
			||||||
To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
 | 
					To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,97 @@
 | 
				
			||||||
 | 
					# Starcoder
 | 
				
			||||||
 | 
					In this directory, you will find examples on how you could run Starcoder BF16 inference with self-speculative decoding using BigDL-LLM on [Intel CPUs](../README.md). For illustration purposes,we utilize the [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) and [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py) as reference Starcoder models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 0. Requirements
 | 
				
			||||||
 | 
					To run these examples with BigDL-LLM on Intel CPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Example: Predict Tokens using `generate()` API
 | 
				
			||||||
 | 
					In the example [speculative.py](./speculative.py), we show a basic use case for a Starcoder model to predict the next N tokens using `generate()` API, with BigDL-LLM speculative decoding optimizations on Intel CPUs.
 | 
				
			||||||
 | 
					### 1. Install
 | 
				
			||||||
 | 
					We suggest using conda to manage environment:
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					conda create -n llm python=3.9
 | 
				
			||||||
 | 
					conda activate llm
 | 
				
			||||||
 | 
					pip install --pre --upgrade bigdl-llm[all]
 | 
				
			||||||
 | 
					pip install intel_extension_for_pytorch==2.1.0
 | 
				
			||||||
 | 
					pip install transformers==4.31.0
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					### 2. Configures high-performing processor environment variables
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					source bigdl-llm-init -t
 | 
				
			||||||
 | 
					export OMP_NUM_THREADS=48 # you can change 48 here to #cores of one processor socket
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					### 3. Run
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					We recommend to use `numactl` to bind the program to a specified processor socket:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					numactl -C 0-47 -m 0 python ./speculative.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --n-predict N_PREDICT
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					For example, 0-47 means bind the python program to core list 0-47 for a 48-core socket.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Arguments info:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Starcoder model (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `bigcode/starcoder`.
 | 
				
			||||||
 | 
					- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). A default prompt is provided.
 | 
				
			||||||
 | 
					- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `128`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### Sample Output
 | 
				
			||||||
 | 
					#### [bigcode/starcoder](https://huggingface.co/bigcode/starcoder)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```log
 | 
				
			||||||
 | 
					def dfs_print_Fibonacci_sequence(n):
 | 
				
			||||||
 | 
					    if n == 0:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    elif n == 1:
 | 
				
			||||||
 | 
					        print(0)
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    elif n == 2:
 | 
				
			||||||
 | 
					        print(0)
 | 
				
			||||||
 | 
					        print(1)
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        print(0)
 | 
				
			||||||
 | 
					        print(1)
 | 
				
			||||||
 | 
					        dfs_print_Fibonacci_sequence(n-2)
 | 
				
			||||||
 | 
					        print(dfs_Fibonacci_sequence(n-1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def dfs_Fibonacci_sequence(n):
 | 
				
			||||||
 | 
					    if n == 0:
 | 
				
			||||||
 | 
					        return 0
 | 
				
			||||||
 | 
					    elif n == 1:
 | 
				
			||||||
 | 
					        return 1
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return dfs_Fibonacci_sequence
 | 
				
			||||||
 | 
					Tokens generated 128
 | 
				
			||||||
 | 
					E2E Generation time xx.xxxxs
 | 
				
			||||||
 | 
					First token latency xx.xxxxs
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#### [bigcode/tiny_starcoder_py](https://huggingface.co/bigcode/tiny_starcoder_py)
 | 
				
			||||||
 | 
					```log
 | 
				
			||||||
 | 
					def dfs_print_Fibonacci_sequence(n):
 | 
				
			||||||
 | 
					    if n == 0:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    print(n)
 | 
				
			||||||
 | 
					    for i in range(2, n):
 | 
				
			||||||
 | 
					        print(dfs_print_Fibonacci_sequence(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def dfs_print_Fibonacci_sequence_2(n):
 | 
				
			||||||
 | 
					    if n == 0:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    print(n)
 | 
				
			||||||
 | 
					    for i in range(2, n):
 | 
				
			||||||
 | 
					        print(dfs_print_Fibonacci_sequence_2(i))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def dfs_print_Fibonacci_sequence_3(n):
 | 
				
			||||||
 | 
					    if n == 0:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    print(n)
 | 
				
			||||||
 | 
					    for i in
 | 
				
			||||||
 | 
					Tokens generated 128
 | 
				
			||||||
 | 
					E2E Generation time xx.xxxxs
 | 
				
			||||||
 | 
					First token latency xx.xxxxs
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,87 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					torch.nn.Linear.reset_parameters = lambda x: None
 | 
				
			||||||
 | 
					seed=42
 | 
				
			||||||
 | 
					torch.manual_seed(seed)
 | 
				
			||||||
 | 
					np.random.seed(seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					STARCODER_PROMPT_FORMAT = "{prompt}"
 | 
				
			||||||
 | 
					prompt = "def dfs_print_Fibonacci_sequence(n):"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Mistral model')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="bigcode/starcoder",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the Mistral (e.g. `bigcode/starcoder` and `bigcode/tiny_starcoder_py`) to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--prompt', type=str, default=prompt,
 | 
				
			||||||
 | 
					                        help='Prompt to infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--n-predict', type=int, default=128,
 | 
				
			||||||
 | 
					                        help='Max tokens to predict')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load model in optimized bf16 here.
 | 
				
			||||||
 | 
					    # Set `speculative=True`` to enable speculative decoding,
 | 
				
			||||||
 | 
					    # it only works when load_in_low_bit="fp16" on Intel GPU or load_in_low_bit="bf16" on latest Intel Xeon CPU
 | 
				
			||||||
 | 
					    model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                                 optimize_model=True,
 | 
				
			||||||
 | 
					                                                 torch_dtype=torch.bfloat16,
 | 
				
			||||||
 | 
					                                                 load_in_low_bit="bf16",
 | 
				
			||||||
 | 
					                                                 speculative=True,
 | 
				
			||||||
 | 
					                                                 torchscript=True,
 | 
				
			||||||
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                 use_cache=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with torch.inference_mode():
 | 
				
			||||||
 | 
					        prompt = STARCODER_PROMPT_FORMAT.format(prompt=args.prompt)
 | 
				
			||||||
 | 
					        inputs = tokenizer(prompt, return_tensors='pt')
 | 
				
			||||||
 | 
					        input_ids = inputs.input_ids.to(model.device)
 | 
				
			||||||
 | 
					        actual_in_len = input_ids.shape[1]
 | 
				
			||||||
 | 
					        print("actual input_ids length:" + str(actual_in_len))
 | 
				
			||||||
 | 
					        attention_mask = inputs.attention_mask.to(model.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # warmup
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict,
 | 
				
			||||||
 | 
					                                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                                do_sample=False)
 | 
				
			||||||
 | 
					        output_str = tokenizer.decode(output[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # speculative decoding
 | 
				
			||||||
 | 
					        st = time.perf_counter()
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict,
 | 
				
			||||||
 | 
					                                attention_mask=attention_mask,
 | 
				
			||||||
 | 
					                                do_sample=False)
 | 
				
			||||||
 | 
					        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
				
			||||||
 | 
					        end = time.perf_counter()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print(output_str)
 | 
				
			||||||
 | 
					        print(f"Tokens generated {model.n_token_generated}")
 | 
				
			||||||
 | 
					        print(f"E2E Generation time {(end - st):.4f}s")
 | 
				
			||||||
 | 
					        print(f"First token latency {model.first_token_time:.4f}s")
 | 
				
			||||||
| 
						 | 
					@ -1173,6 +1173,10 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
				
			||||||
        modeling_module_name = model.__class__.__module__
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
        module = importlib.import_module(modeling_module_name)
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
        from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper
 | 
					        from bigdl.llm.transformers.models.gptbigcode import _attn_wrapper
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers.models.gptbigcode import gptbigcode_attention_forward
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.GPTBigCodeAttention,
 | 
				
			||||||
 | 
					                        gptbigcode_attention_forward)
 | 
				
			||||||
        _attn = _attn_wrapper(module.GPTBigCodeAttention._attn)
 | 
					        _attn = _attn_wrapper(module.GPTBigCodeAttention._attn)
 | 
				
			||||||
        replace_func(model,
 | 
					        replace_func(model,
 | 
				
			||||||
                     module.GPTBigCodeAttention,
 | 
					                     module.GPTBigCodeAttention,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -74,7 +74,10 @@ def bloom_layer_norm_forward(self, hidden_states):
 | 
				
			||||||
        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
					        # if nelement == 0, means fused norm failed, go back to python implement.
 | 
				
			||||||
        if result.nelement != 0:
 | 
					        if result.nelement != 0:
 | 
				
			||||||
            return result
 | 
					            return result
 | 
				
			||||||
    return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
 | 
					    input_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					    result = F.layer_norm(hidden_states.to(self.weight.dtype),
 | 
				
			||||||
 | 
					                          self.normalized_shape, self.weight, self.bias, self.eps)
 | 
				
			||||||
 | 
					    return result.to(input_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def bloom_attention_forward(
 | 
					def bloom_attention_forward(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,10 +15,14 @@
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional, Tuple, Union
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _attn_wrapper(origin_attn):
 | 
					def _attn_wrapper(origin_attn):
 | 
				
			||||||
    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
 | 
					    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
 | 
				
			||||||
        attn_output, attn_weights = origin_attn(self,
 | 
					        attn_output, attn_weights = origin_attn(self,
 | 
				
			||||||
                                                query=query,
 | 
					                                                query=query.to(key.dtype),
 | 
				
			||||||
                                                key=key,
 | 
					                                                key=key,
 | 
				
			||||||
                                                value=value,
 | 
					                                                value=value,
 | 
				
			||||||
                                                attention_mask=attention_mask,
 | 
					                                                attention_mask=attention_mask,
 | 
				
			||||||
| 
						 | 
					@ -27,3 +31,84 @@ def _attn_wrapper(origin_attn):
 | 
				
			||||||
            attn_output = attn_output.clone()
 | 
					            attn_output = attn_output.clone()
 | 
				
			||||||
        return attn_output, attn_weights
 | 
					        return attn_output, attn_weights
 | 
				
			||||||
    return _attn
 | 
					    return _attn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def gptbigcode_attention_forward(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					        layer_past: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        head_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        encoder_hidden_states: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        encoder_attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        use_cache: Optional[bool] = False,
 | 
				
			||||||
 | 
					        output_attentions: Optional[bool] = False,
 | 
				
			||||||
 | 
					        **kwargs):
 | 
				
			||||||
 | 
					        if "padding_mask" in kwargs:
 | 
				
			||||||
 | 
					            logger.warning_once(
 | 
				
			||||||
 | 
					                "Passing `padding_mask` is deprecated and will be removed in v4.37." +
 | 
				
			||||||
 | 
					                "Please make sure use `attention_mask` instead.`"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if encoder_hidden_states is not None:
 | 
				
			||||||
 | 
					            if not hasattr(self, "q_attn") or not self.is_cross_attention:
 | 
				
			||||||
 | 
					                from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					                invalidInputError(
 | 
				
			||||||
 | 
					                    False,
 | 
				
			||||||
 | 
					                    "If class is used as cross attention," +
 | 
				
			||||||
 | 
					                    "the weights `q_attn` have to be defined. " +
 | 
				
			||||||
 | 
					                    "Please make sure to instantiate class with " +
 | 
				
			||||||
 | 
					                    "`GPTBigCodeAttention(..., is_cross_attention=True)`."
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            query = self.q_attn(hidden_states)
 | 
				
			||||||
 | 
					            key_value = self.c_attn(encoder_hidden_states)
 | 
				
			||||||
 | 
					            attention_mask = encoder_attention_mask
 | 
				
			||||||
 | 
					        elif self.multi_query:
 | 
				
			||||||
 | 
					            query, key_value = self.c_attn(hidden_states).split(
 | 
				
			||||||
 | 
					                (self.embed_dim, 2 * self.kv_dim), dim=2)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # Note: We split as (self.num_heads, 3, self.head_dim)
 | 
				
			||||||
 | 
					            # instead of (3, self.num_heads, self.head_dim),
 | 
				
			||||||
 | 
					            # i.e., the memory layout is not the same as GPT2.
 | 
				
			||||||
 | 
					            # This makes the concatenation with past_key_value more efficient.
 | 
				
			||||||
 | 
					            query, key_value = (
 | 
				
			||||||
 | 
					                self.c_attn(hidden_states)
 | 
				
			||||||
 | 
					                .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
 | 
				
			||||||
 | 
					                .transpose(1, 2)
 | 
				
			||||||
 | 
					                .split((self.head_dim, 2 * self.head_dim), dim=3)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if layer_past is not None:
 | 
				
			||||||
 | 
					                if layer_past.shape[-2] == key_value.shape[-2]:
 | 
				
			||||||
 | 
					                    key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    fill_zeros = torch.zeros(layer_past.shape[0],
 | 
				
			||||||
 | 
					                                             layer_past.shape[1],
 | 
				
			||||||
 | 
					                                             key_value.shape[2] - layer_past.shape[2],
 | 
				
			||||||
 | 
					                                             dtype=layer_past.dtype,
 | 
				
			||||||
 | 
					                                             device=layer_past.device)
 | 
				
			||||||
 | 
					                    layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
 | 
				
			||||||
 | 
					                    key_value = torch.cat((layer_past, key_value), dim=-2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        present = key_value if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_output, attn_weights = self._attn(query,
 | 
				
			||||||
 | 
					                                               key.transpose(-1, -2),
 | 
				
			||||||
 | 
					                                               value,
 | 
				
			||||||
 | 
					                                               attention_mask,
 | 
				
			||||||
 | 
					                                               head_mask)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.multi_query:
 | 
				
			||||||
 | 
					            attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
 | 
				
			||||||
 | 
					        attn_output = self.c_proj(attn_output)
 | 
				
			||||||
 | 
					        attn_output = self.resid_dropout(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        outputs = (attn_output, present)
 | 
				
			||||||
 | 
					        if output_attentions:
 | 
				
			||||||
 | 
					            if self.multi_query:
 | 
				
			||||||
 | 
					                attn_weights = attn_weights.transpose(1, 2)
 | 
				
			||||||
 | 
					            outputs += (attn_weights,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return outputs
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -168,6 +168,8 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
 | 
				
			||||||
    if not _enable_ipex:
 | 
					    if not _enable_ipex:
 | 
				
			||||||
        len0 = past_key_values[0][0].size(0)
 | 
					        len0 = past_key_values[0][0].size(0)
 | 
				
			||||||
        len1 = past_key_values[0][0].size(1)
 | 
					        len1 = past_key_values[0][0].size(1)
 | 
				
			||||||
 | 
					        # gpt_bigcode has only 2-dimension kv
 | 
				
			||||||
 | 
					        if len(past_key_values[0][0].shape) == 4:
 | 
				
			||||||
            len2 = past_key_values[0][0].size(2)
 | 
					            len2 = past_key_values[0][0].size(2)
 | 
				
			||||||
            len3 = past_key_values[0][0].size(3)
 | 
					            len3 = past_key_values[0][0].size(3)
 | 
				
			||||||
        for i in range(len(past_key_values)):
 | 
					        for i in range(len(past_key_values)):
 | 
				
			||||||
| 
						 | 
					@ -195,6 +197,12 @@ def _prepare_past_key_values_storage_cpu(self, past_key_values,
 | 
				
			||||||
                    torch.float32)
 | 
					                    torch.float32)
 | 
				
			||||||
                past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
 | 
					                past_key_values_storage[i][1][:len0, :, :, :] = past_key_values[i][1].to(
 | 
				
			||||||
                    torch.float32)
 | 
					                    torch.float32)
 | 
				
			||||||
 | 
					            elif self.config.model_type == "gpt_bigcode":
 | 
				
			||||||
 | 
					                kv = torch.ones(len0 + max_new_tokens, len1,
 | 
				
			||||||
 | 
					                                dtype=torch.float32)
 | 
				
			||||||
 | 
					                past_key_values_storage.append(kv[None, :, :])
 | 
				
			||||||
 | 
					                past_key_values_storage[i][0][:len0, :] = past_key_values[i][0].to(
 | 
				
			||||||
 | 
					                    torch.float32)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
					                k0 = torch.ones(len0, len1, len2 + max_new_tokens, len3,
 | 
				
			||||||
                                dtype=torch.float32)
 | 
					                                dtype=torch.float32)
 | 
				
			||||||
| 
						 | 
					@ -266,6 +274,10 @@ def _prepare_draft_past_key_values_cpu(self, past_key_values,
 | 
				
			||||||
            k0 = past_key_values_storage[i][0][:len0, :, :, :]
 | 
					            k0 = past_key_values_storage[i][0][:len0, :, :, :]
 | 
				
			||||||
            v0 = past_key_values_storage[i][1][:len0, :, :, :]
 | 
					            v0 = past_key_values_storage[i][1][:len0, :, :, :]
 | 
				
			||||||
            tmp_past_key_values.append((k0, v0))
 | 
					            tmp_past_key_values.append((k0, v0))
 | 
				
			||||||
 | 
					        elif self.config.model_type == "gpt_bigcode":
 | 
				
			||||||
 | 
					            len0 = past_key_values[0][0].size(0)
 | 
				
			||||||
 | 
					            kv = past_key_values_storage[i][0][:len0, :]
 | 
				
			||||||
 | 
					            tmp_past_key_values.append(kv[None, :, :])
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            len2 = past_key_values[0][0].size(2)
 | 
					            len2 = past_key_values[0][0].size(2)
 | 
				
			||||||
            k0 = past_key_values_storage[i][0][:, :, :len2, :]
 | 
					            k0 = past_key_values_storage[i][0][:, :, :len2, :]
 | 
				
			||||||
| 
						 | 
					@ -292,6 +304,12 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
 | 
				
			||||||
                    past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
 | 
					                    past_key_values[i][0][size:size1, :, :, :].to(torch.float32)
 | 
				
			||||||
                past_key_values_storage[i][1][size:size1, :, :, :] = \
 | 
					                past_key_values_storage[i][1][size:size1, :, :, :] = \
 | 
				
			||||||
                    past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
 | 
					                    past_key_values[i][1][size:size1, :, :, :].to(torch.float32)
 | 
				
			||||||
 | 
					            elif self.config.model_type == "gpt_bigcode":
 | 
				
			||||||
 | 
					                size = original_draft_past_key_values[i][0].size(0)
 | 
				
			||||||
 | 
					                size1 = past_key_values[i][0].size(0)
 | 
				
			||||||
 | 
					                if size < size1:
 | 
				
			||||||
 | 
					                    past_key_values_storage[i][0][size:size1, :] = \
 | 
				
			||||||
 | 
					                        past_key_values[i][0][size:size1, :].to(torch.float32)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                size = original_draft_past_key_values[i][0].size(2)
 | 
					                size = original_draft_past_key_values[i][0].size(2)
 | 
				
			||||||
                size1 = past_key_values[i][0].size(2)
 | 
					                size1 = past_key_values[i][0].size(2)
 | 
				
			||||||
| 
						 | 
					@ -801,6 +819,11 @@ def speculative_generate(self,
 | 
				
			||||||
                             v[:, :, :-(max_of_max_matched - max_matched), :])
 | 
					                             v[:, :, :-(max_of_max_matched - max_matched), :])
 | 
				
			||||||
                            for k, v in past_key_values
 | 
					                            for k, v in past_key_values
 | 
				
			||||||
                        ]
 | 
					                        ]
 | 
				
			||||||
 | 
					                    elif self.config.model_type == "gpt_bigcode":
 | 
				
			||||||
 | 
					                        past_key_values = [
 | 
				
			||||||
 | 
					                            kv[:, :-(max_of_max_matched - max_matched)]
 | 
				
			||||||
 | 
					                            for kv in past_key_values
 | 
				
			||||||
 | 
					                        ]
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        past_key_values = [
 | 
					                        past_key_values = [
 | 
				
			||||||
                            (k[:, :, :-(max_of_max_matched - max_matched)],
 | 
					                            (k[:, :, :-(max_of_max_matched - max_matched)],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue