diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 41b66bff..c11a6830 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -958,9 +958,25 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.rwkv4 import rwkv_attention_forward + from bigdl.llm.transformers.models.rwkv4 import rwkv_ffn_forward convert_forward(model, module.RwkvSelfAttention, rwkv_attention_forward) + convert_forward(model, + module.RwkvFeedForward, + rwkv_ffn_forward) + elif model.config.model_type == "rwkv5": + # rwkv v5 + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward + from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward + convert_forward(model, + module.RwkvSelfAttention, + rwkv_attention_forward) + convert_forward(model, + module.RwkvFeedForward, + rwkv_ffn_forward) elif model.config.model_type == "deci": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -974,14 +990,6 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.DeciLMAttention, decilm_attention_forward_4_35_2, ) - elif model.config.model_type == "rwkv5": - # rwkv v5 - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward - convert_forward(model, - module.RwkvSelfAttention, - rwkv_attention_forward) elif model.config.model_type == "gpt_bigcode": # starcoder modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/bigdl/llm/transformers/models/rwkv4.py b/python/llm/src/bigdl/llm/transformers/models/rwkv4.py index 84256371..b29af810 100644 --- a/python/llm/src/bigdl/llm/transformers/models/rwkv4.py +++ b/python/llm/src/bigdl/llm/transformers/models/rwkv4.py @@ -37,6 +37,37 @@ import torch from typing import List +def extract_key_value(self, hidden, state=None): + # Mix hidden with the previous timestep to produce key, value, receptance + if hidden.size(1) == 1 and state is not None: + shifted = state[1][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[1][:, :, self.layer_id] + if len(shifted.size()) == 2: + shifted = shifted.unsqueeze(1) + shifted = shifted.contiguous() + + if not hasattr(self, "mixed_mix"): + self.mixed_mix = torch.cat([ + self.time_mix_key.data, + self.time_mix_value.data, + self.time_mix_receptance.data, + ]) + + import linear_q4_0 + mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) + key, value, receptance = mixed_result + + key = self.key(key) + value = self.value(value) + receptance = torch.sigmoid(self.receptance(receptance)) + if state is not None: + state[1][:, :, self.layer_id] = hidden[:, -1] + return receptance, key, value, state + + def rwkv_linear_attention_xpu( time_decay: torch.Tensor, time_first: torch.Tensor, @@ -84,7 +115,7 @@ def rwkv_attention_forward( state: List[torch.Tensor]=None, use_cache: bool=False, ): - receptance, key, value, state = self.extract_key_value(hidden, state=state) + receptance, key, value, state = extract_key_value(self, hidden, state=state) layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None if hidden.device.type == "xpu": @@ -113,3 +144,35 @@ def rwkv_attention_forward( state[4][:, :, self.layer_id] = layer_state[2] return self.output(receptance * rwkv), state + + +def rwkv_ffn_forward( + self, + hidden: torch.Tensor, + state: List[torch.Tensor]=None, +): + if hidden.size(1) == 1 and state is not None: + shifted = state[0][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[0][:, :, self.layer_id] + if len(shifted.size()) == 2: + shifted = shifted.unsqueeze(1) + shifted = shifted.contiguous() + + if not hasattr(self, "mixed_mix"): + self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data]) + + import linear_q4_0 + mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) + key, receptance = mixed_result + + key = torch.square(torch.relu(self.key(key))) + value = self.value(key) + receptance = torch.sigmoid(self.receptance(receptance)) + + if state is not None: + state[0][:, :, self.layer_id] = hidden[:, -1] + + return receptance * value, state diff --git a/python/llm/src/bigdl/llm/transformers/models/rwkv5.py b/python/llm/src/bigdl/llm/transformers/models/rwkv5.py index cff32638..e22a4995 100644 --- a/python/llm/src/bigdl/llm/transformers/models/rwkv5.py +++ b/python/llm/src/bigdl/llm/transformers/models/rwkv5.py @@ -38,6 +38,41 @@ import torch.nn.functional as F from typing import List +def extract_key_value(self, hidden, state=None): + # Mix hidden with the previous timestep to produce key, value, receptance + if hidden.size(1) == 1 and state is not None: + shifted = state[0][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[0][:, :, self.layer_id] + if len(shifted.size()) == 2: + shifted = shifted.unsqueeze(1) + shifted = shifted.contiguous() + + if not hasattr(self, "mixed_mix"): + self.mixed_mix = torch.cat([ + self.time_mix_key.data, + self.time_mix_value.data, + self.time_mix_receptance.data, + self.time_mix_gate.data, + ]) + + import linear_q4_0 + mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) + key, value, receptance, gate = mixed_result + + key = self.key(key) + value = self.value(value) + receptance = self.receptance(receptance) + gate = F.silu(self.gate(gate)) + + if state is not None: + state[0][:, :, self.layer_id] = hidden[:, -1] + + return receptance, key, value, gate, state + + def rwkv_linear_attention_xpu( B: int, H: int, @@ -98,7 +133,7 @@ def rwkv_attention_forward( S = hidden.shape[-1] // H T = hidden.shape[1] - receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state) + receptance, key, value, gate, state = extract_key_value(self, hidden, state=state) layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None if hidden.device.type == "xpu": @@ -144,3 +179,35 @@ def rwkv_attention_forward( state[1][:, :, :, :, self.layer_id] = layer_state return rwkv, state + + +def rwkv_ffn_forward( + self, + hidden: torch.Tensor, + state: List[torch.Tensor]=None, +): + if hidden.size(1) == 1 and state is not None: + shifted = state[2][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[2][:, :, self.layer_id] + if len(shifted.size()) == 2: + shifted = shifted.unsqueeze(1) + shifted = shifted.contiguous() + + if not hasattr(self, "mixed_mix"): + self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data]) + + import linear_q4_0 + mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) + key, receptance = mixed_result + + key = torch.square(torch.relu(self.key(key))) + value = self.value(key) + receptance = torch.sigmoid(self.receptance(receptance)) + + if state is not None: + state[2][:, :, self.layer_id] = hidden[:, -1] + + return receptance * value, state