refactor gemma to reduce old fuse rope usage (#12215)

This commit is contained in:
Yishuo Wang 2024-10-16 17:40:28 +08:00 committed by GitHub
parent 9104a168f6
commit a4a758656a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 150 additions and 348 deletions

View file

@ -1017,6 +1017,10 @@ def _optimize_pre(model, qtype=None):
model.apply(pre_process_attn_and_mlp) model.apply(pre_process_attn_and_mlp)
if model.config.model_type == "internvl_chat": if model.config.model_type == "internvl_chat":
_optimize_pre(model.language_model, qtype=qtype) _optimize_pre(model.language_model, qtype=qtype)
if model.config.model_type == "gemma":
from ipex_llm.transformers.models.gemma import merge_qkv, pre_compute_inv_freq
model.apply(merge_qkv)
model.apply(pre_compute_inv_freq)
if model.config.model_type == "gemma2": if model.config.model_type == "gemma2":
from ipex_llm.transformers.models.gemma2 import merge_qkv from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv) model.apply(merge_qkv)
@ -1846,32 +1850,16 @@ def _optimize_post(model, lightweight_bmm=False):
module.MistralMLP, module.MistralMLP,
llama_mlp_forward) llama_mlp_forward)
elif model.config.model_type == "gemma": elif model.config.model_type == "gemma":
invalidInputError(version.parse(trans_version) >= version.parse("4.38.0"),
"Please upgrade transformers to 4.38.0 or higher version "
"to run Mixtral models.")
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
if version.parse(trans_version) >= version.parse("4.39.0"): from ipex_llm.transformers.models.gemma import gemma_model_forward
from ipex_llm.transformers.models.gemma import gemma_attention_forward_4_39 from ipex_llm.transformers.models.gemma import gemma_attention_forward
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward_4_39
)
else:
from ipex_llm.transformers.models.gemma import gemma_attention_forward
convert_forward(model,
module.GemmaAttention,
gemma_attention_forward,
)
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma import gemma_mlp_forward from ipex_llm.transformers.models.common import mlp_gelu_forward
convert_forward(model, convert_forward(model, module.GemmaModel, gemma_model_forward)
module.GemmaRMSNorm, convert_forward(model, module.GemmaAttention, gemma_attention_forward)
gemma_rms_norm_forward) convert_forward(model, module.GemmaRMSNorm, gemma_rms_norm_forward)
convert_forward(model, convert_forward(model, module.GemmaMLP, mlp_gelu_forward)
module.GemmaMLP,
gemma_mlp_forward)
elif model.config.model_type == "gemma2": elif model.config.model_type == "gemma2":
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)

View file

@ -31,50 +31,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax
from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5
from ipex_llm.transformers.models.utils import use_decoding_fast_path
import os from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) from transformers.models.gemma.modeling_gemma import apply_rotary_pos_emb, repeat_kv
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, GemmaAttention
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def merge_qkv(module: torch.nn.Module):
cos = cos.unsqueeze(unsqueeze_dim) merge_qkv_base(module, GemmaAttention)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def pre_compute_inv_freq(module: torch.nn.Module):
""" if isinstance(module, GemmaRotaryEmbedding):
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). module.inv_freq = 1.0 / (
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) module.base **
to (batch, num_attention_heads, seqlen, head_dim) (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim)
""" )
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def should_use_fuse_rope(self, hidden_states, position_ids):
use_fuse_rope = hidden_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad)
use_fuse_rope = use_fuse_rope and position_ids is not None
return use_fuse_rope
def gemma_rms_norm_forward(self, hidden_states): def gemma_rms_norm_forward(self, hidden_states):
@ -91,29 +72,110 @@ def gemma_rms_norm_forward(self, hidden_states):
return (1 + self.weight) * hidden_states.to(input_dtype) return (1 + self.weight) * hidden_states.to(input_dtype)
def gemma_mlp_forward( def gemma_model_forward(
self, self,
x: torch.Tensor, input_ids: torch.LongTensor = None,
residual=None attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: position_ids: Optional[torch.LongTensor] = None,
x_2d = x.view(-1, x.shape[-1]) past_key_values: Optional[Cache] = None,
bsz, hidden_size = x_2d.shape inputs_embeds: Optional[torch.FloatTensor] = None,
qtype = getattr(self.gate_proj, "qtype", None) use_cache: Optional[bool] = None,
if mlp_fusion_check(x_2d, qtype, self.training) and not self.down_proj.enable_xetla: output_attentions: Optional[bool] = None,
import xe_linear output_hidden_states: Optional[bool] = None,
if not x_2d.is_contiguous(): return_dict: Optional[bool] = None,
x_2d = x_2d.contiguous() cache_position: Optional[torch.LongTensor] = None,
out = self.down_proj(xe_linear.mlp_forward_xpu( ) -> Union[Tuple, BaseModelOutputWithPast]:
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, output_attentions = (
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, output_attentions if output_attentions is not None
GELU, qtype else self.config.output_attentions
)) )
else: output_hidden_states = (
out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) output_hidden_states if output_hidden_states is not None
if residual is not None: else self.config.output_hidden_states
return out + residual )
else: use_cache = use_cache if use_cache is not None else self.config.use_cache
return out
# IPEX-LLM OPT start: kv cache and quantize kv cache
if use_cache and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# IPEX-LLM OPT end
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
"You cannot specify both input_ids and inputs_embeds at the same time, "
"and must specify either one")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# IPEX-LLM changes start: support both transformers 4.38.1 and 4.39
try:
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
causal_mask = causal_mask[:, :, cache_position, :]
except TypeError as _e:
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
# IPEX-LLM changes end
# embed positions
hidden_states = inputs_embeds
# normalized
hidden_states = hidden_states * (self.config.hidden_size**0.5)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def gemma_attention_forward( def gemma_attention_forward(
@ -126,111 +188,27 @@ def gemma_attention_forward(
use_cache: bool=False, use_cache: bool=False,
cache_position: Optional[torch.Tensor]=None, cache_position: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) qkv = self.qkv_proj(hidden_states)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
decoding_fast_path = use_decoding_fast_path(self.q_proj, qkv = qkv.transpose(1, 2)
use_fuse_rope, query_states, key_states, value_states = qkv.split([self.num_heads,
enough_kv_room, self.num_key_value_heads,
bsz * q_len) self.num_key_value_heads], dim=1)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
kv_seq_len = cache_k.shape[-2]
import xe_linear
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
past_key_value.seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
if should_use_fuse_rope(hidden_states, position_ids, self.training):
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else: else:
query_states = self.q_proj(hidden_states) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
key_states = self.k_proj(hidden_states) query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
value_states = self.v_proj(hidden_states) cos, sin, None)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) if past_key_value is not None:
key_states = key_states.view(bsz, q_len, key_states, value_states = past_key_value.update(key_states, value_states,
self.num_key_value_heads, self.head_dim).transpose(1, 2) self.layer_idx, None)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
"to initialize the attention class with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "gemma")
else:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, None)
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value.seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -238,26 +216,15 @@ def gemma_attention_forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
else:
causal_mask = attention_mask
attn_weights = attn_weights + causal_mask attn_weights = attn_weights + causal_mask
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = attention_softmax(attn_weights, self.training)
dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1) attn_output = attn_output.view(bsz, q_len, -1)
@ -266,157 +233,4 @@ def gemma_attention_forward(
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def gemma_attention_forward_4_39(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor]=None,
position_ids: Optional[torch.LongTensor]=None,
past_key_value: Optional[Tuple[torch.Tensor]]=None,
output_attentions: bool=False,
use_cache: bool=False,
cache_position: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
decoding_fast_path = use_decoding_fast_path(self.q_proj,
use_fuse_rope,
enough_kv_room,
bsz * q_len)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
kv_seq_len = cache_k.shape[-2]
import xe_linear
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
self.q_proj.weight,
self.k_proj.weight,
self.v_proj.weight,
position_ids,
cache_k, cache_v,
self.q_proj.weight.qtype,
self.v_proj.weight.qtype,
kv_seq_len,
self.head_dim)
kv_seq_len += 1
# update past_key_value's seem_tokens and kv caches.
if self.layer_idx == 0:
past_key_value._seen_tokens = kv_seq_len
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len,
self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
invalidInputError(False,
"The cache structure has changed since version v4.36. "
f"If you are using {self.__class__.__name__} for "
"auto-regressive decodingwith k/v caching, please make sure "
"to initialize the attention class with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "gemma")
else:
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, None)
if past_key_value is not None:
# update the number of seen tokens
if self.layer_idx == 0:
past_key_value._seen_tokens += key_states.shape[-2]
# reuse k, v, self_attention
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
if len(past_key_value.key_cache) <= self.layer_idx:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
cache_k = past_key_value.key_cache[self.layer_idx]
cache_v = past_key_value.value_cache[self.layer_idx]
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
self.num_key_value_heads, # Support GQA
self.head_dim,
cache_k.size(2),
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=device)
new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = append_kv_cache(cache_k, cache_v,
key_states, value_states)
# update past_key_value
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
if cache_position is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
else:
causal_mask = attention_mask
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output.to(original_dtype), attn_weights, past_key_value