ipex-llm/python/llm/src/ipex_llm/transformers/models/gemma.py
2024-04-19 13:57:48 +08:00

269 lines
12 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
from torch import nn
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import use_decoding_fast_path
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def should_use_fuse_rope(self, hidden_states, position_ids):
use_fuse_rope = hidden_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
use_fuse_rope = use_fuse_rope and position_ids is not None
return use_fuse_rope
def gemma_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = linear_q4_0.rms_norm(self.weight + 1, x_2d, self.eps)
return output.reshape(hidden_states.shape)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return (1 + self.weight) * hidden_states.to(input_dtype)
def gemma_mlp_forward(
self,
x: torch.Tensor,
residual=None
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
bsz, hidden_size = x_2d.shape
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla:
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
out = self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
GELU, qtype
))
else:
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
if residual is not None:
return out + residual
else:
return out
def gemma_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
output_attentions: bool=False,
use_cache: bool=False,
cache_position: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
kv_seq_len = cache_k.shape[-2]
import linear_q4_0
query_states, key_states, value_states = linear_q4_0.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
"to initialize the attention class with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "gemma")
else:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, None)
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None:
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
else:
causal_mask = attention_mask
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value