Optimize rwkv v5 rest token again (#10043)
This commit is contained in:
parent
b1ff28ceb6
commit
53a5140eff
2 changed files with 175 additions and 65 deletions
|
|
@ -973,13 +973,19 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward_wrapper
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_model_forward_wrapper
|
||||
convert_forward(model,
|
||||
module.RwkvSelfAttention,
|
||||
rwkv_attention_forward)
|
||||
rwkv_ffn_forward = rwkv_ffn_forward_wrapper(module.RwkvFeedForward.forward)
|
||||
convert_forward(model,
|
||||
module.RwkvFeedForward,
|
||||
rwkv_ffn_forward)
|
||||
rwkv_model_forward = rwkv_model_forward_wrapper(module.Rwkv5Model.forward)
|
||||
convert_forward(model,
|
||||
module.Rwkv5Model,
|
||||
rwkv_model_forward)
|
||||
elif model.config.model_type == "deci":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
|
|
@ -35,20 +35,19 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def extract_key_value(self, hidden, state=None):
|
||||
# Mix hidden with the previous timestep to produce key, value, receptance
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[0][:, :, self.layer_id]
|
||||
shifted = state[0][self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[0][:, :, self.layer_id]
|
||||
shifted[:, 0] = state[0][self.layer_id]
|
||||
if len(shifted.size()) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([
|
||||
|
|
@ -68,7 +67,7 @@ def extract_key_value(self, hidden, state=None):
|
|||
gate = F.silu(self.gate(gate))
|
||||
|
||||
if state is not None:
|
||||
state[0][:, :, self.layer_id] = hidden[:, -1]
|
||||
state[0][self.layer_id] = hidden[:, -1]
|
||||
|
||||
return receptance, key, value, gate, state
|
||||
|
||||
|
|
@ -97,9 +96,7 @@ def rwkv_linear_attention_xpu(
|
|||
time_decay = torch.exp(-torch.exp(time_decay.float()))
|
||||
time_first = time_first.float()
|
||||
|
||||
state = state.contiguous().float()
|
||||
|
||||
# `state` will be modified during this call
|
||||
# `state` will be updated inplaced during this call
|
||||
import linear_q4_0
|
||||
out = linear_q4_0.rwkv_linear_attention_v5(
|
||||
time_decay,
|
||||
|
|
@ -118,6 +115,50 @@ def rwkv_linear_attention_xpu(
|
|||
out = out.to(dtype=hidden.dtype) * gate
|
||||
# out = out @ ow
|
||||
out = ow(out)
|
||||
return out
|
||||
|
||||
|
||||
def rwkv_linear_attention_cpu(
|
||||
B,
|
||||
H,
|
||||
S,
|
||||
T,
|
||||
n_head,
|
||||
hidden,
|
||||
time_decay,
|
||||
time_first,
|
||||
receptance,
|
||||
key,
|
||||
value,
|
||||
gate,
|
||||
lxw,
|
||||
lxb,
|
||||
ow,
|
||||
state,
|
||||
):
|
||||
key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
|
||||
value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
||||
receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
|
||||
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
||||
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
|
||||
lxw = lxw.float()
|
||||
lxb = lxb.float()
|
||||
out = torch.zeros_like(key).reshape(B, T, H, S)
|
||||
for t in range(T):
|
||||
rt = receptance[:, :, t:t + 1, :]
|
||||
kt = key[:, :, :, t:t + 1]
|
||||
vt = value[:, :, t:t + 1, :]
|
||||
at = kt @ vt
|
||||
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
|
||||
with torch.no_grad():
|
||||
state = at + time_decay * state
|
||||
|
||||
out = out.reshape(B * T, H * S)
|
||||
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
|
||||
out = out.to(dtype=hidden.dtype) * gate
|
||||
# out = out @ ow
|
||||
out = ow(out) # fix this
|
||||
|
||||
return out, state
|
||||
|
||||
|
||||
|
|
@ -133,15 +174,29 @@ def rwkv_attention_forward(
|
|||
S = hidden.shape[-1] // H
|
||||
T = hidden.shape[1]
|
||||
|
||||
receptance, key, value, gate, state = extract_key_value(self, hidden, state=state)
|
||||
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
||||
|
||||
if hidden.device.type == "xpu":
|
||||
rwkv, layer_state = rwkv_linear_attention_xpu(
|
||||
B,
|
||||
H,
|
||||
S,
|
||||
T,
|
||||
receptance, key, value, gate, state = extract_key_value(self, hidden, state)
|
||||
# `state`` will be updated inplaced when running on GPU
|
||||
rwkv = rwkv_linear_attention_xpu(
|
||||
B, H, S, T,
|
||||
hidden,
|
||||
self.time_decay,
|
||||
self.time_faaaa,
|
||||
receptance,
|
||||
key,
|
||||
value,
|
||||
gate,
|
||||
self.ln_x.weight,
|
||||
self.ln_x.bias,
|
||||
self.output,
|
||||
state=state[1][self.layer_id],
|
||||
)
|
||||
else:
|
||||
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state)
|
||||
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
||||
rwkv, layer_state = rwkv_linear_attention_cpu(
|
||||
B, H, S, T,
|
||||
self.num_attention_heads,
|
||||
hidden,
|
||||
self.time_decay,
|
||||
self.time_faaaa,
|
||||
|
|
@ -154,47 +209,27 @@ def rwkv_attention_forward(
|
|||
self.output,
|
||||
state=layer_state,
|
||||
)
|
||||
else:
|
||||
from transformers.models.rwkv.modeling_rwkv import rwkv_linear_attention_cpu
|
||||
rwkv, layer_state = rwkv_linear_attention_cpu(
|
||||
B,
|
||||
H,
|
||||
S,
|
||||
T,
|
||||
self.num_attention_heads,
|
||||
hidden,
|
||||
self.time_decay,
|
||||
self.time_faaaa,
|
||||
receptance,
|
||||
key,
|
||||
value,
|
||||
gate,
|
||||
self.ln_x.weight,
|
||||
self.ln_x.bias,
|
||||
self.output.weight.t(),
|
||||
state=layer_state,
|
||||
)
|
||||
|
||||
if layer_state is not None:
|
||||
state[1][:, :, :, :, self.layer_id] = layer_state
|
||||
|
||||
return rwkv, state
|
||||
|
||||
|
||||
def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
|
||||
def rwkv_ffn_forward(
|
||||
self,
|
||||
hidden: torch.Tensor,
|
||||
state: List[torch.Tensor]=None,
|
||||
):
|
||||
if hidden.device.type == "xpu":
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[2][:, :, self.layer_id]
|
||||
shifted = state[2][self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[2][:, :, self.layer_id]
|
||||
shifted[:, 0] = state[2][self.layer_id]
|
||||
if len(shifted.size()) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||
|
|
@ -208,6 +243,75 @@ def rwkv_ffn_forward(
|
|||
receptance = torch.sigmoid(self.receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[2][:, :, self.layer_id] = hidden[:, -1]
|
||||
state[2][self.layer_id] = hidden[:, -1]
|
||||
|
||||
return receptance * value, state
|
||||
else:
|
||||
return origin_rwkv_ffn_forward(self, hidden, state)
|
||||
|
||||
return rwkv_ffn_forward
|
||||
|
||||
|
||||
def rwkv_model_forward_wrapper(origin_rwkv_model_forward):
|
||||
def rwkv_model_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
state: Optional[List[torch.FloatTensor]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
# change `state` layout and put `num_hidden_layers` to the highest dim
|
||||
if input_ids.device.type == "xpu" and use_cache and state is None:
|
||||
state = []
|
||||
batch_size = input_ids.size(0)
|
||||
hidden_size = self.config.hidden_size
|
||||
num_hidden_layers = self.config.num_hidden_layers
|
||||
num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
|
||||
state.append(
|
||||
torch.zeros(
|
||||
(num_hidden_layers, batch_size, hidden_size),
|
||||
dtype=self.embeddings.weight.dtype,
|
||||
requires_grad=False,
|
||||
device=input_ids.device,
|
||||
).contiguous()
|
||||
)
|
||||
state.append(
|
||||
torch.zeros(
|
||||
(
|
||||
num_hidden_layers,
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
self.config.hidden_size // num_attention_heads,
|
||||
self.config.hidden_size // num_attention_heads,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
requires_grad=False,
|
||||
device=input_ids.device,
|
||||
).contiguous()
|
||||
)
|
||||
state.append(
|
||||
torch.zeros(
|
||||
(num_hidden_layers, batch_size, hidden_size),
|
||||
dtype=self.embeddings.weight.dtype,
|
||||
requires_grad=False,
|
||||
device=input_ids.device,
|
||||
).contiguous()
|
||||
)
|
||||
return origin_rwkv_model_forward(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
state=state,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
return rwkv_model_forward
|
||||
|
|
|
|||
Loading…
Reference in a new issue