add rwkv time shift optimization (#10032)
This commit is contained in:
parent
f57d0fda8b
commit
7dfa6dbe46
3 changed files with 148 additions and 10 deletions
|
|
@ -958,9 +958,25 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.rwkv4 import rwkv_attention_forward
|
||||
from bigdl.llm.transformers.models.rwkv4 import rwkv_ffn_forward
|
||||
convert_forward(model,
|
||||
module.RwkvSelfAttention,
|
||||
rwkv_attention_forward)
|
||||
convert_forward(model,
|
||||
module.RwkvFeedForward,
|
||||
rwkv_ffn_forward)
|
||||
elif model.config.model_type == "rwkv5":
|
||||
# rwkv v5
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward
|
||||
convert_forward(model,
|
||||
module.RwkvSelfAttention,
|
||||
rwkv_attention_forward)
|
||||
convert_forward(model,
|
||||
module.RwkvFeedForward,
|
||||
rwkv_ffn_forward)
|
||||
elif model.config.model_type == "deci":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -974,14 +990,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.DeciLMAttention,
|
||||
decilm_attention_forward_4_35_2, )
|
||||
elif model.config.model_type == "rwkv5":
|
||||
# rwkv v5
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
|
||||
convert_forward(model,
|
||||
module.RwkvSelfAttention,
|
||||
rwkv_attention_forward)
|
||||
elif model.config.model_type == "gpt_bigcode":
|
||||
# starcoder
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
|
|
@ -37,6 +37,37 @@ import torch
|
|||
from typing import List
|
||||
|
||||
|
||||
def extract_key_value(self, hidden, state=None):
|
||||
# Mix hidden with the previous timestep to produce key, value, receptance
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[1][:, :, self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[1][:, :, self.layer_id]
|
||||
if len(shifted.size()) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([
|
||||
self.time_mix_key.data,
|
||||
self.time_mix_value.data,
|
||||
self.time_mix_receptance.data,
|
||||
])
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
key, value, receptance = mixed_result
|
||||
|
||||
key = self.key(key)
|
||||
value = self.value(value)
|
||||
receptance = torch.sigmoid(self.receptance(receptance))
|
||||
if state is not None:
|
||||
state[1][:, :, self.layer_id] = hidden[:, -1]
|
||||
return receptance, key, value, state
|
||||
|
||||
|
||||
def rwkv_linear_attention_xpu(
|
||||
time_decay: torch.Tensor,
|
||||
time_first: torch.Tensor,
|
||||
|
|
@ -84,7 +115,7 @@ def rwkv_attention_forward(
|
|||
state: List[torch.Tensor]=None,
|
||||
use_cache: bool=False,
|
||||
):
|
||||
receptance, key, value, state = self.extract_key_value(hidden, state=state)
|
||||
receptance, key, value, state = extract_key_value(self, hidden, state=state)
|
||||
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
||||
|
||||
if hidden.device.type == "xpu":
|
||||
|
|
@ -113,3 +144,35 @@ def rwkv_attention_forward(
|
|||
state[4][:, :, self.layer_id] = layer_state[2]
|
||||
|
||||
return self.output(receptance * rwkv), state
|
||||
|
||||
|
||||
def rwkv_ffn_forward(
|
||||
self,
|
||||
hidden: torch.Tensor,
|
||||
state: List[torch.Tensor]=None,
|
||||
):
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[0][:, :, self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[0][:, :, self.layer_id]
|
||||
if len(shifted.size()) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
key, receptance = mixed_result
|
||||
|
||||
key = torch.square(torch.relu(self.key(key)))
|
||||
value = self.value(key)
|
||||
receptance = torch.sigmoid(self.receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[0][:, :, self.layer_id] = hidden[:, -1]
|
||||
|
||||
return receptance * value, state
|
||||
|
|
|
|||
|
|
@ -38,6 +38,41 @@ import torch.nn.functional as F
|
|||
from typing import List
|
||||
|
||||
|
||||
def extract_key_value(self, hidden, state=None):
|
||||
# Mix hidden with the previous timestep to produce key, value, receptance
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[0][:, :, self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[0][:, :, self.layer_id]
|
||||
if len(shifted.size()) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([
|
||||
self.time_mix_key.data,
|
||||
self.time_mix_value.data,
|
||||
self.time_mix_receptance.data,
|
||||
self.time_mix_gate.data,
|
||||
])
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
key, value, receptance, gate = mixed_result
|
||||
|
||||
key = self.key(key)
|
||||
value = self.value(value)
|
||||
receptance = self.receptance(receptance)
|
||||
gate = F.silu(self.gate(gate))
|
||||
|
||||
if state is not None:
|
||||
state[0][:, :, self.layer_id] = hidden[:, -1]
|
||||
|
||||
return receptance, key, value, gate, state
|
||||
|
||||
|
||||
def rwkv_linear_attention_xpu(
|
||||
B: int,
|
||||
H: int,
|
||||
|
|
@ -98,7 +133,7 @@ def rwkv_attention_forward(
|
|||
S = hidden.shape[-1] // H
|
||||
T = hidden.shape[1]
|
||||
|
||||
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
||||
receptance, key, value, gate, state = extract_key_value(self, hidden, state=state)
|
||||
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
||||
|
||||
if hidden.device.type == "xpu":
|
||||
|
|
@ -144,3 +179,35 @@ def rwkv_attention_forward(
|
|||
state[1][:, :, :, :, self.layer_id] = layer_state
|
||||
|
||||
return rwkv, state
|
||||
|
||||
|
||||
def rwkv_ffn_forward(
|
||||
self,
|
||||
hidden: torch.Tensor,
|
||||
state: List[torch.Tensor]=None,
|
||||
):
|
||||
if hidden.size(1) == 1 and state is not None:
|
||||
shifted = state[2][:, :, self.layer_id]
|
||||
else:
|
||||
shifted = self.time_shift(hidden)
|
||||
if state is not None:
|
||||
shifted[:, 0] = state[2][:, :, self.layer_id]
|
||||
if len(shifted.size()) == 2:
|
||||
shifted = shifted.unsqueeze(1)
|
||||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
key, receptance = mixed_result
|
||||
|
||||
key = torch.square(torch.relu(self.key(key)))
|
||||
value = self.value(key)
|
||||
receptance = torch.sigmoid(self.receptance(receptance))
|
||||
|
||||
if state is not None:
|
||||
state[2][:, :, self.layer_id] = hidden[:, -1]
|
||||
|
||||
return receptance * value, state
|
||||
|
|
|
|||
Loading…
Reference in a new issue