LLM: integrate sdp kernel for FP16 rest token inference on GPU [DG2/ATSM] (#9633)
* integrate sdp * update api * fix style * meet code review * fix * distinguish mtl from arc * small fix
This commit is contained in:
		
							parent
							
								
									5b0e7e308c
								
							
						
					
					
						commit
						dc5b1d7e9d
					
				
					 2 changed files with 39 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -249,8 +249,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
 | 
			
		||||
                            # convert here
 | 
			
		||||
                            m, n = module.weight.data.shape
 | 
			
		||||
                            trans_weight = module.weight.data.reshape(m//16, 16, n)
 | 
			
		||||
                            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
                            if module.in_features == 11008:
 | 
			
		||||
                                trans_weight = module.weight.data.reshape(m//8, 8, n)
 | 
			
		||||
                                trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
                            elif module.in_features == 4096:
 | 
			
		||||
                                trans_weight = module.weight.data.reshape(m//16, 16, n)
 | 
			
		||||
                                trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(trans_weight)
 | 
			
		||||
                            if module.bias is not None:
 | 
			
		||||
                                new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -219,6 +219,13 @@ def llama_attention_forward_4_31(
 | 
			
		|||
                                                     value_states,
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(q_len, self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states.contiguous(),
 | 
			
		||||
                                                    value_states.contiguous())
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
        # otherwise, use native attention
 | 
			
		||||
        attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -266,6 +273,32 @@ def check_flash_attention_available(query):
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_esimd_sdp(q_len, head_dim, query_states):
 | 
			
		||||
    if head_dim != 128:
 | 
			
		||||
        # esimd_sdp only support head_dim = 128 now
 | 
			
		||||
        return False
 | 
			
		||||
    elif q_len != 1:
 | 
			
		||||
        # esimd_sdp only support rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.device.type != "xpu":
 | 
			
		||||
        # esimd_sdp only support GPU now
 | 
			
		||||
        return False
 | 
			
		||||
    elif query_states.dtype != torch.float16:
 | 
			
		||||
        # esimd_sdp only has optimization for FP16 now
 | 
			
		||||
        return False
 | 
			
		||||
    else:
 | 
			
		||||
        device_name = torch.xpu.get_device_name(query_states.device.index)
 | 
			
		||||
        if device_name.startswith("Intel(R) Arc(TM) A") or \
 | 
			
		||||
                device_name.startswith("Intel(R) Data Center GPU Flex"):
 | 
			
		||||
            import linear_fp16_esimd
 | 
			
		||||
            if hasattr(linear_fp16_esimd, "sdp_forward"):
 | 
			
		||||
                return True
 | 
			
		||||
            else:
 | 
			
		||||
                return False
 | 
			
		||||
        else:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def native_sdp(query, key, value, attention_mask,
 | 
			
		||||
               bsz, q_len, kv_seq_len, head_dim, num_heads):
 | 
			
		||||
    attn_weights = torch.matmul(query,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue