add rwkv time shift optimization (#10032)
This commit is contained in:
parent
f57d0fda8b
commit
7dfa6dbe46
3 changed files with 148 additions and 10 deletions
|
|
@ -958,9 +958,25 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.rwkv4 import rwkv_attention_forward
|
from bigdl.llm.transformers.models.rwkv4 import rwkv_attention_forward
|
||||||
|
from bigdl.llm.transformers.models.rwkv4 import rwkv_ffn_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.RwkvSelfAttention,
|
module.RwkvSelfAttention,
|
||||||
rwkv_attention_forward)
|
rwkv_attention_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.RwkvFeedForward,
|
||||||
|
rwkv_ffn_forward)
|
||||||
|
elif model.config.model_type == "rwkv5":
|
||||||
|
# rwkv v5
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
|
||||||
|
from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward
|
||||||
|
convert_forward(model,
|
||||||
|
module.RwkvSelfAttention,
|
||||||
|
rwkv_attention_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.RwkvFeedForward,
|
||||||
|
rwkv_ffn_forward)
|
||||||
elif model.config.model_type == "deci":
|
elif model.config.model_type == "deci":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -974,14 +990,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.DeciLMAttention,
|
module.DeciLMAttention,
|
||||||
decilm_attention_forward_4_35_2, )
|
decilm_attention_forward_4_35_2, )
|
||||||
elif model.config.model_type == "rwkv5":
|
|
||||||
# rwkv v5
|
|
||||||
modeling_module_name = model.__class__.__module__
|
|
||||||
module = importlib.import_module(modeling_module_name)
|
|
||||||
from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
|
|
||||||
convert_forward(model,
|
|
||||||
module.RwkvSelfAttention,
|
|
||||||
rwkv_attention_forward)
|
|
||||||
elif model.config.model_type == "gpt_bigcode":
|
elif model.config.model_type == "gpt_bigcode":
|
||||||
# starcoder
|
# starcoder
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,37 @@ import torch
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def extract_key_value(self, hidden, state=None):
|
||||||
|
# Mix hidden with the previous timestep to produce key, value, receptance
|
||||||
|
if hidden.size(1) == 1 and state is not None:
|
||||||
|
shifted = state[1][:, :, self.layer_id]
|
||||||
|
else:
|
||||||
|
shifted = self.time_shift(hidden)
|
||||||
|
if state is not None:
|
||||||
|
shifted[:, 0] = state[1][:, :, self.layer_id]
|
||||||
|
if len(shifted.size()) == 2:
|
||||||
|
shifted = shifted.unsqueeze(1)
|
||||||
|
shifted = shifted.contiguous()
|
||||||
|
|
||||||
|
if not hasattr(self, "mixed_mix"):
|
||||||
|
self.mixed_mix = torch.cat([
|
||||||
|
self.time_mix_key.data,
|
||||||
|
self.time_mix_value.data,
|
||||||
|
self.time_mix_receptance.data,
|
||||||
|
])
|
||||||
|
|
||||||
|
import linear_q4_0
|
||||||
|
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||||
|
key, value, receptance = mixed_result
|
||||||
|
|
||||||
|
key = self.key(key)
|
||||||
|
value = self.value(value)
|
||||||
|
receptance = torch.sigmoid(self.receptance(receptance))
|
||||||
|
if state is not None:
|
||||||
|
state[1][:, :, self.layer_id] = hidden[:, -1]
|
||||||
|
return receptance, key, value, state
|
||||||
|
|
||||||
|
|
||||||
def rwkv_linear_attention_xpu(
|
def rwkv_linear_attention_xpu(
|
||||||
time_decay: torch.Tensor,
|
time_decay: torch.Tensor,
|
||||||
time_first: torch.Tensor,
|
time_first: torch.Tensor,
|
||||||
|
|
@ -84,7 +115,7 @@ def rwkv_attention_forward(
|
||||||
state: List[torch.Tensor]=None,
|
state: List[torch.Tensor]=None,
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
):
|
):
|
||||||
receptance, key, value, state = self.extract_key_value(hidden, state=state)
|
receptance, key, value, state = extract_key_value(self, hidden, state=state)
|
||||||
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
||||||
|
|
||||||
if hidden.device.type == "xpu":
|
if hidden.device.type == "xpu":
|
||||||
|
|
@ -113,3 +144,35 @@ def rwkv_attention_forward(
|
||||||
state[4][:, :, self.layer_id] = layer_state[2]
|
state[4][:, :, self.layer_id] = layer_state[2]
|
||||||
|
|
||||||
return self.output(receptance * rwkv), state
|
return self.output(receptance * rwkv), state
|
||||||
|
|
||||||
|
|
||||||
|
def rwkv_ffn_forward(
|
||||||
|
self,
|
||||||
|
hidden: torch.Tensor,
|
||||||
|
state: List[torch.Tensor]=None,
|
||||||
|
):
|
||||||
|
if hidden.size(1) == 1 and state is not None:
|
||||||
|
shifted = state[0][:, :, self.layer_id]
|
||||||
|
else:
|
||||||
|
shifted = self.time_shift(hidden)
|
||||||
|
if state is not None:
|
||||||
|
shifted[:, 0] = state[0][:, :, self.layer_id]
|
||||||
|
if len(shifted.size()) == 2:
|
||||||
|
shifted = shifted.unsqueeze(1)
|
||||||
|
shifted = shifted.contiguous()
|
||||||
|
|
||||||
|
if not hasattr(self, "mixed_mix"):
|
||||||
|
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||||
|
|
||||||
|
import linear_q4_0
|
||||||
|
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||||
|
key, receptance = mixed_result
|
||||||
|
|
||||||
|
key = torch.square(torch.relu(self.key(key)))
|
||||||
|
value = self.value(key)
|
||||||
|
receptance = torch.sigmoid(self.receptance(receptance))
|
||||||
|
|
||||||
|
if state is not None:
|
||||||
|
state[0][:, :, self.layer_id] = hidden[:, -1]
|
||||||
|
|
||||||
|
return receptance * value, state
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,41 @@ import torch.nn.functional as F
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
def extract_key_value(self, hidden, state=None):
|
||||||
|
# Mix hidden with the previous timestep to produce key, value, receptance
|
||||||
|
if hidden.size(1) == 1 and state is not None:
|
||||||
|
shifted = state[0][:, :, self.layer_id]
|
||||||
|
else:
|
||||||
|
shifted = self.time_shift(hidden)
|
||||||
|
if state is not None:
|
||||||
|
shifted[:, 0] = state[0][:, :, self.layer_id]
|
||||||
|
if len(shifted.size()) == 2:
|
||||||
|
shifted = shifted.unsqueeze(1)
|
||||||
|
shifted = shifted.contiguous()
|
||||||
|
|
||||||
|
if not hasattr(self, "mixed_mix"):
|
||||||
|
self.mixed_mix = torch.cat([
|
||||||
|
self.time_mix_key.data,
|
||||||
|
self.time_mix_value.data,
|
||||||
|
self.time_mix_receptance.data,
|
||||||
|
self.time_mix_gate.data,
|
||||||
|
])
|
||||||
|
|
||||||
|
import linear_q4_0
|
||||||
|
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||||
|
key, value, receptance, gate = mixed_result
|
||||||
|
|
||||||
|
key = self.key(key)
|
||||||
|
value = self.value(value)
|
||||||
|
receptance = self.receptance(receptance)
|
||||||
|
gate = F.silu(self.gate(gate))
|
||||||
|
|
||||||
|
if state is not None:
|
||||||
|
state[0][:, :, self.layer_id] = hidden[:, -1]
|
||||||
|
|
||||||
|
return receptance, key, value, gate, state
|
||||||
|
|
||||||
|
|
||||||
def rwkv_linear_attention_xpu(
|
def rwkv_linear_attention_xpu(
|
||||||
B: int,
|
B: int,
|
||||||
H: int,
|
H: int,
|
||||||
|
|
@ -98,7 +133,7 @@ def rwkv_attention_forward(
|
||||||
S = hidden.shape[-1] // H
|
S = hidden.shape[-1] // H
|
||||||
T = hidden.shape[1]
|
T = hidden.shape[1]
|
||||||
|
|
||||||
receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
|
receptance, key, value, gate, state = extract_key_value(self, hidden, state=state)
|
||||||
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
|
||||||
|
|
||||||
if hidden.device.type == "xpu":
|
if hidden.device.type == "xpu":
|
||||||
|
|
@ -144,3 +179,35 @@ def rwkv_attention_forward(
|
||||||
state[1][:, :, :, :, self.layer_id] = layer_state
|
state[1][:, :, :, :, self.layer_id] = layer_state
|
||||||
|
|
||||||
return rwkv, state
|
return rwkv, state
|
||||||
|
|
||||||
|
|
||||||
|
def rwkv_ffn_forward(
|
||||||
|
self,
|
||||||
|
hidden: torch.Tensor,
|
||||||
|
state: List[torch.Tensor]=None,
|
||||||
|
):
|
||||||
|
if hidden.size(1) == 1 and state is not None:
|
||||||
|
shifted = state[2][:, :, self.layer_id]
|
||||||
|
else:
|
||||||
|
shifted = self.time_shift(hidden)
|
||||||
|
if state is not None:
|
||||||
|
shifted[:, 0] = state[2][:, :, self.layer_id]
|
||||||
|
if len(shifted.size()) == 2:
|
||||||
|
shifted = shifted.unsqueeze(1)
|
||||||
|
shifted = shifted.contiguous()
|
||||||
|
|
||||||
|
if not hasattr(self, "mixed_mix"):
|
||||||
|
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||||
|
|
||||||
|
import linear_q4_0
|
||||||
|
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||||
|
key, receptance = mixed_result
|
||||||
|
|
||||||
|
key = torch.square(torch.relu(self.key(key)))
|
||||||
|
value = self.value(key)
|
||||||
|
receptance = torch.sigmoid(self.receptance(receptance))
|
||||||
|
|
||||||
|
if state is not None:
|
||||||
|
state[2][:, :, self.layer_id] = hidden[:, -1]
|
||||||
|
|
||||||
|
return receptance * value, state
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue