ipex-llm/python/llm/src/ipex_llm/transformers/models/common.py
2024-12-19 17:23:01 +08:00

288 lines
12 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from typing import List
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
if hasattr(linears[0], "weight"):
# For GPTQ model, it might be qweight
new_weight = torch.cat(list(linear.weight.data for linear in linears), dim=0)
if linears[0].bias is not None:
new_linear = torch.nn.Linear(0, 0, bias=True)
new_bias = torch.cat(list(linear.bias.data for linear in linears), dim=0)
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
else:
new_linear = torch.nn.Linear(0, 0, bias=False)
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
return new_linear
else:
return None
def merge_qkv_base(module: torch.nn.Module, attention_class):
if (
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
or not isinstance(attention_class, str) and isinstance(module, attention_class)
):
qkv_proj = merge_linear([
module.q_proj,
module.k_proj,
module.v_proj,
])
if qkv_proj is not None:
module.qkv_proj = qkv_proj
del module.q_proj, module.k_proj, module.v_proj
def padding_linear_hd(linear: torch.nn.Linear,
old_head_dim: int, new_head_dim: int) -> torch.nn.Linear:
in_features, out_features = linear.in_features, linear.out_features
weight = linear.weight.data
weight = weight.view(-1, old_head_dim, in_features)
new_weight = torch.empty([weight.size(0), new_head_dim, in_features],
dtype=weight.dtype, device=weight.device)
new_weight[:, :old_head_dim, :] = weight
new_weight[:, old_head_dim:, :] = 0
new_weight = new_weight.view(-1, in_features)
if linear.bias is not None:
bias = linear.bias.data
bias = bias.view(-1, old_head_dim)
new_bias = torch.empty([bias.size(0), new_head_dim],
dtype=bias.dtype, device=bias.device)
new_bias[:, :old_head_dim] = bias
new_bias[:, old_head_dim:] = 0
new_bias = new_bias.flatten()
new_linear = torch.nn.Linear(0, 0, bias=True)
new_linear.bias = torch.nn.Parameter(new_bias, requires_grad=False)
else:
new_linear = torch.nn.Linear(0, 0, bias=False)
new_linear.weight = torch.nn.Parameter(new_weight, requires_grad=False)
new_linear.in_features = new_weight.size(1)
new_linear.out_features = new_weight.size(0)
return new_linear
def padding_attention_hd_base(module: torch.nn.Module, attention_class,
old_head_dim: int, new_head_dim: int):
if (
isinstance(attention_class, str) and module.__class__.__name__ == attention_class
or not isinstance(attention_class, str) and isinstance(module, attention_class)
) and module.head_dim == old_head_dim:
module.q_proj = padding_linear_hd(module.q_proj, old_head_dim, new_head_dim)
module.k_proj = padding_linear_hd(module.k_proj, old_head_dim, new_head_dim)
module.v_proj = padding_linear_hd(module.v_proj, old_head_dim, new_head_dim)
module.head_dim = new_head_dim
module.old_head_dim = old_head_dim
def padding_states_hd(states: torch.Tensor, old_head_dim: int, new_head_dim: int):
bsz, num_heads, seq_len, head_dim = states.size()
if head_dim == old_head_dim and old_head_dim < new_head_dim:
new_states = torch.empty([bsz, num_heads, seq_len, new_head_dim],
dtype=states.dtype, device=states.device)
new_states[:, :, :, :old_head_dim] = states
new_states[:, :, :, old_head_dim:] = 0
return new_states
return states
def padding_qkv_hd(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
old_head_dim: int, new_head_dim: int):
return (
padding_states_hd(q, old_head_dim, new_head_dim),
padding_states_hd(k, old_head_dim, new_head_dim),
padding_states_hd(v, old_head_dim, new_head_dim),
)
def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor):
from ipex_llm.transformers.models.utils import mlp_fusion_check
qtype = getattr(module.gate_proj, "qtype", None)
if mlp_fusion_check(x, qtype, module.training):
import xe_linear
x_2d = x.contiguous().view(-1, x.size(-1))
output = module.down_proj(
xe_linear.mlp_forward_xpu(
x_2d, module.gate_proj.weight.data, module.up_proj.weight.data,
x_2d.size(0), x_2d.size(1), module.gate_proj.out_len,
act, qtype
)
)
return output.view(x.shape)
else:
return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x))
def mlp_silu_forward(self, x: torch.Tensor):
from ipex_llm.transformers.models.utils import SILU
return fuse_mlp_base(self, SILU, x)
def mlp_gelu_forward(self, x: torch.Tensor):
from ipex_llm.transformers.models.utils import GELU
return fuse_mlp_base(self, GELU, x)
def attention_softmax(attn_weights: torch.Tensor):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu":
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
else:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(attn_weights.dtype)
return attn_weights
def rms_norm_forward(self, hidden_states: torch.Tensor):
weight = self.weight
if hasattr(self, "variance_epsilon"):
eps = self.variance_epsilon
else:
eps = self.epsilon
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons
x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
output = xe_addons.rms_norm(weight, x_2d, eps)
return output.reshape(hidden_states.shape)
else:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return weight * hidden_states.to(input_dtype)
def layer_norm_forward(self, hidden_states: torch.Tensor):
if hidden_states.device.type == 'xpu' and hidden_states.dtype in [torch.float, torch.half]:
import xe_addons
hidden_size = math.prod(self.normalized_shape)
x_2d = hidden_states.reshape(-1, hidden_size).contiguous()
output = xe_addons.layer_norm(x_2d, self.weight, self.bias, self.eps)
return output.reshape(hidden_states.shape)
else:
return torch.nn.functional.layer_norm(
hidden_states, self.normalized_shape,
self.weight, self.bias, self.eps
)
def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
max_kvs = 128
padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
if mask is None:
if is_causal:
mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask.triu_(1)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
elif seq_length != kv_length and seq_length <= 32:
mask = None
else:
mask = torch.zeros([1, 1, 1, padding_kv_length], dtype=dtype, device=device)
mask[:, :, kv_length:padding_kv_length] = torch.finfo(dtype).min
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
else:
if seq_length != kv_length and seq_length <= 32:
mask = mask[..., :seq_length, :kv_length]
mask = mask.expand([bsz, n_heads, seq_length, kv_length])
elif mask.size(3) != padding_kv_length:
new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
dtype=dtype, device=device)
new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
mask.set_(new_mask) # modify `mask` inplaced
else:
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
return mask
def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, mask: torch.Tensor = None,
is_causal: bool = False, scale: float = None) -> torch.Tensor:
bsz, n_heads, seq_length, head_dim = query.shape
_, n_kv_heads, kv_length, _ = key.shape
dtype, device = query.dtype, query.device
if (
device.type == "xpu"
and dtype in [torch.float, torch.half]
and head_dim in [64, 80, 96, 128]
):
# prepare scale
scale = 1 / math.sqrt(head_dim) if scale is None else scale
# prepare mask
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
# compute
# import xe_addons
# if is_causal:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
# elif seq_length != kv_length and seq_length <= 32:
# # todo: add scale support
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
# else:
# attn_output = xe_addons.sdp(query, key, value, mask)
# else:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
import xe_addons
if is_causal:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask)
elif seq_length != kv_length and seq_length <= 32:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
return attn_output
else:
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
from ipex_llm.transformers.models.utils import repeat_kv
if n_heads != n_kv_heads:
key = repeat_kv(key, n_heads // n_kv_heads)
value = repeat_kv(value, n_heads // n_kv_heads)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
return attn_output