Optimize rwkv v5 rest token again (#10043)

This commit is contained in:
Yishuo Wang 2024-01-31 10:01:11 +08:00 committed by GitHub
parent b1ff28ceb6
commit 53a5140eff
2 changed files with 175 additions and 65 deletions

View file

@ -973,13 +973,19 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward_wrapper
from bigdl.llm.transformers.models.rwkv5 import rwkv_model_forward_wrapper
convert_forward(model,
module.RwkvSelfAttention,
rwkv_attention_forward)
rwkv_ffn_forward = rwkv_ffn_forward_wrapper(module.RwkvFeedForward.forward)
convert_forward(model,
module.RwkvFeedForward,
rwkv_ffn_forward)
rwkv_model_forward = rwkv_model_forward_wrapper(module.Rwkv5Model.forward)
convert_forward(model,
module.Rwkv5Model,
rwkv_model_forward)
elif model.config.model_type == "deci":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)

View file

@ -35,20 +35,19 @@
import torch
import torch.nn.functional as F
from typing import List
from typing import List, Optional
def extract_key_value(self, hidden, state=None):
# Mix hidden with the previous timestep to produce key, value, receptance
if hidden.size(1) == 1 and state is not None:
shifted = state[0][:, :, self.layer_id]
shifted = state[0][self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[0][:, :, self.layer_id]
shifted[:, 0] = state[0][self.layer_id]
if len(shifted.size()) == 2:
shifted = shifted.unsqueeze(1)
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([
@ -68,7 +67,7 @@ def extract_key_value(self, hidden, state=None):
gate = F.silu(self.gate(gate))
if state is not None:
state[0][:, :, self.layer_id] = hidden[:, -1]
state[0][self.layer_id] = hidden[:, -1]
return receptance, key, value, gate, state
@ -97,9 +96,7 @@ def rwkv_linear_attention_xpu(
time_decay = torch.exp(-torch.exp(time_decay.float()))
time_first = time_first.float()
state = state.contiguous().float()
# `state` will be modified during this call
# `state` will be updated inplaced during this call
import linear_q4_0
out = linear_q4_0.rwkv_linear_attention_v5(
time_decay,
@ -118,6 +115,50 @@ def rwkv_linear_attention_xpu(
out = out.to(dtype=hidden.dtype) * gate
# out = out @ ow
out = ow(out)
return out
def rwkv_linear_attention_cpu(
B,
H,
S,
T,
n_head,
hidden,
time_decay,
time_first,
receptance,
key,
value,
gate,
lxw,
lxb,
ow,
state,
):
key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
lxw = lxw.float()
lxb = lxb.float()
out = torch.zeros_like(key).reshape(B, T, H, S)
for t in range(T):
rt = receptance[:, :, t:t + 1, :]
kt = key[:, :, :, t:t + 1]
vt = value[:, :, t:t + 1, :]
at = kt @ vt
out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
with torch.no_grad():
state = at + time_decay * state
out = out.reshape(B * T, H * S)
out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
out = out.to(dtype=hidden.dtype) * gate
# out = out @ ow
out = ow(out) # fix this
return out, state
@ -133,15 +174,29 @@ def rwkv_attention_forward(
S = hidden.shape[-1] // H
T = hidden.shape[1]
receptance, key, value, gate, state = extract_key_value(self, hidden, state=state)
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
if hidden.device.type == "xpu":
rwkv, layer_state = rwkv_linear_attention_xpu(
B,
H,
S,
T,
receptance, key, value, gate, state = extract_key_value(self, hidden, state)
# `state`` will be updated inplaced when running on GPU
rwkv = rwkv_linear_attention_xpu(
B, H, S, T,
hidden,
self.time_decay,
self.time_faaaa,
receptance,
key,
value,
gate,
self.ln_x.weight,
self.ln_x.bias,
self.output,
state=state[1][self.layer_id],
)
else:
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state)
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
rwkv, layer_state = rwkv_linear_attention_cpu(
B, H, S, T,
self.num_attention_heads,
hidden,
self.time_decay,
self.time_faaaa,
@ -154,47 +209,27 @@ def rwkv_attention_forward(
self.output,
state=layer_state,
)
else:
from transformers.models.rwkv.modeling_rwkv import rwkv_linear_attention_cpu
rwkv, layer_state = rwkv_linear_attention_cpu(
B,
H,
S,
T,
self.num_attention_heads,
hidden,
self.time_decay,
self.time_faaaa,
receptance,
key,
value,
gate,
self.ln_x.weight,
self.ln_x.bias,
self.output.weight.t(),
state=layer_state,
)
if layer_state is not None:
state[1][:, :, :, :, self.layer_id] = layer_state
return rwkv, state
def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
def rwkv_ffn_forward(
self,
hidden: torch.Tensor,
state: List[torch.Tensor]=None,
):
if hidden.device.type == "xpu":
if hidden.size(1) == 1 and state is not None:
shifted = state[2][:, :, self.layer_id]
shifted = state[2][self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[2][:, :, self.layer_id]
shifted[:, 0] = state[2][self.layer_id]
if len(shifted.size()) == 2:
shifted = shifted.unsqueeze(1)
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
@ -208,6 +243,75 @@ def rwkv_ffn_forward(
receptance = torch.sigmoid(self.receptance(receptance))
if state is not None:
state[2][:, :, self.layer_id] = hidden[:, -1]
state[2][self.layer_id] = hidden[:, -1]
return receptance * value, state
else:
return origin_rwkv_ffn_forward(self, hidden, state)
return rwkv_ffn_forward
def rwkv_model_forward_wrapper(origin_rwkv_model_forward):
def rwkv_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
# change `state` layout and put `num_hidden_layers` to the highest dim
if input_ids.device.type == "xpu" and use_cache and state is None:
state = []
batch_size = input_ids.size(0)
hidden_size = self.config.hidden_size
num_hidden_layers = self.config.num_hidden_layers
num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
state.append(
torch.zeros(
(num_hidden_layers, batch_size, hidden_size),
dtype=self.embeddings.weight.dtype,
requires_grad=False,
device=input_ids.device,
).contiguous()
)
state.append(
torch.zeros(
(
num_hidden_layers,
batch_size,
num_attention_heads,
self.config.hidden_size // num_attention_heads,
self.config.hidden_size // num_attention_heads,
),
dtype=torch.float32,
requires_grad=False,
device=input_ids.device,
).contiguous()
)
state.append(
torch.zeros(
(num_hidden_layers, batch_size, hidden_size),
dtype=self.embeddings.weight.dtype,
requires_grad=False,
device=input_ids.device,
).contiguous()
)
return origin_rwkv_model_forward(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
state=state,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return rwkv_model_forward