add sdp support for stablelm 3b (#11473)
This commit is contained in:
		
							parent
							
								
									cf8eb7b128
								
							
						
					
					
						commit
						39bcb33a67
					
				
					 2 changed files with 3 additions and 3 deletions
				
			
		| 
						 | 
					@ -93,7 +93,7 @@ def stablelm_model_forward(
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
					    # IPEX-LLM OPT: kv cache and quantize kv cache
 | 
				
			||||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
    use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 96, 128]
 | 
					    use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
 | 
				
			||||||
                       and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
 | 
					                       and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
 | 
				
			||||||
    if use_cache:
 | 
					    if use_cache:
 | 
				
			||||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
					        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states):
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
        query_states.device.type == "xpu"
 | 
					        query_states.device.type == "xpu"
 | 
				
			||||||
        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
					        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
				
			||||||
        and head_dim in [64, 96, 128]
 | 
					        and head_dim in [64, 80, 96, 128]
 | 
				
			||||||
        and q_len != kv_len     # next token
 | 
					        and q_len != kv_len     # next token
 | 
				
			||||||
        and q_len <= 32         # lookup
 | 
					        and q_len <= 32         # lookup
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states):
 | 
				
			||||||
def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
 | 
					def use_sdp_causal(q_len, kv_len, head_dim, query_states, training):
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
        q_len == kv_len     # first token
 | 
					        q_len == kv_len     # first token
 | 
				
			||||||
        and head_dim in [64, 96, 128]           # for now
 | 
					        and head_dim in [64, 80, 96, 128]           # for now
 | 
				
			||||||
        and query_states.device.type == "xpu"   # GPU
 | 
					        and query_states.device.type == "xpu"   # GPU
 | 
				
			||||||
        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
					        and query_states.dtype in [torch.float, torch.half]     # fp32/fp16
 | 
				
			||||||
        and not query_states.requires_grad and not training     # not training
 | 
					        and not query_states.requires_grad and not training     # not training
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue