LLM: add esimd sdp support for chatglm3 (#10205)
* add esimd sdp support * fix style
This commit is contained in:
		
							parent
							
								
									7cbc2429a6
								
							
						
					
					
						commit
						34ee1aa91f
					
				
					 1 changed files with 25 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -25,6 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		|||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
			
		||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_esimd_sdp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
| 
						 | 
				
			
			@ -515,23 +516,31 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
 | 
			
		|||
            context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                                           key_layer,
 | 
			
		||||
                                                           value_layer,
 | 
			
		||||
                                                           is_causal=True)
 | 
			
		||||
                                                           is_causal=True).to(key_layer.dtype)
 | 
			
		||||
        else:
 | 
			
		||||
            head_dim = query_layer.size(-1)
 | 
			
		||||
            attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
                                        device=query_layer.device)
 | 
			
		||||
                attention_mask = ~attention_mask
 | 
			
		||||
                if attention_mask.dtype == torch.bool:
 | 
			
		||||
                    attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
                else:
 | 
			
		||||
                    attn_bias += attention_mask
 | 
			
		||||
                attn += attn_bias
 | 
			
		||||
            attn = F.softmax(attn, dim=-1,
 | 
			
		||||
                             dtype=torch.float32).to(value_layer.dtype)
 | 
			
		||||
            context_layer = torch.matmul(attn, value_layer)
 | 
			
		||||
            if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2],
 | 
			
		||||
                             query_layer.shape[-1], query_layer):
 | 
			
		||||
                import linear_fp16_esimd
 | 
			
		||||
                attn_output = linear_fp16_esimd.sdp_forward(query_layer,
 | 
			
		||||
                                                            key_layer,
 | 
			
		||||
                                                            value_layer)
 | 
			
		||||
                context_layer = attn_output.view(query_layer.shape)
 | 
			
		||||
            else:
 | 
			
		||||
                head_dim = query_layer.size(-1)
 | 
			
		||||
                attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                    key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
                if attention_mask is not None:
 | 
			
		||||
                    attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
                                            device=query_layer.device)
 | 
			
		||||
                    attention_mask = ~attention_mask
 | 
			
		||||
                    if attention_mask.dtype == torch.bool:
 | 
			
		||||
                        attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
 | 
			
		||||
                    else:
 | 
			
		||||
                        attn_bias += attention_mask
 | 
			
		||||
                    attn += attn_bias
 | 
			
		||||
                attn = F.softmax(attn, dim=-1,
 | 
			
		||||
                                 dtype=torch.float32).to(value_layer.dtype)
 | 
			
		||||
                context_layer = torch.matmul(attn, value_layer)
 | 
			
		||||
        context_layer = context_layer.permute(2, 0, 1, 3)
 | 
			
		||||
        new_context_layer_shape = context_layer.size()[:-2] + (-1,)
 | 
			
		||||
        context_layer = context_layer.reshape(*new_context_layer_shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue