* temp * add batch * fix style * update package name * fix style * add workflow * use temp version to run uts * trigger performance test * trigger win igpu perf * revert workflow & setup
321 lines
10 KiB
Python
321 lines
10 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Some parts of this file is adapted from
|
|
# https://huggingface.co/RWKV/rwkv-5-world-3b/blob/main/modeling_rwkv5.py
|
|
# which is licensed under Apache License 2.0:
|
|
#
|
|
# Copyright 2023 Bo Peng and HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from typing import List, Optional
|
|
from ipex_llm.utils.common.log4Error import invalidInputError
|
|
|
|
|
|
def extract_key_value(self, hidden, state=None):
|
|
# Mix hidden with the previous timestep to produce key, value, receptance
|
|
if hidden.size(1) == 1 and state is not None:
|
|
shifted = state[0][self.layer_id]
|
|
else:
|
|
shifted = self.time_shift(hidden)
|
|
if state is not None:
|
|
shifted[:, 0] = state[0][self.layer_id]
|
|
if len(shifted.size()) == 2:
|
|
shifted = shifted.unsqueeze(1)
|
|
|
|
if not hasattr(self, "mixed_mix"):
|
|
self.mixed_mix = torch.cat([
|
|
self.time_mix_key.data,
|
|
self.time_mix_value.data,
|
|
self.time_mix_receptance.data,
|
|
self.time_mix_gate.data,
|
|
]).to(dtype=hidden.dtype)
|
|
|
|
import xe_linear
|
|
mixed_result = xe_linear.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
|
key, value, receptance, gate = mixed_result
|
|
|
|
key = self.key(key)
|
|
value = self.value(value)
|
|
receptance = self.receptance(receptance)
|
|
gate = F.silu(self.gate(gate))
|
|
|
|
if state is not None:
|
|
state[0][self.layer_id] = hidden[:, -1]
|
|
|
|
return receptance, key, value, gate, state
|
|
|
|
|
|
def rwkv_linear_attention_xpu(
|
|
B: int,
|
|
H: int,
|
|
S: int,
|
|
T: int,
|
|
hidden: torch.Tensor,
|
|
time_decay: torch.Tensor,
|
|
time_first: torch.Tensor,
|
|
receptance: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
gate: torch.Tensor,
|
|
lxw: torch.Tensor,
|
|
lxb: torch.Tensor,
|
|
ow: torch.nn.Linear,
|
|
state: torch.Tensor,
|
|
):
|
|
key = key.float().view(B, T, H, S).transpose(1, 2)
|
|
value = value.float().view(B, T, H, S).transpose(1, 2)
|
|
receptance = receptance.float().view(B, T, H, S).transpose(1, 2)
|
|
|
|
time_decay = torch.exp(-torch.exp(time_decay.float()))
|
|
time_first = time_first.float()
|
|
|
|
# `state` will be updated inplaced during this call
|
|
import xe_linear
|
|
out = xe_linear.rwkv_linear_attention_v5(
|
|
time_decay,
|
|
time_first,
|
|
receptance,
|
|
key,
|
|
value,
|
|
state,
|
|
)
|
|
|
|
lxw = lxw.float()
|
|
lxb = lxb.float()
|
|
|
|
out = out.reshape(B * T, H * S)
|
|
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
|
|
out = out.to(dtype=hidden.dtype) * gate
|
|
# out = out @ ow
|
|
out = ow(out)
|
|
return out
|
|
|
|
|
|
def rwkv_linear_attention_cpu(
|
|
B,
|
|
H,
|
|
S,
|
|
T,
|
|
n_head,
|
|
hidden,
|
|
time_decay,
|
|
time_first,
|
|
receptance,
|
|
key,
|
|
value,
|
|
gate,
|
|
lxw,
|
|
lxb,
|
|
ow,
|
|
state,
|
|
):
|
|
key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
|
|
value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
|
receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
|
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
|
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
|
lxw = lxw.float()
|
|
lxb = lxb.float()
|
|
out = torch.zeros_like(key).reshape(B, T, H, S)
|
|
for t in range(T):
|
|
rt = receptance[:, :, t:t + 1, :]
|
|
kt = key[:, :, :, t:t + 1]
|
|
vt = value[:, :, t:t + 1, :]
|
|
at = kt @ vt
|
|
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
|
|
with torch.no_grad():
|
|
state = at + time_decay * state
|
|
|
|
out = out.reshape(B * T, H * S)
|
|
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
|
|
out = out.to(dtype=hidden.dtype) * gate
|
|
# out = out @ ow
|
|
out = ow(out) # fix this
|
|
|
|
return out, state
|
|
|
|
|
|
def rwkv_attention_forward(
|
|
self,
|
|
hidden: torch.Tensor,
|
|
state: List[torch.Tensor]=None,
|
|
use_cache: bool=False,
|
|
seq_mode: bool=True,
|
|
):
|
|
B = hidden.shape[0]
|
|
H = self.time_decay.shape[0]
|
|
S = hidden.shape[-1] // H
|
|
T = hidden.shape[1]
|
|
|
|
if hidden.device.type == "xpu":
|
|
receptance, key, value, gate, state = extract_key_value(self, hidden, state)
|
|
# `state`` will be updated inplaced when running on GPU
|
|
rwkv = rwkv_linear_attention_xpu(
|
|
B, H, S, T,
|
|
hidden,
|
|
self.time_decay,
|
|
self.time_faaaa,
|
|
receptance,
|
|
key,
|
|
value,
|
|
gate,
|
|
self.ln_x.weight,
|
|
self.ln_x.bias,
|
|
self.output,
|
|
state=state[1][self.layer_id],
|
|
)
|
|
else:
|
|
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state)
|
|
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
|
rwkv, layer_state = rwkv_linear_attention_cpu(
|
|
B, H, S, T,
|
|
self.num_attention_heads,
|
|
hidden,
|
|
self.time_decay,
|
|
self.time_faaaa,
|
|
receptance,
|
|
key,
|
|
value,
|
|
gate,
|
|
self.ln_x.weight,
|
|
self.ln_x.bias,
|
|
self.output,
|
|
state=layer_state,
|
|
)
|
|
if layer_state is not None:
|
|
state[1][:, :, :, :, self.layer_id] = layer_state
|
|
|
|
return rwkv, state
|
|
|
|
|
|
def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
|
|
def rwkv_ffn_forward(
|
|
self,
|
|
hidden: torch.Tensor,
|
|
state: List[torch.Tensor]=None,
|
|
):
|
|
if hidden.device.type == "xpu":
|
|
if hidden.size(1) == 1 and state is not None:
|
|
shifted = state[2][self.layer_id]
|
|
else:
|
|
shifted = self.time_shift(hidden)
|
|
if state is not None:
|
|
shifted[:, 0] = state[2][self.layer_id]
|
|
if len(shifted.size()) == 2:
|
|
shifted = shifted.unsqueeze(1)
|
|
|
|
if not hasattr(self, "mixed_mix"):
|
|
self.mixed_mix = torch.cat([self.time_mix_key.data,
|
|
self.time_mix_receptance.data]).to(dtype=hidden.dtype)
|
|
|
|
import xe_linear
|
|
mixed_result = xe_linear.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
|
key, receptance = mixed_result
|
|
|
|
key = torch.square(torch.relu(self.key(key)))
|
|
value = self.value(key)
|
|
receptance = torch.sigmoid(self.receptance(receptance))
|
|
|
|
if state is not None:
|
|
state[2][self.layer_id] = hidden[:, -1]
|
|
|
|
return receptance * value, state
|
|
else:
|
|
return origin_rwkv_ffn_forward(self, hidden, state)
|
|
|
|
return rwkv_ffn_forward
|
|
|
|
|
|
def rwkv_model_forward_wrapper(origin_rwkv_model_forward):
|
|
def rwkv_model_forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
state: Optional[List[torch.FloatTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
invalidInputError(self.embeddings.weight.dtype == torch.float,
|
|
"Only fp32 is supported for now, fp16 and bf16 are not supported")
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
# change `state` layout and put `num_hidden_layers` to the highest dim
|
|
if input_ids.device.type == "xpu" and use_cache and state is None:
|
|
state = []
|
|
batch_size = input_ids.size(0)
|
|
hidden_size = self.config.hidden_size
|
|
num_hidden_layers = self.config.num_hidden_layers
|
|
num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
|
|
state.append(
|
|
torch.zeros(
|
|
(num_hidden_layers, batch_size, hidden_size),
|
|
dtype=self.embeddings.weight.dtype,
|
|
requires_grad=False,
|
|
device=input_ids.device,
|
|
).contiguous()
|
|
)
|
|
state.append(
|
|
torch.zeros(
|
|
(
|
|
num_hidden_layers,
|
|
batch_size,
|
|
num_attention_heads,
|
|
self.config.hidden_size // num_attention_heads,
|
|
self.config.hidden_size // num_attention_heads,
|
|
),
|
|
dtype=torch.float32,
|
|
requires_grad=False,
|
|
device=input_ids.device,
|
|
).contiguous()
|
|
)
|
|
state.append(
|
|
torch.zeros(
|
|
(num_hidden_layers, batch_size, hidden_size),
|
|
dtype=self.embeddings.weight.dtype,
|
|
requires_grad=False,
|
|
device=input_ids.device,
|
|
).contiguous()
|
|
)
|
|
return origin_rwkv_model_forward(
|
|
self=self,
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
inputs_embeds=inputs_embeds,
|
|
state=state,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
return rwkv_model_forward
|