# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This file is adapted from # https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py # import math import torch from typing import Optional, Tuple, List import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_sdp import os KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) KV_CACHE_ALLOC_MIN_LENGTH = 512 def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, ) -> List[torch.Tensor]: """Split a tensor along its last dimension. Arguments: tensor: input tensor. num_partitions: number of partitions to split the tensor contiguous_split_chunks: If True, make each chunk contiguous in memory. Returns: A list of Tensors """ # Get the size and dimension. last_dim = tensor.dim() - 1 last_dim_size = tensor.size()[last_dim] // num_partitions # Split. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) # Note: torch.split does not create contiguous tensors by default. if contiguous_split_chunks: return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list @torch.jit.script def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: # x: [sq, b, np, hn] sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) rot_dim = rope_cache.shape[-2] * 2 x, x_pass = x[..., :rot_dim], x[..., rot_dim:] # truncate to support variable sizes rope_cache = rope_cache[:sq] xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], ], -1, ) x_out2 = x_out2.flatten(3) return torch.cat((x_out2, x_pass), dim=-1) def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Tensor, torch.Tensor): # key, value's shape: [bs, n_kv_head, seq_len, head_dim] -> [bs, n_head, seq_len, head_dim] batch_size, n_kv_head, seq_len, head_dim = key.shape key = key.unsqueeze(2) key = key.expand(-1, -1, n_head // n_kv_head, -1, -1) key = key.contiguous().view(batch_size, n_head, seq_len, head_dim) value = value.unsqueeze(2) value = value.expand(-1, -1, n_head // n_kv_head, -1, -1) value = value.contiguous().view(batch_size, n_head, seq_len, head_dim) return key, value def should_split_qkv_tensor(query_layer, bsz, n_head, seq_len): if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None: return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1" elif query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000: # split tensor for memory block limitation # support fp16 and set input length threshold at 5000 for now return True elif query_layer.element_size()*bsz*n_head*seq_len*seq_len >= 4*1024**3: # attn_weight size larger than memory block limitation 4GB return True return False def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous() output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps) return output.reshape(hidden_states.shape) input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states.to(input_dtype) def chatglm2_model_forward( self, input_ids, position_ids: Optional[torch.Tensor]=None, attention_mask: Optional[torch.BoolTensor]=None, full_attention_mask: Optional[torch.BoolTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None, inputs_embeds: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, ): output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape if inputs_embeds is None: inputs_embeds = self.embedding(input_ids) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or ( past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) use_fuse_rope = input_ids.device.type == "xpu" use_fuse_rope = use_fuse_rope and not self.training # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: rotary_pos_emb = rotary_pos_emb[position_ids] else: rotary_pos_emb = rotary_pos_emb[None, :seq_length] if use_fuse_rope: # Repeat cos sin here, call only once for each token. # Chatglm2's rotary embedding is similar to gptj's, is rotate_every_two. # If put this to attension forward, it will generate too many times. cos, sin = rotary_pos_emb.split(rotary_pos_emb.shape[-1] // 2, dim=-1) cos = cos.squeeze(-1) sin = sin.squeeze(-1) cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) rotary_pos_emb = (cos, sin) else: rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() # Run encoder. hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states ) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) def chatglm2_attention_forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): if use_quantize_kv_cache(self.query_key_value, hidden_states.transpose(0, 1)): forward_function = chatglm2_quantized_attention_forward_8eb45c else: forward_function = chatglm2_attention_forward_8eb45c return forward_function( self=self, hidden_states=hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache ) def chatglm2_quantized_attention_forward_8eb45c( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): # hidden_states: [seq_len, bs, head_dim] mixed_x_layer = self.query_key_value(hidden_states) n_head = self.num_attention_heads_per_partition n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head head_dim = self.hidden_size_per_attention_head query_layer, key_layer, value_layer = mixed_x_layer.split( [n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-1, ) query_layer = query_layer.view(query_layer.shape[:-1] + (n_head, head_dim)) key_layer = key_layer.view(key_layer.shape[:-1] + (n_kv_head, head_dim)) value_layer = value_layer.view(value_layer.shape[:-1] + (n_kv_head, head_dim)) # query, key, value's shape: [seq_len, bs, n_head/n_kv_head, head_dim] # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple): # use_fuse_rope, see chatglm2_model_forward cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] query_layer = query_layer.transpose(0, 1) key_layer = key_layer.transpose(0, 1) query_layer_cur = query_layer[..., :rot_dim] key_layer_cur = key_layer[..., :rot_dim] # ipex_llm's apply_rotary_embedding can change the origin storage, # so query_layer will get the result directly. torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur) torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur) query_layer = query_layer.transpose(0, 1) key_layer = key_layer.transpose(0, 1) else: query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) query_layer = query_layer.permute(1, 2, 0, 3) key_layer = key_layer.permute(1, 2, 0, 3) value_layer = value_layer.permute(1, 2, 0, 3) # query, key, value's shape: [bs, n_head/n_kv_head, seq_len, head_dim] batch_size, _, seq_len, _ = query_layer.shape if kv_cache is None: # first token if self.multi_query_attention: key, value = repeat_kv(key_layer, value_layer, n_head) else: key, value = key_layer, value_layer if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): # split second dim to block size = 8 block_size = 8 query_split = torch.split(query_layer, block_size, dim=1) key_split = torch.split(key, block_size, dim=1) value_split = torch.split(value, block_size, dim=1) results = [] for q, k, v in zip(query_split, key_split, value_split): if attention_mask is None: result = F.scaled_dot_product_attention(q, k, v, is_causal=True) else: result = F.scaled_dot_product_attention(q, k, v, attention_mask) results.append(result) context_layer = torch.cat(results, dim=1) else: if attention_mask is None: context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True) else: context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask) context_layer = context_layer.to(query_layer.dtype) if use_cache: k_cache, v_cache = init_fp8_kv_cache(batch_size, n_kv_head, seq_len, head_dim, query_layer.device) k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) else: k_cache, v_cache = kv_cache k_cache = k_cache.permute(1, 2, 0, 3) v_cache = v_cache.permute(1, 2, 0, 3) # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim] k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer) if attention_mask is not None: attention_mask = ~attention_mask attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, device=query_layer.device) if attention_mask.dtype == torch.bool: attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) else: attn_bias += attention_mask else: attn_bias = None if seq_len != 1: key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype) key, value = repeat_kv(key, value, n_head) attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim) if attn_bias is not None: attn += attn_bias attn = F.softmax(attn, dim=-1, dtype=torch.float32) context_layer = torch.matmul(attn.to(value.dtype), value) else: key, value = k_cache, v_cache import linear_q4_0 context_layer = linear_q4_0.sdp_fp8(query_layer, key, value, attn_bias) # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim] context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1) if use_cache: kv_cache = (k_cache.permute(2, 0, 1, 3), v_cache.permute(2, 0, 1, 3)) else: kv_cache = None output = self.dense(context_layer) return output, kv_cache def chatglm2_attention_forward_8eb45c( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= # ===================== # Query, Key, and Value # ===================== # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] device = hidden_states.device mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) query_layer = query_layer.view( query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) key_layer = key_layer.view( key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) value_layer = value_layer.view( value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) else: new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) cur_length, batch_size = query_layer.shape[0], query_layer.shape[1] # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: if len(rotary_pos_emb) == 2 and isinstance(rotary_pos_emb, tuple): # use_fuse_rope, see chatglm2_model_forward cos, sin = rotary_pos_emb rot_dim = cos.shape[-1] query_layer = query_layer.transpose(0, 1) key_layer = key_layer.transpose(0, 1) query_layer_cur = query_layer[..., :rot_dim] key_layer_cur = key_layer[..., :rot_dim] # ipex_llm's apply_rotary_embedding can change the origin storage, # so query_layer will get the result directly. torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur) torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur) query_layer = query_layer.transpose(0, 1) key_layer = key_layer.transpose(0, 1) else: query_layer = apply_rotary_pos_emb_chatglm(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb_chatglm(key_layer, rotary_pos_emb) if self.multi_query_attention: if device.type == "xpu" and batch_size > 1: # use beam_search for generation. # If batch_size > 1 on gpu, permute key/value_layer to [bs, np, sl, hn] # to reduce memory usage. Otherwise,expend key/value_layer to [bs, nh, sl, hn]. key_layer = key_layer.permute(1, 2, 0, 3) # [bs, np, sl, hn] value_layer = value_layer.permute(1, 2, 0, 3) # [bs, np, sl, hn] else: key_length = key_layer.size(0) query_group_size = self.num_attention_heads_per_partition // \ self.num_multi_query_groups_per_partition key_layer = key_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) key_layer = key_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, key_length, self.hidden_size_per_attention_head)) value_layer = value_layer.permute(1, 2, 0, 3).unsqueeze(-3) # [bs, nh/k, sl, hn] value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) value_layer = value_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, key_length, self.hidden_size_per_attention_head)) # adjust key and value for inference if kv_cache is not None: cache_k, cache_v = kv_cache cache_k = cache_k.permute(1, 2, 0, 3) cache_v = cache_v.permute(1, 2, 0, 3) past_length = cache_k.size(2) if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3): max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH if device.type == "xpu" and batch_size > 1: # use beam_search for generation. # If batch_size > 1 on gpu, use init_kv_cache to avoid empty cache for ensuring # generation correctness. # Set the num_heads in init_kv_cache to np, ensuring that the tensors of # new_cache_k/v and key/value_layer have the same size. new_cache_k, new_cache_v = init_kv_cache(batch_size, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head, past_length, max_cache_length, dtype=query_layer.dtype, device=device) else: new_cache_k, new_cache_v = extend_kv_cache(batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, past_length, max_cache_length, dtype=query_layer.dtype, device=device) new_cache_k[:] = cache_k new_cache_v[:] = cache_v cache_k = new_cache_k cache_v = new_cache_v key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer) elif use_cache: max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \ + KV_CACHE_ALLOC_BLOCK_LENGTH if device.type == "xpu" and batch_size > 1: # use beam_search for generation. # Ensure the tensors of key/value_cache and key/value_layer have the same size. nums_per_partition = self.num_multi_query_groups_per_partition else: nums_per_partition = self.num_attention_heads_per_partition key_cache, value_cache = init_kv_cache(batch_size, nums_per_partition, self.hidden_size_per_attention_head, cur_length, max_cache_length, dtype=query_layer.dtype, device=device) key_cache[:] = key_layer value_cache[:] = value_layer key_layer = key_cache value_layer = value_cache # If batch_size > 1, return tensors with shape [bs, np, sl, hn] as past_key_values. This could # reduce memory usage as tensors are not expended to [bs, nh, sl, hn]. # Otherwise, return views of [bs, nh, sl, hn]. cache_key_layer = key_layer cache_value_layer = value_layer if use_cache: kv_cache = (key_layer, value_layer) else: kv_cache = None # ================================== # core attention computation # ================================== if device.type == "xpu" and batch_size > 1: # use beam_search for generation. # If batch_size > 1, expend key/value_layer to [ns, nh, sl, bn] for # core attention computation. # The expanded tensors will not be returned as past_key_values. if self.multi_query_attention: query_group_size = self.num_attention_heads_per_partition // \ self.num_multi_query_groups_per_partition key_layer = key_layer.unsqueeze(-3) key_layer = key_layer.expand(-1, -1, query_group_size, -1, -1) save_length = key_layer.size(3) # [bs, np, sl, hn] --> [bs, nh, sl, hn] key_layer = key_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, save_length, self.hidden_size_per_attention_head)) value_layer = value_layer.unsqueeze(-3) value_layer = value_layer.expand(-1, -1, query_group_size, -1, -1) # [bs, np, sl, hn] --> [bs, nh, sl, hn] value_layer = value_layer.contiguous().view((batch_size, self.num_attention_heads_per_partition, save_length, self.hidden_size_per_attention_head)) context_layer = core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask) # ================= # Output. [sq, b, h] # ================= output = self.dense(context_layer) return output, (cache_key_layer.permute(2, 0, 1, 3), cache_value_layer.permute(2, 0, 1, 3)) def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask): pytorch_major_version = int(torch.__version__.split('.')[0]) if pytorch_major_version >= 2: query_layer = query_layer.permute(1, 2, 0, 3) L, S = query_layer.shape[2], key_layer.shape[2] if attention_mask is None and L == S: batch_size, n_head, seq_len, head_dim = query_layer.shape if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): # split second dim to block size = 8 block_size = 8 query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) key_split = torch.split(key_layer, block_size, dim=1) value_split = torch.split(value_layer, block_size, dim=1) results = [] for q, k, v in zip(query_split, key_split, value_split): result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) results.append(result) context_layer = torch.cat(results, dim=1) else: context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), key_layer, value_layer, is_causal=True).to(key_layer.dtype) else: # attention_mask is not None only when past_key_value is not None and q_len > 1 if attention_mask is not None: attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, device=query_layer.device) attention_mask = ~attention_mask if attention_mask.dtype == torch.bool: attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) else: attn_bias += attention_mask else: attn_bias = None if use_sdp(query_layer.shape[2], key_layer.shape[2], query_layer.shape[-1], query_layer): import linear_q4_0 attn_output = linear_q4_0.sdp(query_layer, key_layer, value_layer, attn_bias) context_layer = attn_output.view(query_layer.shape) else: head_dim = query_layer.size(-1) attn = torch.matmul(query_layer.to(key_layer.dtype), key_layer.transpose(2, 3)) / math.sqrt(head_dim) if attn_bias is not None: attn += attn_bias attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(value_layer.dtype) context_layer = torch.matmul(attn, value_layer) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.reshape(*new_context_layer_shape) else: # Raw attention scores # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(2)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) # preallocting input tensor: [b * np, sq, sk] matmul_input_buffer = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=query_layer.device ) matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=query_layer.device ) # Raw attention scores. [b * np, sq, sk] torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), out=matmul_result ) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] if self.attention_softmax_in_fp32: attention_scores = attention_scores.float() if self.coeff is not None: attention_scores = attention_scores * self.coeff if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool) attention_mask.tril_() attention_mask = ~attention_mask if attention_mask is not None: attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) attention_probs = F.softmax(attention_scores, dim=-1) attention_probs = attention_probs.type_as(value_layer) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.empty( output_size[0] * output_size[1], output_size[2], value_layer.size(-1), dtype=value_layer.dtype, device=value_layer.device, ) torch.bmm(attention_probs, value_layer, out=context_layer) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer