LLM: Support optimized kv_cache for baichuan family (#8997)
* add initial support for baichuan attantion * support baichuan1 * update based on comment * update based on comment * support baichuan2 * update link, change how to jusge baichuan2 * fix style * add model parameter for pob emb * update based on comment
This commit is contained in:
		
							parent
							
								
									37bb0cbf8f
								
							
						
					
					
						commit
						004c45c2be
					
				
					 7 changed files with 317 additions and 25 deletions
				
			
		| 
						 | 
					@ -173,4 +173,24 @@ def optimize(model):
 | 
				
			||||||
                        chatglm_attention_forward
 | 
					                        chatglm_attention_forward
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif model.config.model_type == "baichuan" and model.config.vocab_size == 125696:
 | 
				
			||||||
 | 
					        # baichuan2
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers.models.baichuan2 import baichuan_attention_forward
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.Attention,
 | 
				
			||||||
 | 
					                        baichuan_attention_forward
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif model.config.model_type == "baichuan":
 | 
				
			||||||
 | 
					        # baichuan1
 | 
				
			||||||
 | 
					        modeling_module_name = model.__class__.__module__
 | 
				
			||||||
 | 
					        module = importlib.import_module(modeling_module_name)
 | 
				
			||||||
 | 
					        from bigdl.llm.transformers.models.baichuan import baichuan_attention_forward
 | 
				
			||||||
 | 
					        convert_forward(model,
 | 
				
			||||||
 | 
					                        module.Attention,
 | 
				
			||||||
 | 
					                        baichuan_attention_forward
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										135
									
								
								python/llm/src/bigdl/llm/transformers/models/baichuan.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								python/llm/src/bigdl/llm/transformers/models/baichuan.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,135 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This file is adapted from
 | 
				
			||||||
 | 
					# https://huggingface.co/baichuan-inc/Baichuan-7B/blob/c1a5c7d5b7f50ecc51bb0e08150a9f12e5656756/modeling_baichuan.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from typing import List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def baichuan_attention_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    proj = self.W_pack(hidden_states)
 | 
				
			||||||
 | 
					    proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
 | 
				
			||||||
 | 
					    # batch_size x source_len x hidden_size
 | 
				
			||||||
 | 
					    query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    # batch_size x target_len x head_size
 | 
				
			||||||
 | 
					    key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    # batch_size x source_len x hidden_size
 | 
				
			||||||
 | 
					    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        kv_seq_len += past_key_value[0].shape[-2]
 | 
				
			||||||
 | 
					    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
 | 
					    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
 | 
					                                                    cos, sin, position_ids, "baichuan")
 | 
				
			||||||
 | 
					    # [bsz, nh, t, hd]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if past_key_value is not None:
 | 
				
			||||||
 | 
					    #     # reuse k, v, self_attention
 | 
				
			||||||
 | 
					    #     key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
				
			||||||
 | 
					    #     value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        # reuse k, v, self_attention
 | 
				
			||||||
 | 
					        cache_k = past_key_value[0]
 | 
				
			||||||
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
 | 
					            # allocate new
 | 
				
			||||||
 | 
					            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                       self.num_heads,
 | 
				
			||||||
 | 
					                                                       self.head_dim,
 | 
				
			||||||
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
 | 
					                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                                                       dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                                                       device=device)
 | 
				
			||||||
 | 
					            new_cache_k[:] = cache_k
 | 
				
			||||||
 | 
					            new_cache_v[:] = cache_v
 | 
				
			||||||
 | 
					            cache_k = new_cache_k
 | 
				
			||||||
 | 
					            cache_v = new_cache_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif use_cache:
 | 
				
			||||||
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
 | 
					        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                           self.num_heads,
 | 
				
			||||||
 | 
					                                                           self.head_dim,
 | 
				
			||||||
 | 
					                                                           kv_seq_len,
 | 
				
			||||||
 | 
					                                                           max_cache_length,
 | 
				
			||||||
 | 
					                                                           dtype=key_states.dtype,
 | 
				
			||||||
 | 
					                                                           device=device)
 | 
				
			||||||
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
 | 
					        key_states = new_key_states
 | 
				
			||||||
 | 
					        value_states = new_value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          f"Attention weights should be of size "
 | 
				
			||||||
 | 
					                          f"{(bsz, self.num_heads, q_len, kv_seq_len)}"
 | 
				
			||||||
 | 
					                          f", but is {attn_weights.size()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attention_mask is not None:
 | 
				
			||||||
 | 
					        invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
 | 
				
			||||||
 | 
					                          f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
 | 
				
			||||||
 | 
					                          f"but is {attention_mask.size()}")
 | 
				
			||||||
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # upcast attention to fp32
 | 
				
			||||||
 | 
					    attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
 | 
					                                         dtype=torch.float32).to(query_states.dtype)
 | 
				
			||||||
 | 
					    attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
				
			||||||
 | 
					                      f"`attn_output` should be of size "
 | 
				
			||||||
 | 
					                      f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
				
			||||||
 | 
					                      f"but is {attn_output.size()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = attn_output.transpose(1, 2)
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not output_attentions:
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
							
								
								
									
										135
									
								
								python/llm/src/bigdl/llm/transformers/models/baichuan2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								python/llm/src/bigdl/llm/transformers/models/baichuan2.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,135 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This file is adapted from
 | 
				
			||||||
 | 
					# https://huggingface.co/baichuan-inc/Baichuan2-7B-Base/blob/cb7fc748b78b7ea99772e4cf76db155729ce774e/modeling_baichuan.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import math
 | 
				
			||||||
 | 
					from typing import List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
 | 
					from torch import nn
 | 
				
			||||||
 | 
					from torch.nn import functional as F
 | 
				
			||||||
 | 
					from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
				
			||||||
 | 
					from transformers.utils import logging, ContextManagers
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
 | 
					    from xformers import ops as xops
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    xops = None
 | 
				
			||||||
 | 
					    logger.warning(
 | 
				
			||||||
 | 
					        "Xformers is not installed correctly. If you want to use memory_efficient_attention to "
 | 
				
			||||||
 | 
					        "accelerate training use the following command to install Xformers\npip install xformers."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def baichuan_attention_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    hidden_states: torch.Tensor,
 | 
				
			||||||
 | 
					    attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					    position_ids: Optional[torch.LongTensor] = None,
 | 
				
			||||||
 | 
					    past_key_value: Optional[Tuple[torch.Tensor]] = None,
 | 
				
			||||||
 | 
					    output_attentions: bool = False,
 | 
				
			||||||
 | 
					    use_cache: bool = False,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    proj = self.W_pack(hidden_states)
 | 
				
			||||||
 | 
					    proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
 | 
				
			||||||
 | 
					    # batch_size x source_len x hidden_size
 | 
				
			||||||
 | 
					    query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    # batch_size x target_len x head_size
 | 
				
			||||||
 | 
					    key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					    # batch_size x source_len x hidden_size
 | 
				
			||||||
 | 
					    value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kv_seq_len = key_states.shape[-2]
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        kv_seq_len += past_key_value[0].shape[-2]
 | 
				
			||||||
 | 
					    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
 | 
					    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
 | 
					                                                    cos, sin, position_ids, "baichuan")
 | 
				
			||||||
 | 
					    # [bsz, nh, t, hd]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # if past_key_value is not None:
 | 
				
			||||||
 | 
					    #     # reuse k, v, self_attention
 | 
				
			||||||
 | 
					    #     key_states = torch.cat([past_key_value[0], key_states], dim=2)
 | 
				
			||||||
 | 
					    #     value_states = torch.cat([past_key_value[1], value_states], dim=2)
 | 
				
			||||||
 | 
					    if past_key_value is not None:
 | 
				
			||||||
 | 
					        # reuse k, v, self_attention
 | 
				
			||||||
 | 
					        cache_k = past_key_value[0]
 | 
				
			||||||
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
 | 
					            # allocate new
 | 
				
			||||||
 | 
					            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                       self.num_heads,
 | 
				
			||||||
 | 
					                                                       self.head_dim,
 | 
				
			||||||
 | 
					                                                       cache_k.size(2),
 | 
				
			||||||
 | 
					                                                       kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                                                       dtype=cache_k.dtype,
 | 
				
			||||||
 | 
					                                                       device=device)
 | 
				
			||||||
 | 
					            new_cache_k[:] = cache_k
 | 
				
			||||||
 | 
					            new_cache_v[:] = cache_v
 | 
				
			||||||
 | 
					            cache_k = new_cache_k
 | 
				
			||||||
 | 
					            cache_v = new_cache_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        key_states, value_states = append_kv_cache(cache_k, cache_v, key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    elif use_cache:
 | 
				
			||||||
 | 
					        max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
 | 
					        new_key_states, new_value_states = create_kv_cache(bsz,
 | 
				
			||||||
 | 
					                                                           self.num_heads,
 | 
				
			||||||
 | 
					                                                           self.head_dim,
 | 
				
			||||||
 | 
					                                                           kv_seq_len,
 | 
				
			||||||
 | 
					                                                           max_cache_length,
 | 
				
			||||||
 | 
					                                                           dtype=key_states.dtype,
 | 
				
			||||||
 | 
					                                                           device=device)
 | 
				
			||||||
 | 
					        new_key_states[:] = key_states
 | 
				
			||||||
 | 
					        new_value_states[:] = value_states
 | 
				
			||||||
 | 
					        key_states = new_key_states
 | 
				
			||||||
 | 
					        value_states = new_value_states
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if xops is not None and self.training:
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					        query_states = query_states.transpose(1, 2)
 | 
				
			||||||
 | 
					        key_states = key_states.transpose(1, 2)
 | 
				
			||||||
 | 
					        value_states = value_states.transpose(1, 2)
 | 
				
			||||||
 | 
					        attn_output = xops.memory_efficient_attention(
 | 
				
			||||||
 | 
					            query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True,
 | 
				
			||||||
 | 
					                                            enable_mem_efficient=True):
 | 
				
			||||||
 | 
					            attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
 | 
				
			||||||
 | 
					                                                         attn_mask=attention_mask)
 | 
				
			||||||
 | 
					        attn_output = attn_output.transpose(1, 2)
 | 
				
			||||||
 | 
					    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
				
			||||||
 | 
					    attn_output = self.o_proj(attn_output)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not output_attentions:
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return attn_output, attn_weights, past_key_value
 | 
				
			||||||
