LLM: add fp8 sdp for chatglm2/3 (#10411)
* add fp8 sdp for chatglm2 * fix style
This commit is contained in:
		
							parent
							
								
									fe8976a00f
								
							
						
					
					
						commit
						b036205be2
					
				
					 1 changed files with 17 additions and 19 deletions
				
			
		| 
						 | 
					@ -97,7 +97,7 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten
 | 
				
			||||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
					def chatglm_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
					    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).contiguous()
 | 
					        x_2d = hidden_states.reshape(-1, hidden_states.size(-1)).to(self.weight.dtype).contiguous()
 | 
				
			||||||
        output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
 | 
					        output = linear_q4_0.rms_norm(self.weight, x_2d, self.eps)
 | 
				
			||||||
        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
					        if 1 < x_2d.size(0) <= 64:   # may use XMX, need copy
 | 
				
			||||||
            output = output.clone()
 | 
					            output = output.clone()
 | 
				
			||||||
| 
						 | 
					@ -260,7 +260,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
				
			||||||
                                                 n_kv_head,
 | 
					                                                 n_kv_head,
 | 
				
			||||||
                                                 seq_len,
 | 
					                                                 seq_len,
 | 
				
			||||||
                                                 head_dim,
 | 
					                                                 head_dim,
 | 
				
			||||||
                                                 query_layer.device)
 | 
					                                                 query_layer.device,
 | 
				
			||||||
 | 
					                                                 new_layout=True)
 | 
				
			||||||
            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
					            k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        k_cache, v_cache = kv_cache
 | 
					        k_cache, v_cache = kv_cache
 | 
				
			||||||
| 
						 | 
					@ -268,31 +269,28 @@ def chatglm2_quantized_attention_forward_8eb45c(
 | 
				
			||||||
        v_cache = v_cache.permute(1, 2, 0, 3)
 | 
					        v_cache = v_cache.permute(1, 2, 0, 3)
 | 
				
			||||||
        # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
 | 
					        # k_cache, v_cache's shape: [bs, n_kv_head, seq_len, head_dim]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer)
 | 
					        k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache, key_layer, value_layer,
 | 
				
			||||||
 | 
					                                               new_layout=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if seq_len != 1:
 | 
					        if seq_len != 1:
 | 
				
			||||||
            key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
 | 
					            key, value = restore_fp8_kv_cache(k_cache, v_cache, query_layer.dtype)
 | 
				
			||||||
            key, value = repeat_kv(key, value, n_head)
 | 
					            key, value = repeat_kv(key, value, n_head)
 | 
				
			||||||
            attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
 | 
					            attn = torch.matmul(query_layer, key.transpose(2, 3)) / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					            if attention_mask is not None:
 | 
				
			||||||
 | 
					                attention_mask = ~attention_mask
 | 
				
			||||||
 | 
					                attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
				
			||||||
 | 
					                                        device=query_layer.device)
 | 
				
			||||||
 | 
					                if attention_mask.dtype == torch.bool:
 | 
				
			||||||
 | 
					                    attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    attn_bias += attention_mask
 | 
				
			||||||
 | 
					                attn += attn_bias
 | 
				
			||||||
 | 
					            attn = F.softmax(attn, dim=-1, dtype=torch.float32)
 | 
				
			||||||
 | 
					            context_layer = torch.matmul(attn.to(value.dtype), value)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            key, value = k_cache, v_cache
 | 
					            key, value = k_cache, v_cache
 | 
				
			||||||
            import linear_q4_0
 | 
					            import linear_q4_0
 | 
				
			||||||
            attn = linear_q4_0.query_key_fp8_matmul(query_layer, key) / math.sqrt(head_dim)
 | 
					            context_layer = linear_q4_0.sdp_fp8(query_layer, key, value)
 | 
				
			||||||
        if attention_mask is not None:
 | 
					 | 
				
			||||||
            attention_mask = ~attention_mask
 | 
					 | 
				
			||||||
            attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
					 | 
				
			||||||
                                    device=query_layer.device)
 | 
					 | 
				
			||||||
            if attention_mask.dtype == torch.bool:
 | 
					 | 
				
			||||||
                attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                attn_bias += attention_mask
 | 
					 | 
				
			||||||
            attn += attn_bias
 | 
					 | 
				
			||||||
        attn = F.softmax(attn, dim=-1, dtype=torch.float32)
 | 
					 | 
				
			||||||
        if seq_len != 1:
 | 
					 | 
				
			||||||
            context_layer = torch.matmul(attn.to(value.dtype), value)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            import linear_q4_0
 | 
					 | 
				
			||||||
            context_layer = linear_q4_0.attn_value_fp8_matmul(attn, value.transpose(-1, -2))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim]
 | 
					    # context_layer's shape: [bs, n_head, seq_len, head_dim] -> [seq_len, bs, n_head * head_dim]
 | 
				
			||||||
    context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1)
 | 
					    context_layer = context_layer.permute(2, 0, 1, 3).contiguous().view(seq_len, batch_size, -1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue