ipex-llm/python/llm/src/ipex_llm/transformers/models/gptbigcode.py
Wang, Jian4 9df70d95eb
Refactor bigdl.llm to ipex_llm (#24)
* Rename bigdl/llm to ipex_llm

* rm python/llm/src/bigdl

* from bigdl.llm to from ipex_llm
2024-03-22 15:41:21 +08:00

101 lines
4.2 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple, Union
import torch
def _attn_wrapper(origin_attn):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_output, attn_weights = origin_attn(self,
query=query.to(key.dtype),
key=key,
value=value,
attention_mask=attention_mask,
head_mask=head_mask)
if query.device.type == 'xpu' and 1 < query.numel() // query.size(-1) <= 64:
attn_output = attn_output.clone()
return attn_output, attn_weights
return _attn
def gptbigcode_attention_forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False):
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
from ipex_llm.utils.common import invalidInputError
invalidInputError(
False,
"If class is used as cross attention," +
"the weights `q_attn` have to be defined. " +
"Please make sure to instantiate class with " +
"`GPTBigCodeAttention(..., is_cross_attention=True)`."
)
query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split(
(self.embed_dim, 2 * self.kv_dim), dim=2)
else:
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)
if layer_past is not None:
if layer_past.shape[-2] == key_value.shape[-2]:
key_value = torch.cat((layer_past, key_value), dim=-2)
else:
fill_zeros = torch.zeros(layer_past.shape[0],
layer_past.shape[1],
key_value.shape[2] - layer_past.shape[2],
dtype=layer_past.dtype,
device=layer_past.device)
layer_past = torch.cat([layer_past, fill_zeros], dim=-1)
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None
key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2),
value, attention_mask, head_mask)
if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
if self.multi_query:
attn_weights = attn_weights.transpose(1, 2)
outputs += (attn_weights,)
return outputs