| 
						 | 
					@ -67,8 +67,6 @@ def attention_fn(
 | 
				
			||||||
        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
					        cache_v = cache_v.permute(1, 2, 0, 3)
 | 
				
			||||||
        past_length = cache_k.size(2)
 | 
					        past_length = cache_k.size(2)
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            if device.type == 'xpu':
 | 
					 | 
				
			||||||
                torch.xpu.empty_cache()
 | 
					 | 
				
			||||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
					            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
				
			||||||
                                                       self.num_attention_heads_per_partition,
 | 
					                                                       self.num_attention_heads_per_partition,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,8 +151,6 @@ def chatglm2_attention_forward_8eb45c(
 | 
				
			||||||
        past_length = cache_k.size(2)
 | 
					        past_length = cache_k.size(2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            if device.type == 'xpu':
 | 
					 | 
				
			||||||
                torch.xpu.empty_cache()
 | 
					 | 
				
			||||||
            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
					            max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
					            new_cache_k, new_cache_v = create_kv_cache(batch_size,
 | 
				
			||||||
                                                       self.num_attention_heads_per_partition,
 | 
					                                                       self.num_attention_heads_per_partition,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,24 +38,7 @@ import math
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
					from bigdl.llm.transformers.models.utils import create_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
				
			||||||
 | 
					 | 
				
			||||||
def rotate_half(x):
 | 
					 | 
				
			||||||
    """Rotates half the hidden dims of the input."""
 | 
					 | 
				
			||||||
    x1 = x[..., :x.shape[-1] // 2]
 | 
					 | 
				
			||||||
    x2 = x[..., x.shape[-1] // 2:]
 | 
					 | 
				
			||||||
    return torch.cat((-x2, x1), dim=-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
 | 
					 | 
				
			||||||
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
 | 
					 | 
				
			||||||
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
					 | 
				
			||||||
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
					 | 
				
			||||||
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
					 | 
				
			||||||
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
					 | 
				
			||||||
    q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
					 | 
				
			||||||
    k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
					 | 
				
			||||||
    return q_embed, k_embed
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
					def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
				
			||||||
| 
						 | 
					@ -122,15 +105,13 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
        kv_seq_len += past_key_value[0].shape[-2]
 | 
					        kv_seq_len += past_key_value[0].shape[-2]
 | 
				
			||||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
					    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
				
			||||||
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
					    query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
				
			||||||
                                                    cos, sin, position_ids)
 | 
					                                                    cos, sin, position_ids, "llama")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if past_key_value is not None:
 | 
					    if past_key_value is not None:
 | 
				
			||||||
        # reuse k, v, self_attention
 | 
					        # reuse k, v, self_attention
 | 
				
			||||||
        cache_k = past_key_value[0]
 | 
					        cache_k = past_key_value[0]
 | 
				
			||||||
        cache_v = past_key_value[1]
 | 
					        cache_v = past_key_value[1]
 | 
				
			||||||
        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
					        if cache_k.stride()[1] <= cache_k.size(2) * cache_k.size(3):
 | 
				
			||||||
            if device.type == 'xpu':
 | 
					 | 
				
			||||||
                torch.xpu.empty_cache()
 | 
					 | 
				
			||||||
            # allocate new
 | 
					            # allocate new
 | 
				
			||||||
            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
					            new_cache_k, new_cache_v = create_kv_cache(bsz,
 | 
				
			||||||
                                                       self.num_key_value_heads,  # Support GQA
 | 
					                                                       self.num_key_value_heads,  # Support GQA
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,9 +15,12 @@
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
					def create_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
 | 
				
			||||||
 | 
					    if device.type == 'xpu':
 | 
				
			||||||
 | 
					        torch.xpu.empty_cache()
 | 
				
			||||||
    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
					    key_cache_storage = torch.empty(batch_size, num_heads,
 | 
				
			||||||
                                    max_length, head_dim,
 | 
					                                    max_length, head_dim,
 | 
				
			||||||
                                    dtype=dtype, device=device)
 | 
					                                    dtype=dtype, device=device)
 | 
				
			||||||
| 
						 | 
					@ -46,3 +49,25 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
 | 
				
			||||||
    new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
 | 
					    new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0)
 | 
				
			||||||
    new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
 | 
					    new_cache_v[:, :, cache_v.size(2):cache_k.size(2) + key_states.size(2), :] = value_states
 | 
				
			||||||
    return new_cache_k, new_cache_v
 | 
					    return new_cache_k, new_cache_v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rotate_half(x):
 | 
				
			||||||
 | 
					    """Rotates half the hidden dims of the input."""
 | 
				
			||||||
 | 
					    x1 = x[..., :x.shape[-1] // 2]
 | 
				
			||||||
 | 
					    x2 = x[..., x.shape[-1] // 2:]
 | 
				
			||||||
 | 
					    return torch.cat((-x2, x1), dim=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
 | 
				
			||||||
 | 
					    if model_family in ["llama", "baichuan"]:
 | 
				
			||||||
 | 
					        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
 | 
				
			||||||
 | 
					        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
				
			||||||
 | 
					        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
				
			||||||
 | 
					        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
				
			||||||
 | 
					        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
 | 
				
			||||||
 | 
					        q_embed = (q * cos) + (rotate_half(q) * sin)
 | 
				
			||||||
 | 
					        k_embed = (k * cos) + (rotate_half(k) * sin)
 | 
				
			||||||
 | 
					        return q_embed, k_embed
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          f"{model_family} is not supported.")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue