ipex-llm/python/llm/src/ipex_llm/transformers/models/qwen2.py

331 lines
14 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/qwen2/modeling_qwen2.py
# which is licensed under Apache License 2.0:
#
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from typing import Optional, Tuple, Union, List
import torch
from torch.nn import CrossEntropyLoss
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.models.common import scaled_dot_product_attention
from ipex_llm.transformers.models.utils import SILU, mlp_fusion_check
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, \
should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \
DynamicCompressCache, DynamicCompressFp8Cache
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, Qwen2Attention, Qwen2MLP
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.cache_utils import Cache
def qwen2_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = self.config.hidden_size != 3584 and use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads, self.config.num_key_value_heads
)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end
# `cache_position` is required after transformers 4.42
if cache_position is not None:
kwargs = {"cache_position": cache_position}
else:
kwargs = {}
return Qwen2Model.forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
def qwen2_causal_lm_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, # for transformers >= 4.42
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
# ipex-llm changes start: remove `logits.float()` to reduce memory usage with long input
# logits = logits.float()
# ipex-llm changes end
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, Qwen2Attention)
if isinstance(module, Qwen2Attention) and os.environ.get("IPEX_LLM_LOW_MEM", None) == "1":
del module.rotary_emb.cos_cached
del module.rotary_emb.sin_cached
def padding_mlp(module: torch.nn.Module):
# for qwen 1.5 14B
if isinstance(module, Qwen2MLP):
hidden_size = module.gate_proj.weight.shape[1]
intermediate_size = module.gate_proj.weight.shape[0]
padding_intermediate_size = (intermediate_size + 256 - 1) // 256 * 256
if intermediate_size % 256 == 0:
return
gate_weight = module.gate_proj.weight.data
new_gate_weight = torch.zeros([padding_intermediate_size, hidden_size],
dtype=gate_weight.dtype, device=gate_weight.device)
new_gate_weight[:intermediate_size, :] = gate_weight
if hasattr(module.gate_proj, 'out_features'):
module.gate_proj.out_features = padding_intermediate_size
module.gate_proj.weight = torch.nn.Parameter(new_gate_weight, requires_grad=False)
up_weight = module.up_proj.weight.data
new_up_weight = torch.zeros([padding_intermediate_size, hidden_size],
dtype=up_weight.dtype, device=up_weight.device)
new_up_weight[:intermediate_size, :] = up_weight
if hasattr(module.gate_proj, 'out_features'):
module.up_proj.out_features = padding_intermediate_size
module.up_proj.weight = torch.nn.Parameter(new_up_weight, requires_grad=False)
down_weight = module.down_proj.weight.data
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight
if hasattr(module.gate_proj, 'out_features'):
module.down_proj.in_features = padding_intermediate_size
module.down_proj.weight = torch.nn.Parameter(new_down_weight, requires_grad=False)
def qwen2_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
device = hidden_states.device
# [CompressKV]
from ipex_llm.transformers.kv import DynamicCompressCache
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
else:
# when quant_method is 'gptq'
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \
.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) \
.transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if should_use_fuse_rope(hidden_states, position_ids, self.training):
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = cos.to(device), sin.to(device)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
if past_key_value is not None:
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, 256)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
attn_weights = None
attn_output = scaled_dot_product_attention(
query_states, key_states, value_states,
attention_mask, q_len == kv_seq_len
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def qwen2_mlp_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import xe_linear
return self.down_proj(xe_linear.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
SILU, qtype
))
elif x.device.type == "xpu" and not self.training:
import xe_addons
gate = self.gate_proj(x)
up = self.up_proj(x)
xe_addons.mlp_silu_mul_inplaced(gate, up)
return self.down_proj(gate)
else:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))