LLM: optimize CPU speculative decoding of chatglm3 (#9928)
* update * fix style * meet code review
This commit is contained in:
		
							parent
							
								
									967714bac8
								
							
						
					
					
						commit
						bf37b3a670
					
				
					 1 changed files with 6 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -367,17 +367,17 @@ def chatglm2_attention_forward_8eb45c(
 | 
			
		|||
 | 
			
		||||
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
 | 
			
		||||
    pytorch_major_version = int(torch.__version__.split('.')[0])
 | 
			
		||||
    if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
 | 
			
		||||
    if pytorch_major_version >= 2:
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        L, S = query_layer.shape[2], key_layer.shape[2]
 | 
			
		||||
        if attention_mask is None and L == S:
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
            context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                                           key_layer,
 | 
			
		||||
                                                           value_layer,
 | 
			
		||||
                                                           is_causal=True)
 | 
			
		||||
        else:
 | 
			
		||||
            head_dim = query_layer.size(-1)
 | 
			
		||||
            attn = torch.matmul(query_layer,
 | 
			
		||||
            attn = torch.matmul(query_layer.to(key_layer.dtype),
 | 
			
		||||
                                key_layer.transpose(2, 3)) / math.sqrt(head_dim)
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue