diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 26ce1936..5799b3f4 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -155,13 +155,22 @@ def optimize(model): if "chatglm2" in model.config._name_or_path: modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward_8eb45c - from bigdl.llm.transformers.models.chatglm import core_attn_forward_8eb45c + from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c + from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c convert_forward(model, module.SelfAttention, - chatglm_attention_forward_8eb45c + chatglm2_attention_forward_8eb45c ) convert_forward(model, module.CoreAttention, core_attn_forward_8eb45c) + elif "chatglm" in model.config._name_or_path: + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward + convert_forward(model, + module.SelfAttention, + chatglm_attention_forward + ) + return model diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm.py b/python/llm/src/bigdl/llm/transformers/models/chatglm.py index d8a55f19..4503e0d2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm.py @@ -14,198 +14,124 @@ # limitations under the License. # # This file is adapted from -# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py +# https://huggingface.co/THUDM/chatglm-6b/blob/63ce1bac4a7a7da57c67448bab39ddbe0e115a19/configuration_chatglm.py # +import math import torch -from typing import Optional, Tuple, Union, List, Callable, Dict, Any +import torch.utils.checkpoint import torch.nn.functional as F +from typing import Optional, Tuple +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): + # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] + cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ + F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) + q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + return q, k + KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_MIN_LENGTH = 512 -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -def chatglm_attention_forward_8eb45c( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True +def attention_fn( + self, + query_layer, + key_layer, + value_layer, + attention_mask, + hidden_size_per_partition, + layer_id, + layer_past=None, + scaling_attention_score=True, + use_cache=False, ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, - self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + key_layer = key_layer.permute(1, 2, 0, 3).contiguous() + value_layer = value_layer.permute(1, 2, 0, 3).contiguous() + # query_layer = query_layer.permute(1, 2, 0, 3) cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] - if self.multi_query_attention: - key_length = key_layer.size(0) - query_group_size = self.num_attention_heads_per_partition // \ - self.num_multi_query_groups_per_partition - key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] - key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) - key_layer = key_layer.contiguous().view((batch_size, - self.num_attention_heads_per_partition, - key_length, - self.hidden_size_per_attention_head)) - value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) - value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) - value_layer = value_layer.contiguous().view((batch_size, - self.num_attention_heads_per_partition, - key_length, - self.hidden_size_per_attention_head)) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - past_length = cache_k.size(2) - + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + past_length = past_key.size(2) if past_length + cur_length > self.max_cache_length: self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH self.kv_cache = (torch.empty(batch_size, - self.num_attention_heads_per_partition, + self.num_attention_heads, self.max_cache_length, self.hidden_size_per_attention_head,), torch.empty(batch_size, - self.num_attention_heads_per_partition, + self.num_attention_heads, self.max_cache_length, self.hidden_size_per_attention_head,)) - self.kv_cache[0][:, :, :past_length, :] = cache_k - self.kv_cache[1][:, :, :past_length, :] = cache_v + self.kv_cache[0][:, :, :past_length, :] = past_key + self.kv_cache[1][:, :, :past_length, :] = past_value + self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer - key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :] value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :] elif use_cache: self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH - self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition, + self.kv_cache = (torch.empty(batch_size, self.num_attention_heads, self.max_cache_length, self.hidden_size_per_attention_head,), - torch.empty(batch_size, self.num_attention_heads_per_partition, + torch.empty(batch_size, self.num_attention_heads, self.max_cache_length, self.hidden_size_per_attention_head,)) self.kv_cache[0][:, :, :cur_length, :] = key_layer self.kv_cache[1][:, :, :cur_length, :] = value_layer + # seqlen, batch, num_attention_heads, hidden_size_per_attention_head + b, nh, seq_len, hidden_size = key_layer.shape + if use_cache: - kv_cache = (key_layer, value_layer) + present = (key_layer, value_layer) else: - kv_cache = None + present = None - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split('.')[0]) if query_layer.size(0) > 1 and pytorch_major_version >= 2: query_layer = query_layer.permute(1, 2, 0, 3) if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, - key_layer, - value_layer, - is_causal=True) + + if torch.is_autocast_cpu_enabled(): + attention_mask = torch.ones(query_layer.shape[2], + key_layer.shape[2], + dtype=torch.bool).tril(diagonal=0) + attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) + query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) + key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) + value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=False) + else: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=True) else: if attention_mask is not None: attention_mask = ~attention_mask + attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + if torch.is_autocast_cpu_enabled(): + query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) + key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) + value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) + attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, @@ -213,8 +139,16 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) + attention_probs = None + else: - # Raw attention scores + query_key_layer_scaling_coeff = float(layer_id + 1) + if scaling_attention_score: + query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), @@ -225,47 +159,44 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=query_layer.device ) - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], output_size[3], dtype=query_layer.dtype, + ) + + torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(1, 2), # [b * np, hn, sk] beta=0.0, - alpha=(1.0 / self.norm_factor), - ) + alpha=1.0, + out=matmul_result) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: + if self.scale_mask_softmax: + self.scale_mask_softmax.scale = query_key_layer_scaling_coeff + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask.contiguous()) + else: + if not (attention_mask == 0).all(): + # if auto-regressive, skip + attention_scores.masked_fill_(attention_mask, -10000.0) + dtype = attention_scores.dtype attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) + attention_scores = attention_scores * query_key_layer_scaling_coeff + + attention_probs = F.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.type(dtype) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= @@ -276,18 +207,97 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) + # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer) + context_layer = torch.empty( + output_size[0] * output_size[1], + output_size[2], value_layer.size(-1), dtype=value_layer.dtype,) + torch.bmm(attention_probs, value_layer, out=context_layer) + # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - return context_layer + outputs = (context_layer, present, attention_probs) + + return outputs + + +def chatglm_attention_forward( + self, + hidden_states: torch.Tensor, + position_ids, + attention_mask: torch.Tensor, + layer_id, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, + use_cache: bool = False, + output_attentions: bool = False, +): + """ + hidden_states: [seq_len, batch, hidden_size] + attention_mask: [(1, 1), seq_len, seq_len] + """ + + # [seq_len, batch, 3 * hidden_size] + mixed_raw_layer = self.query_key_value(hidden_states) + + # [seq_len, batch, 3 * hidden_size] --> + # [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] + new_tensor_shape = mixed_raw_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) + + if self.position_encoding_2d: + q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) + k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) + cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) + position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ + position_ids[:, 1, :].transpose(0, 1).contiguous() + q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) + q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) + query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) + key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) + else: + position_ids = position_ids.transpose(0, 1) + cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) + # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] + query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, + cos, sin, position_ids) + + # [seq_len, batch, hidden_size] + context_layer, present, attention_probs = attention_fn( + self=self, + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + hidden_size_per_partition=self.hidden_size_per_partition, + layer_id=layer_id, + layer_past=layer_past, + use_cache=use_cache + ) + + output = self.dense(context_layer) + + outputs = (output, present) + + if output_attentions: + outputs += (attention_probs,) + + return outputs # output, present, attention_probs diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py new file mode 100644 index 00000000..9f6fe65c --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -0,0 +1,326 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is adapted from +# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py +# + +import torch +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +import torch.nn.functional as F + + +KV_CACHE_ALLOC_BLOCK_LENGTH = 256 +KV_CACHE_ALLOC_MIN_LENGTH = 512 + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +def chatglm2_attention_forward_8eb45c( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True +): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, + self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] + + if self.multi_query_attention: + key_length = key_layer.size(0) + query_group_size = self.num_attention_heads_per_partition // \ + self.num_multi_query_groups_per_partition + key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] + key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) + key_layer = key_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + key_length, + self.hidden_size_per_attention_head)) + value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) + value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) + value_layer = value_layer.contiguous().view((batch_size, + self.num_attention_heads_per_partition, + key_length, + self.hidden_size_per_attention_head)) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + past_length = cache_k.size(2) + + if past_length + cur_length > self.max_cache_length: + self.max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH + self.kv_cache = (torch.empty(batch_size, + self.num_attention_heads_per_partition, + self.max_cache_length, + self.hidden_size_per_attention_head,), + torch.empty(batch_size, + self.num_attention_heads_per_partition, + self.max_cache_length, + self.hidden_size_per_attention_head,)) + self.kv_cache[0][:, :, :past_length, :] = cache_k + self.kv_cache[1][:, :, :past_length, :] = cache_v + self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer + self.kv_cache[1][:, :, past_length:past_length + cur_length, :] = value_layer + + key_layer = self.kv_cache[0][:, :, :past_length + cur_length, :] + value_layer = self.kv_cache[1][:, :, :past_length + cur_length, :] + + elif use_cache: + self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + + KV_CACHE_ALLOC_BLOCK_LENGTH + self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition, + self.max_cache_length, self.hidden_size_per_attention_head,), + torch.empty(batch_size, self.num_attention_heads_per_partition, + self.max_cache_length, self.hidden_size_per_attention_head,)) + self.kv_cache[0][:, :, :cur_length, :] = key_layer + self.kv_cache[1][:, :, :cur_length, :] = value_layer + + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if query_layer.size(0) > 1 and pytorch_major_version >= 2: + query_layer = query_layer.permute(1, 2, 0, 3) + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + + if torch.is_autocast_cpu_enabled(): + attention_mask = torch.ones(query_layer.shape[2], + key_layer.shape[2], + dtype=torch.bool).tril(diagonal=0) + attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) + query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) + key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) + value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=False) + else: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), ) + if torch.is_autocast_cpu_enabled(): + query_layer = query_layer.to(torch.get_autocast_cpu_dtype()) + key_layer = key_layer.to(torch.get_autocast_cpu_dtype()) + value_layer = value_layer.to(torch.get_autocast_cpu_dtype()) + attention_mask = attention_mask.to(torch.get_autocast_cpu_dtype()) + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, + key_layer, + value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(2)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], + output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + matmul_result = torch.empty( + output_size[0] * output_size[1], + output_size[2], output_size[3], dtype=query_layer.dtype, + ) + + # Raw attention scores. [b * np, sq, sk] + torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + out=matmul_result + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(0), value_layer.size(1), + query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.empty( + output_size[0] * output_size[1], + output_size[2], value_layer.size(-1), dtype=value_layer.dtype, + ) + torch.bmm(attention_probs, value_layer, out=context_layer) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer