* Rename bigdl/llm to ipex_llm * rm python/llm/src/bigdl * from bigdl.llm to from ipex_llm
101 lines
4.2 KiB
Python
101 lines
4.2 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
|
|
from typing import Optional, Tuple, Union
|
|
import torch
|
|
|
|
|
|
def _attn_wrapper(origin_attn):
|
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
|
attn_output, attn_weights = origin_attn(self,
|
|
query=query.to(key.dtype),
|
|
key=key,
|
|
value=value,
|
|
attention_mask=attention_mask,
|
|
head_mask=head_mask)
|
|
if query.device.type == 'xpu' and 1 < query.numel() // query.size(-1) <= 64:
|
|
attn_output = attn_output.clone()
|
|
return attn_output, attn_weights
|
|
return _attn
|
|
|
|
|
|
def gptbigcode_attention_forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
layer_past: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
use_cache: Optional[bool] = False,
|
|
output_attentions: Optional[bool] = False):
|
|
|
|
if encoder_hidden_states is not None:
|
|
if not hasattr(self, "q_attn") or not self.is_cross_attention:
|
|
from ipex_llm.utils.common import invalidInputError
|
|
invalidInputError(
|
|
False,
|
|
"If class is used as cross attention," +
|
|
"the weights `q_attn` have to be defined. " +
|
|
"Please make sure to instantiate class with " +
|
|
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
|
|
)
|
|
|
|
query = self.q_attn(hidden_states)
|
|
key_value = self.c_attn(encoder_hidden_states)
|
|
attention_mask = encoder_attention_mask
|
|
elif self.multi_query:
|
|
query, key_value = self.c_attn(hidden_states).split(
|
|
(self.embed_dim, 2 * self.kv_dim), dim=2)
|
|
else:
|
|
query, key_value = (
|
|
self.c_attn(hidden_states)
|
|
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
|
|
.transpose(1, 2)
|
|
.split((self.head_dim, 2 * self.head_dim), dim=3)
|
|
)
|
|
|
|
if layer_past is not None:
|
|
if layer_past.shape[-2] == key_value.shape[-2]:
|
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
|
else:
|
|
fill_zeros = torch.zeros(layer_past.shape[0],
|
|
layer_past.shape[1],
|
|
key_value.shape[2] - layer_past.shape[2],
|
|
dtype=layer_past.dtype,
|
|
device=layer_past.device)
|
|
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
|
|
key_value = torch.cat((layer_past, key_value), dim=-2)
|
|
present = key_value if use_cache else None
|
|
|
|
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
|
|
|
|
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2),
|
|
value, attention_mask, head_mask)
|
|
|
|
if not self.multi_query:
|
|
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
|
|
attn_output = self.c_proj(attn_output)
|
|
attn_output = self.resid_dropout(attn_output)
|
|
|
|
outputs = (attn_output, present)
|
|
if output_attentions:
|
|
if self.multi_query:
|
|
attn_weights = attn_weights.transpose(1, 2)
|
|
outputs += (attn_weights,)
|
|
|
|
return outputs
|