add rwkv time shift optimization (#10032)

This commit is contained in:
Yishuo Wang 2024-01-30 14:10:55 +08:00 committed by GitHub
parent f57d0fda8b
commit 7dfa6dbe46
3 changed files with 148 additions and 10 deletions

View file

@ -958,9 +958,25 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.rwkv4 import rwkv_attention_forward from bigdl.llm.transformers.models.rwkv4 import rwkv_attention_forward
from bigdl.llm.transformers.models.rwkv4 import rwkv_ffn_forward
convert_forward(model, convert_forward(model,
module.RwkvSelfAttention, module.RwkvSelfAttention,
rwkv_attention_forward) rwkv_attention_forward)
convert_forward(model,
module.RwkvFeedForward,
rwkv_ffn_forward)
elif model.config.model_type == "rwkv5":
# rwkv v5
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward
convert_forward(model,
module.RwkvSelfAttention,
rwkv_attention_forward)
convert_forward(model,
module.RwkvFeedForward,
rwkv_ffn_forward)
elif model.config.model_type == "deci": elif model.config.model_type == "deci":
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
@ -974,14 +990,6 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.DeciLMAttention, module.DeciLMAttention,
decilm_attention_forward_4_35_2, ) decilm_attention_forward_4_35_2, )
elif model.config.model_type == "rwkv5":
# rwkv v5
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
convert_forward(model,
module.RwkvSelfAttention,
rwkv_attention_forward)
elif model.config.model_type == "gpt_bigcode": elif model.config.model_type == "gpt_bigcode":
# starcoder # starcoder
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__

View file

@ -37,6 +37,37 @@ import torch
from typing import List from typing import List
def extract_key_value(self, hidden, state=None):
# Mix hidden with the previous timestep to produce key, value, receptance
if hidden.size(1) == 1 and state is not None:
shifted = state[1][:, :, self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[1][:, :, self.layer_id]
if len(shifted.size()) == 2:
shifted = shifted.unsqueeze(1)
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([
self.time_mix_key.data,
self.time_mix_value.data,
self.time_mix_receptance.data,
])
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
key, value, receptance = mixed_result
key = self.key(key)
value = self.value(value)
receptance = torch.sigmoid(self.receptance(receptance))
if state is not None:
state[1][:, :, self.layer_id] = hidden[:, -1]
return receptance, key, value, state
def rwkv_linear_attention_xpu( def rwkv_linear_attention_xpu(
time_decay: torch.Tensor, time_decay: torch.Tensor,
time_first: torch.Tensor, time_first: torch.Tensor,
@ -84,7 +115,7 @@ def rwkv_attention_forward(
state: List[torch.Tensor]=None, state: List[torch.Tensor]=None,
use_cache: bool=False, use_cache: bool=False,
): ):
receptance, key, value, state = self.extract_key_value(hidden, state=state) receptance, key, value, state = extract_key_value(self, hidden, state=state)
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
if hidden.device.type == "xpu": if hidden.device.type == "xpu":
@ -113,3 +144,35 @@ def rwkv_attention_forward(
state[4][:, :, self.layer_id] = layer_state[2] state[4][:, :, self.layer_id] = layer_state[2]
return self.output(receptance * rwkv), state return self.output(receptance * rwkv), state
def rwkv_ffn_forward(
self,
hidden: torch.Tensor,
state: List[torch.Tensor]=None,
):
if hidden.size(1) == 1 and state is not None:
shifted = state[0][:, :, self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[0][:, :, self.layer_id]
if len(shifted.size()) == 2:
shifted = shifted.unsqueeze(1)
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
key, receptance = mixed_result
key = torch.square(torch.relu(self.key(key)))
value = self.value(key)
receptance = torch.sigmoid(self.receptance(receptance))
if state is not None:
state[0][:, :, self.layer_id] = hidden[:, -1]
return receptance * value, state

View file

@ -38,6 +38,41 @@ import torch.nn.functional as F
from typing import List from typing import List
def extract_key_value(self, hidden, state=None):
# Mix hidden with the previous timestep to produce key, value, receptance
if hidden.size(1) == 1 and state is not None:
shifted = state[0][:, :, self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[0][:, :, self.layer_id]
if len(shifted.size()) == 2:
shifted = shifted.unsqueeze(1)
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([
self.time_mix_key.data,
self.time_mix_value.data,
self.time_mix_receptance.data,
self.time_mix_gate.data,
])
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
key, value, receptance, gate = mixed_result
key = self.key(key)
value = self.value(value)
receptance = self.receptance(receptance)
gate = F.silu(self.gate(gate))
if state is not None:
state[0][:, :, self.layer_id] = hidden[:, -1]
return receptance, key, value, gate, state
def rwkv_linear_attention_xpu( def rwkv_linear_attention_xpu(
B: int, B: int,
H: int, H: int,
@ -98,7 +133,7 @@ def rwkv_attention_forward(
S = hidden.shape[-1] // H S = hidden.shape[-1] // H
T = hidden.shape[1] T = hidden.shape[1]
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state) receptance, key, value, gate, state = extract_key_value(self, hidden, state=state)
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
if hidden.device.type == "xpu": if hidden.device.type == "xpu":
@ -144,3 +179,35 @@ def rwkv_attention_forward(
state[1][:, :, :, :, self.layer_id] = layer_state state[1][:, :, :, :, self.layer_id] = layer_state
return rwkv, state return rwkv, state
def rwkv_ffn_forward(
self,
hidden: torch.Tensor,
state: List[torch.Tensor]=None,
):
if hidden.size(1) == 1 and state is not None:
shifted = state[2][:, :, self.layer_id]
else:
shifted = self.time_shift(hidden)
if state is not None:
shifted[:, 0] = state[2][:, :, self.layer_id]
if len(shifted.size()) == 2:
shifted = shifted.unsqueeze(1)
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
key, receptance = mixed_result
key = torch.square(torch.relu(self.key(key)))
value = self.value(key)
receptance = torch.sigmoid(self.receptance(receptance))
if state is not None:
state[2][:, :, self.layer_id] = hidden[:, -1]
return receptance * value, state