diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 36950230..b7614dd5 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -973,13 +973,19 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward - from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward + from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward_wrapper + from bigdl.llm.transformers.models.rwkv5 import rwkv_model_forward_wrapper convert_forward(model, module.RwkvSelfAttention, rwkv_attention_forward) + rwkv_ffn_forward = rwkv_ffn_forward_wrapper(module.RwkvFeedForward.forward) convert_forward(model, module.RwkvFeedForward, rwkv_ffn_forward) + rwkv_model_forward = rwkv_model_forward_wrapper(module.Rwkv5Model.forward) + convert_forward(model, + module.Rwkv5Model, + rwkv_model_forward) elif model.config.model_type == "deci": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/rwkv5.py b/python/llm/src/bigdl/llm/transformers/models/rwkv5.py index e22a4995..7d4d6d03 100644 --- a/python/llm/src/bigdl/llm/transformers/models/rwkv5.py +++ b/python/llm/src/bigdl/llm/transformers/models/rwkv5.py @@ -35,20 +35,19 @@ import torch import torch.nn.functional as F -from typing import List +from typing import List, Optional def extract_key_value(self, hidden, state=None): # Mix hidden with the previous timestep to produce key, value, receptance if hidden.size(1) == 1 and state is not None: - shifted = state[0][:, :, self.layer_id] + shifted = state[0][self.layer_id] else: shifted = self.time_shift(hidden) if state is not None: - shifted[:, 0] = state[0][:, :, self.layer_id] + shifted[:, 0] = state[0][self.layer_id] if len(shifted.size()) == 2: shifted = shifted.unsqueeze(1) - shifted = shifted.contiguous() if not hasattr(self, "mixed_mix"): self.mixed_mix = torch.cat([ @@ -68,7 +67,7 @@ def extract_key_value(self, hidden, state=None): gate = F.silu(self.gate(gate)) if state is not None: - state[0][:, :, self.layer_id] = hidden[:, -1] + state[0][self.layer_id] = hidden[:, -1] return receptance, key, value, gate, state @@ -97,9 +96,7 @@ def rwkv_linear_attention_xpu( time_decay = torch.exp(-torch.exp(time_decay.float())) time_first = time_first.float() - state = state.contiguous().float() - - # `state` will be modified during this call + # `state` will be updated inplaced during this call import linear_q4_0 out = linear_q4_0.rwkv_linear_attention_v5( time_decay, @@ -118,6 +115,50 @@ def rwkv_linear_attention_xpu( out = out.to(dtype=hidden.dtype) * gate # out = out @ ow out = ow(out) + return out + + +def rwkv_linear_attention_cpu( + B, + H, + S, + T, + n_head, + hidden, + time_decay, + time_first, + receptance, + key, + value, + gate, + lxw, + lxb, + ow, + state, +): + key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1) + value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2) + receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2) + time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1) + time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1) + lxw = lxw.float() + lxb = lxb.float() + out = torch.zeros_like(key).reshape(B, T, H, S) + for t in range(T): + rt = receptance[:, :, t:t + 1, :] + kt = key[:, :, :, t:t + 1] + vt = value[:, :, t:t + 1, :] + at = kt @ vt + out[:, t] = (rt @ (time_first * at + state)).squeeze(2) + with torch.no_grad(): + state = at + time_decay * state + + out = out.reshape(B * T, H * S) + out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S) + out = out.to(dtype=hidden.dtype) * gate + # out = out @ ow + out = ow(out) # fix this + return out, state @@ -133,15 +174,29 @@ def rwkv_attention_forward( S = hidden.shape[-1] // H T = hidden.shape[1] - receptance, key, value, gate, state = extract_key_value(self, hidden, state=state) - layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None - if hidden.device.type == "xpu": - rwkv, layer_state = rwkv_linear_attention_xpu( - B, - H, - S, - T, + receptance, key, value, gate, state = extract_key_value(self, hidden, state) + # `state`` will be updated inplaced when running on GPU + rwkv = rwkv_linear_attention_xpu( + B, H, S, T, + hidden, + self.time_decay, + self.time_faaaa, + receptance, + key, + value, + gate, + self.ln_x.weight, + self.ln_x.bias, + self.output, + state=state[1][self.layer_id], + ) + else: + receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state) + layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None + rwkv, layer_state = rwkv_linear_attention_cpu( + B, H, S, T, + self.num_attention_heads, hidden, self.time_decay, self.time_faaaa, @@ -154,60 +209,109 @@ def rwkv_attention_forward( self.output, state=layer_state, ) - else: - from transformers.models.rwkv.modeling_rwkv import rwkv_linear_attention_cpu - rwkv, layer_state = rwkv_linear_attention_cpu( - B, - H, - S, - T, - self.num_attention_heads, - hidden, - self.time_decay, - self.time_faaaa, - receptance, - key, - value, - gate, - self.ln_x.weight, - self.ln_x.bias, - self.output.weight.t(), - state=layer_state, - ) - - if layer_state is not None: - state[1][:, :, :, :, self.layer_id] = layer_state + if layer_state is not None: + state[1][:, :, :, :, self.layer_id] = layer_state return rwkv, state -def rwkv_ffn_forward( - self, - hidden: torch.Tensor, - state: List[torch.Tensor]=None, -): - if hidden.size(1) == 1 and state is not None: - shifted = state[2][:, :, self.layer_id] - else: - shifted = self.time_shift(hidden) - if state is not None: - shifted[:, 0] = state[2][:, :, self.layer_id] - if len(shifted.size()) == 2: - shifted = shifted.unsqueeze(1) - shifted = shifted.contiguous() +def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward): + def rwkv_ffn_forward( + self, + hidden: torch.Tensor, + state: List[torch.Tensor]=None, + ): + if hidden.device.type == "xpu": + if hidden.size(1) == 1 and state is not None: + shifted = state[2][self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[2][self.layer_id] + if len(shifted.size()) == 2: + shifted = shifted.unsqueeze(1) - if not hasattr(self, "mixed_mix"): - self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data]) + if not hasattr(self, "mixed_mix"): + self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data]) - import linear_q4_0 - mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) - key, receptance = mixed_result + import linear_q4_0 + mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) + key, receptance = mixed_result - key = torch.square(torch.relu(self.key(key))) - value = self.value(key) - receptance = torch.sigmoid(self.receptance(receptance)) + key = torch.square(torch.relu(self.key(key))) + value = self.value(key) + receptance = torch.sigmoid(self.receptance(receptance)) - if state is not None: - state[2][:, :, self.layer_id] = hidden[:, -1] + if state is not None: + state[2][self.layer_id] = hidden[:, -1] - return receptance * value, state + return receptance * value, state + else: + return origin_rwkv_ffn_forward(self, hidden, state) + + return rwkv_ffn_forward + + +def rwkv_model_forward_wrapper(origin_rwkv_model_forward): + def rwkv_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + # change `state` layout and put `num_hidden_layers` to the highest dim + if input_ids.device.type == "xpu" and use_cache and state is None: + state = [] + batch_size = input_ids.size(0) + hidden_size = self.config.hidden_size + num_hidden_layers = self.config.num_hidden_layers + num_attention_heads = self.config.hidden_size // self.config.num_attention_heads + state.append( + torch.zeros( + (num_hidden_layers, batch_size, hidden_size), + dtype=self.embeddings.weight.dtype, + requires_grad=False, + device=input_ids.device, + ).contiguous() + ) + state.append( + torch.zeros( + ( + num_hidden_layers, + batch_size, + num_attention_heads, + self.config.hidden_size // num_attention_heads, + self.config.hidden_size // num_attention_heads, + ), + dtype=torch.float32, + requires_grad=False, + device=input_ids.device, + ).contiguous() + ) + state.append( + torch.zeros( + (num_hidden_layers, batch_size, hidden_size), + dtype=self.embeddings.weight.dtype, + requires_grad=False, + device=input_ids.device, + ).contiguous() + ) + return origin_rwkv_model_forward( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + state=state, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return rwkv_model_forward