fix stablelm logits diff (#10636)
* fix logits diff * Small fixes --------- Co-authored-by: Yuwen Hu <yuwen.hu@intel.com>
This commit is contained in:
		
							parent
							
								
									97c626d76f
								
							
						
					
					
						commit
						3a9ab8f1ae
					
				
					 3 changed files with 12 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -633,6 +633,7 @@ def _optimize_pre(model):
 | 
			
		|||
                del module.c_attn
 | 
			
		||||
        model.apply(split_qkv_proj_func)
 | 
			
		||||
    if model.config.model_type == "stablelm":
 | 
			
		||||
        # For stablelm-zephyr-3b
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import merge_qkv
 | 
			
		||||
        model.apply(merge_qkv)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1341,6 +1342,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
                        module.BertEncoder,
 | 
			
		||||
                        encoder_forward)
 | 
			
		||||
    elif model.config.model_type == 'stablelm':
 | 
			
		||||
        # For stablelm-zephyr-3b
 | 
			
		||||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from ipex_llm.transformers.models.stablelm import stablelm_attention_forward
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,7 +48,7 @@ from transformers.models.stablelm.modeling_stablelm import StableLmAttention
 | 
			
		|||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
    apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv
 | 
			
		||||
| 
						 | 
				
			
			@ -124,13 +124,15 @@ def stablelm_attention_forward(
 | 
			
		|||
        key_states[..., : self.rotary_emb.dim],
 | 
			
		||||
        key_states[..., self.rotary_emb.dim:],
 | 
			
		||||
    )
 | 
			
		||||
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
    if use_fuse_rope:
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb_no_cache_xpu(query_rot,
 | 
			
		||||
                                                               key_rot,
 | 
			
		||||
                                                               position_ids,
 | 
			
		||||
                                                               "stablelm")
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb_cache_freq_xpu(query_rot,
 | 
			
		||||
                                                                 key_rot,
 | 
			
		||||
                                                                 sin,
 | 
			
		||||
                                                                 cos,
 | 
			
		||||
                                                                 "stablelm",
 | 
			
		||||
                                                                 position_ids)
 | 
			
		||||
    else:
 | 
			
		||||
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 | 
			
		||||
        # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
 | 
			
		||||
        query_rot, key_rot = apply_rotary_pos_emb(query_rot,
 | 
			
		||||
                                                  key_rot,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -208,7 +208,7 @@ def apply_rotary_pos_emb_no_cache_xpu(q, k, position_ids, model_family, rope_the
 | 
			
		|||
    q_embed = torch.empty(q.shape, dtype=q.dtype, device=q.device)
 | 
			
		||||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["llama", "baichuan", "internlm", "aquila", "gpt_neox", "mistral",
 | 
			
		||||
                        "mixtral", "stablelm"]:
 | 
			
		||||
                        "mixtral"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k(q, k, position_ids,
 | 
			
		||||
                                                        q_embed, k_embed, rope_theta)
 | 
			
		||||
        return q_embed, k_embed
 | 
			
		||||
| 
						 | 
				
			
			@ -226,7 +226,7 @@ def apply_rotary_pos_emb_cache_freq_xpu(q, k, sin, cos, model_family, position_i
 | 
			
		|||
    k_embed = torch.empty(k.shape, dtype=k.dtype, device=k.device)
 | 
			
		||||
    if model_family in ["qwen", "mixtral"]:
 | 
			
		||||
        linear_q4_0.apply_rotary_embedding_half_q_and_k_cache_freq(q, k, sin, cos, q_embed, k_embed)
 | 
			
		||||
    elif model_family in ["qwen2", "yuan"]:
 | 
			
		||||
    elif model_family in ["qwen2", "yuan", "stablelm"]:
 | 
			
		||||
        cos = cos.to(q.dtype)
 | 
			
		||||
        sin = sin.to(q.dtype)
 | 
			
		||||
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue