Optimize rwkv v5 rest token again (#10043)
This commit is contained in:
		
							parent
							
								
									b1ff28ceb6
								
							
						
					
					
						commit
						53a5140eff
					
				
					 2 changed files with 175 additions and 65 deletions
				
			
		| 
						 | 
				
			
			@ -973,13 +973,19 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.rwkv5 import rwkv_attention_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward
 | 
			
		||||
        from bigdl.llm.transformers.models.rwkv5 import rwkv_ffn_forward_wrapper
 | 
			
		||||
        from bigdl.llm.transformers.models.rwkv5 import rwkv_model_forward_wrapper
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.RwkvSelfAttention,
 | 
			
		||||
                        rwkv_attention_forward)
 | 
			
		||||
        rwkv_ffn_forward = rwkv_ffn_forward_wrapper(module.RwkvFeedForward.forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.RwkvFeedForward,
 | 
			
		||||
                        rwkv_ffn_forward)
 | 
			
		||||
        rwkv_model_forward = rwkv_model_forward_wrapper(module.Rwkv5Model.forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.Rwkv5Model,
 | 
			
		||||
                        rwkv_model_forward)
 | 
			
		||||
    elif model.config.model_type == "deci":
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,20 +35,19 @@
 | 
			
		|||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from typing import List
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_key_value(self, hidden, state=None):
 | 
			
		||||
    # Mix hidden with the previous timestep to produce key, value, receptance
 | 
			
		||||
    if hidden.size(1) == 1 and state is not None:
 | 
			
		||||
        shifted = state[0][:, :, self.layer_id]
 | 
			
		||||
        shifted = state[0][self.layer_id]
 | 
			
		||||
    else:
 | 
			
		||||
        shifted = self.time_shift(hidden)
 | 
			
		||||
        if state is not None:
 | 
			
		||||
            shifted[:, 0] = state[0][:, :, self.layer_id]
 | 
			
		||||
            shifted[:, 0] = state[0][self.layer_id]
 | 
			
		||||
    if len(shifted.size()) == 2:
 | 
			
		||||
        shifted = shifted.unsqueeze(1)
 | 
			
		||||
    shifted = shifted.contiguous()
 | 
			
		||||
 | 
			
		||||
    if not hasattr(self, "mixed_mix"):
 | 
			
		||||
        self.mixed_mix = torch.cat([
 | 
			
		||||
| 
						 | 
				
			
			@ -68,7 +67,7 @@ def extract_key_value(self, hidden, state=None):
 | 
			
		|||
    gate = F.silu(self.gate(gate))
 | 
			
		||||
 | 
			
		||||
    if state is not None:
 | 
			
		||||
        state[0][:, :, self.layer_id] = hidden[:, -1]
 | 
			
		||||
        state[0][self.layer_id] = hidden[:, -1]
 | 
			
		||||
 | 
			
		||||
    return receptance, key, value, gate, state
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -97,9 +96,7 @@ def rwkv_linear_attention_xpu(
 | 
			
		|||
    time_decay = torch.exp(-torch.exp(time_decay.float()))
 | 
			
		||||
    time_first = time_first.float()
 | 
			
		||||
 | 
			
		||||
    state = state.contiguous().float()
 | 
			
		||||
 | 
			
		||||
    # `state` will be modified during this call
 | 
			
		||||
    # `state` will be updated inplaced during this call
 | 
			
		||||
    import linear_q4_0
 | 
			
		||||
    out = linear_q4_0.rwkv_linear_attention_v5(
 | 
			
		||||
        time_decay,
 | 
			
		||||
| 
						 | 
				
			
			@ -118,6 +115,50 @@ def rwkv_linear_attention_xpu(
 | 
			
		|||
    out = out.to(dtype=hidden.dtype) * gate
 | 
			
		||||
    # out = out @ ow
 | 
			
		||||
    out = ow(out)
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rwkv_linear_attention_cpu(
 | 
			
		||||
    B,
 | 
			
		||||
    H,
 | 
			
		||||
    S,
 | 
			
		||||
    T,
 | 
			
		||||
    n_head,
 | 
			
		||||
    hidden,
 | 
			
		||||
    time_decay,
 | 
			
		||||
    time_first,
 | 
			
		||||
    receptance,
 | 
			
		||||
    key,
 | 
			
		||||
    value,
 | 
			
		||||
    gate,
 | 
			
		||||
    lxw,
 | 
			
		||||
    lxb,
 | 
			
		||||
    ow,
 | 
			
		||||
    state,
 | 
			
		||||
):
 | 
			
		||||
    key = key.to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
 | 
			
		||||
    value = value.to(torch.float32).view(B, T, H, S).transpose(1, 2)
 | 
			
		||||
    receptance = receptance.to(torch.float32).view(B, T, H, S).transpose(1, 2)
 | 
			
		||||
    time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(n_head, -1, 1)
 | 
			
		||||
    time_first = time_first.float().reshape(-1, 1, 1).reshape(n_head, -1, 1)
 | 
			
		||||
    lxw = lxw.float()
 | 
			
		||||
    lxb = lxb.float()
 | 
			
		||||
    out = torch.zeros_like(key).reshape(B, T, H, S)
 | 
			
		||||
    for t in range(T):
 | 
			
		||||
        rt = receptance[:, :, t:t + 1, :]
 | 
			
		||||
        kt = key[:, :, :, t:t + 1]
 | 
			
		||||
        vt = value[:, :, t:t + 1, :]
 | 
			
		||||
        at = kt @ vt
 | 
			
		||||
        out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            state = at + time_decay * state
 | 
			
		||||
 | 
			
		||||
    out = out.reshape(B * T, H * S)
 | 
			
		||||
    out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H * S)
 | 
			
		||||
    out = out.to(dtype=hidden.dtype) * gate
 | 
			
		||||
    # out = out @ ow
 | 
			
		||||
    out = ow(out)   # fix this
 | 
			
		||||
 | 
			
		||||
    return out, state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -133,15 +174,29 @@ def rwkv_attention_forward(
 | 
			
		|||
    S = hidden.shape[-1] // H
 | 
			
		||||
    T = hidden.shape[1]
 | 
			
		||||
 | 
			
		||||
    receptance, key, value, gate, state = extract_key_value(self, hidden, state=state)
 | 
			
		||||
    layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
 | 
			
		||||
 | 
			
		||||
    if hidden.device.type == "xpu":
 | 
			
		||||
        rwkv, layer_state = rwkv_linear_attention_xpu(
 | 
			
		||||
            B,
 | 
			
		||||
            H,
 | 
			
		||||
            S,
 | 
			
		||||
            T,
 | 
			
		||||
        receptance, key, value, gate, state = extract_key_value(self, hidden, state)
 | 
			
		||||
        # `state`` will be updated inplaced when running on GPU
 | 
			
		||||
        rwkv = rwkv_linear_attention_xpu(
 | 
			
		||||
            B, H, S, T,
 | 
			
		||||
            hidden,
 | 
			
		||||
            self.time_decay,
 | 
			
		||||
            self.time_faaaa,
 | 
			
		||||
            receptance,
 | 
			
		||||
            key,
 | 
			
		||||
            value,
 | 
			
		||||
            gate,
 | 
			
		||||
            self.ln_x.weight,
 | 
			
		||||
            self.ln_x.bias,
 | 
			
		||||
            self.output,
 | 
			
		||||
            state=state[1][self.layer_id],
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state)
 | 
			
		||||
        layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
 | 
			
		||||
        rwkv, layer_state = rwkv_linear_attention_cpu(
 | 
			
		||||
            B, H, S, T,
 | 
			
		||||
            self.num_attention_heads,
 | 
			
		||||
            hidden,
 | 
			
		||||
            self.time_decay,
 | 
			
		||||
            self.time_faaaa,
 | 
			
		||||
| 
						 | 
				
			
			@ -154,60 +209,109 @@ def rwkv_attention_forward(
 | 
			
		|||
            self.output,
 | 
			
		||||
            state=layer_state,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        from transformers.models.rwkv.modeling_rwkv import rwkv_linear_attention_cpu
 | 
			
		||||
        rwkv, layer_state = rwkv_linear_attention_cpu(
 | 
			
		||||
            B,
 | 
			
		||||
            H,
 | 
			
		||||
            S,
 | 
			
		||||
            T,
 | 
			
		||||
            self.num_attention_heads,
 | 
			
		||||
            hidden,
 | 
			
		||||
            self.time_decay,
 | 
			
		||||
            self.time_faaaa,
 | 
			
		||||
            receptance,
 | 
			
		||||
            key,
 | 
			
		||||
            value,
 | 
			
		||||
            gate,
 | 
			
		||||
            self.ln_x.weight,
 | 
			
		||||
            self.ln_x.bias,
 | 
			
		||||
            self.output.weight.t(),
 | 
			
		||||
            state=layer_state,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if layer_state is not None:
 | 
			
		||||
        state[1][:, :, :, :, self.layer_id] = layer_state
 | 
			
		||||
        if layer_state is not None:
 | 
			
		||||
            state[1][:, :, :, :, self.layer_id] = layer_state
 | 
			
		||||
 | 
			
		||||
    return rwkv, state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rwkv_ffn_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden: torch.Tensor,
 | 
			
		||||
    state: List[torch.Tensor]=None,
 | 
			
		||||
):
 | 
			
		||||
    if hidden.size(1) == 1 and state is not None:
 | 
			
		||||
        shifted = state[2][:, :, self.layer_id]
 | 
			
		||||
    else:
 | 
			
		||||
        shifted = self.time_shift(hidden)
 | 
			
		||||
        if state is not None:
 | 
			
		||||
            shifted[:, 0] = state[2][:, :, self.layer_id]
 | 
			
		||||
    if len(shifted.size()) == 2:
 | 
			
		||||
        shifted = shifted.unsqueeze(1)
 | 
			
		||||
    shifted = shifted.contiguous()
 | 
			
		||||
def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
 | 
			
		||||
    def rwkv_ffn_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden: torch.Tensor,
 | 
			
		||||
        state: List[torch.Tensor]=None,
 | 
			
		||||
    ):
 | 
			
		||||
        if hidden.device.type == "xpu":
 | 
			
		||||
            if hidden.size(1) == 1 and state is not None:
 | 
			
		||||
                shifted = state[2][self.layer_id]
 | 
			
		||||
            else:
 | 
			
		||||
                shifted = self.time_shift(hidden)
 | 
			
		||||
                if state is not None:
 | 
			
		||||
                    shifted[:, 0] = state[2][self.layer_id]
 | 
			
		||||
            if len(shifted.size()) == 2:
 | 
			
		||||
                shifted = shifted.unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
    if not hasattr(self, "mixed_mix"):
 | 
			
		||||
        self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
 | 
			
		||||
            if not hasattr(self, "mixed_mix"):
 | 
			
		||||
                self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
 | 
			
		||||
 | 
			
		||||
    import linear_q4_0
 | 
			
		||||
    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
			
		||||
    key, receptance = mixed_result
 | 
			
		||||
            import linear_q4_0
 | 
			
		||||
            mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
			
		||||
            key, receptance = mixed_result
 | 
			
		||||
 | 
			
		||||
    key = torch.square(torch.relu(self.key(key)))
 | 
			
		||||
    value = self.value(key)
 | 
			
		||||
    receptance = torch.sigmoid(self.receptance(receptance))
 | 
			
		||||
            key = torch.square(torch.relu(self.key(key)))
 | 
			
		||||
            value = self.value(key)
 | 
			
		||||
            receptance = torch.sigmoid(self.receptance(receptance))
 | 
			
		||||
 | 
			
		||||
    if state is not None:
 | 
			
		||||
        state[2][:, :, self.layer_id] = hidden[:, -1]
 | 
			
		||||
            if state is not None:
 | 
			
		||||
                state[2][self.layer_id] = hidden[:, -1]
 | 
			
		||||
 | 
			
		||||
    return receptance * value, state
 | 
			
		||||
            return receptance * value, state
 | 
			
		||||
        else:
 | 
			
		||||
            return origin_rwkv_ffn_forward(self, hidden, state)
 | 
			
		||||
 | 
			
		||||
    return rwkv_ffn_forward
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rwkv_model_forward_wrapper(origin_rwkv_model_forward):
 | 
			
		||||
    def rwkv_model_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.LongTensor] = None,  # noqa
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        state: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        # change `state` layout and put `num_hidden_layers` to the highest dim
 | 
			
		||||
        if input_ids.device.type == "xpu" and use_cache and state is None:
 | 
			
		||||
            state = []
 | 
			
		||||
            batch_size = input_ids.size(0)
 | 
			
		||||
            hidden_size = self.config.hidden_size
 | 
			
		||||
            num_hidden_layers = self.config.num_hidden_layers
 | 
			
		||||
            num_attention_heads = self.config.hidden_size // self.config.num_attention_heads
 | 
			
		||||
            state.append(
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    (num_hidden_layers, batch_size, hidden_size),
 | 
			
		||||
                    dtype=self.embeddings.weight.dtype,
 | 
			
		||||
                    requires_grad=False,
 | 
			
		||||
                    device=input_ids.device,
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
            )
 | 
			
		||||
            state.append(
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    (
 | 
			
		||||
                        num_hidden_layers,
 | 
			
		||||
                        batch_size,
 | 
			
		||||
                        num_attention_heads,
 | 
			
		||||
                        self.config.hidden_size // num_attention_heads,
 | 
			
		||||
                        self.config.hidden_size // num_attention_heads,
 | 
			
		||||
                    ),
 | 
			
		||||
                    dtype=torch.float32,
 | 
			
		||||
                    requires_grad=False,
 | 
			
		||||
                    device=input_ids.device,
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
            )
 | 
			
		||||
            state.append(
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    (num_hidden_layers, batch_size, hidden_size),
 | 
			
		||||
                    dtype=self.embeddings.weight.dtype,
 | 
			
		||||
                    requires_grad=False,
 | 
			
		||||
                    device=input_ids.device,
 | 
			
		||||
                ).contiguous()
 | 
			
		||||
            )
 | 
			
		||||
        return origin_rwkv_model_forward(
 | 
			
		||||
            self=self,
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            state=state,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return rwkv_model_forward
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue