Baichuan 7b fp16 sdp and qwen2 pvc sdp (#10435)
* add baichuan sdp * update * baichuan2 * fix * fix style * revert 13b * revert
This commit is contained in:
		
							parent
							
								
									5ab52ef5b5
								
							
						
					
					
						commit
						399843faf0
					
				
					 3 changed files with 77 additions and 33 deletions
				
			
		| 
						 | 
					@ -24,8 +24,10 @@ from typing import List, Optional, Tuple, Union
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.utils.checkpoint
 | 
					import torch.utils.checkpoint
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
 | 
					import torch.nn.functional as F
 | 
				
			||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
					from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
				
			||||||
from bigdl.llm.utils.common import invalidInputError
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
				
			||||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
					    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
					from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \
 | 
				
			||||||
| 
						 | 
					@ -267,7 +269,24 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
					    past_key_value = (key_states, value_states) if use_cache else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					    if not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					            use_flash_attention(query_states, key_states, attention_mask):
 | 
				
			||||||
 | 
					        attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     key_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     value_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     is_causal=True)
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					    elif not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
 | 
					        import linear_fp16_esimd
 | 
				
			||||||
 | 
					        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
				
			||||||
 | 
					                                                    key_states,
 | 
				
			||||||
 | 
					                                                    value_states)
 | 
				
			||||||
 | 
					        attn_output = attn_output.view(query_states.shape)
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        attn_weights = torch.matmul(query_states,
 | 
				
			||||||
 | 
					                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
					        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 | 
				
			||||||
            invalidInputError(False,
 | 
					            invalidInputError(False,
 | 
				
			||||||
| 
						 | 
					@ -280,7 +299,8 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
                              f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
 | 
					                              f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, "
 | 
				
			||||||
                              f"but is {attention_mask.size()}")
 | 
					                              f"but is {attention_mask.size()}")
 | 
				
			||||||
            attn_weights = attn_weights + attention_mask
 | 
					            attn_weights = attn_weights + attention_mask
 | 
				
			||||||
        attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 | 
					            attn_weights = torch.max(attn_weights,
 | 
				
			||||||
 | 
					                                     torch.tensor(torch.finfo(attn_weights.dtype).min))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # upcast attention to fp32
 | 
					        # upcast attention to fp32
 | 
				
			||||||
        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					        attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
| 
						 | 
					@ -300,7 +320,7 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
    if not output_attentions:
 | 
					    if not output_attentions:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def baichuan_attention_forward_13b(
 | 
					def baichuan_attention_forward_13b(
 | 
				
			||||||
| 
						 | 
					@ -502,4 +522,4 @@ def baichuan_attention_forward_13b_origin(
 | 
				
			||||||
    if not output_attentions:
 | 
					    if not output_attentions:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,6 +28,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv
 | 
				
			||||||
    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
					    restore_fp8_kv_cache, use_quantize_kv_cache
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
					from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \
 | 
				
			||||||
    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
					    append_kv_cache, is_enough_kv_cache_room_4_31
 | 
				
			||||||
 | 
					from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, SILU
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
					from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
 | 
					from bigdl.llm.transformers.models.utils import mlp_fusion_check
 | 
				
			||||||
| 
						 | 
					@ -270,6 +271,22 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
        attn_output = xops.memory_efficient_attention(
 | 
					        attn_output = xops.memory_efficient_attention(
 | 
				
			||||||
            query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
 | 
					            query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					                use_flash_attention(query_states, key_states, attention_mask):
 | 
				
			||||||
 | 
					            attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                         key_states.to(dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                         value_states.to(dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                         is_causal=True)
 | 
				
			||||||
 | 
					            attn_weights = None
 | 
				
			||||||
 | 
					        elif not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					                use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
 | 
					            import linear_fp16_esimd
 | 
				
			||||||
 | 
					            attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
				
			||||||
 | 
					                                                        key_states,
 | 
				
			||||||
 | 
					                                                        value_states)
 | 
				
			||||||
 | 
					            attn_output = attn_output.view(query_states.shape)
 | 
				
			||||||
 | 
					            attn_weights = None
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if attention_mask is not None:
 | 
					            if attention_mask is not None:
 | 
				
			||||||
                if attention_mask.dtype == torch.bool:
 | 
					                if attention_mask.dtype == torch.bool:
 | 
				
			||||||
| 
						 | 
					@ -289,7 +306,7 @@ def baichuan_attention_forward_7b_origin(
 | 
				
			||||||
    if not output_attentions:
 | 
					    if not output_attentions:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def baichuan_attention_forward_13b(
 | 
					def baichuan_attention_forward_13b(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -348,6 +348,13 @@ def qwen2_attention_forward_origin(
 | 
				
			||||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not self.training and not hidden_states.requires_grad and \
 | 
					    if not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
 | 
					            use_flash_attention(query_states, key_states, attention_mask):
 | 
				
			||||||
 | 
					        attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     key_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     value_states.to(device, dtype=torch.float16),
 | 
				
			||||||
 | 
					                                                     is_causal=True)
 | 
				
			||||||
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					    elif not self.training and not hidden_states.requires_grad and \
 | 
				
			||||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
					            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
				
			||||||
        import linear_fp16_esimd
 | 
					        import linear_fp16_esimd
 | 
				
			||||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
					        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue