fix chatglm2/3-32k/128k fp16 (#11311)
This commit is contained in:
		
							parent
							
								
									1b0c4c8cb8
								
							
						
					
					
						commit
						7f65836cb9
					
				
					 2 changed files with 9 additions and 43 deletions
				
			
		| 
						 | 
				
			
			@ -98,14 +98,15 @@ def chatglm2_model_forward(
 | 
			
		|||
                                        dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        position_ids = position_ids.repeat(batch_size, 1)
 | 
			
		||||
 | 
			
		||||
    if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
 | 
			
		||||
    if not getattr(self.rotary_pos_emb, "cached", False):
 | 
			
		||||
        rot_dim = self.rotary_pos_emb.dim
 | 
			
		||||
        base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
 | 
			
		||||
        inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
 | 
			
		||||
                                                device=inputs_embeds.device,
 | 
			
		||||
                                                dtype=inputs_embeds.dtype) / rot_dim))
 | 
			
		||||
                                                dtype=torch.float,
 | 
			
		||||
                                                device=inputs_embeds.device) / rot_dim))
 | 
			
		||||
        inv_freq = inv_freq.to(inputs_embeds.dtype)
 | 
			
		||||
        self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
			
		||||
        self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype
 | 
			
		||||
        self.rotary_pos_emb.cached = True
 | 
			
		||||
 | 
			
		||||
    # `full_attention_mask` is not None only when
 | 
			
		||||
    #  `past_key_values` is not None and `seq_length` > 1
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,8 +18,7 @@
 | 
			
		|||
#
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
| 
						 | 
				
			
			@ -27,11 +26,6 @@ from ipex_llm.transformers.models.chatglm2 import repeat_kv
 | 
			
		|||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm4_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
| 
						 | 
				
			
			@ -45,34 +39,6 @@ def chatglm4_model_forward(
 | 
			
		|||
    output_hidden_states: Optional[bool] = None,
 | 
			
		||||
    return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
    from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    return chatglm4_model_forward_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        position_ids=position_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
        full_attention_mask=full_attention_mask,
 | 
			
		||||
        past_key_values=past_key_values,
 | 
			
		||||
        inputs_embeds=inputs_embeds,
 | 
			
		||||
        use_cache=use_cache,
 | 
			
		||||
        output_hidden_states=output_hidden_states,
 | 
			
		||||
        return_dict=return_dict,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chatglm4_model_forward_internal(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids,
 | 
			
		||||
        position_ids: Optional[torch.Tensor] = None,
 | 
			
		||||
        attention_mask: Optional[torch.BoolTensor] = None,
 | 
			
		||||
        full_attention_mask: Optional[torch.BoolTensor] = None,
 | 
			
		||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
 | 
			
		||||
        inputs_embeds: Optional[torch.Tensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
):
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
        self.config.output_hidden_states
 | 
			
		||||
| 
						 | 
				
			
			@ -104,16 +70,15 @@ def chatglm4_model_forward_internal(
 | 
			
		|||
                                        dtype=torch.int64, device=inputs_embeds.device)
 | 
			
		||||
        position_ids = position_ids.repeat(batch_size, 1)
 | 
			
		||||
 | 
			
		||||
    if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
 | 
			
		||||
    if not getattr(self.rotary_pos_emb, "cached", False):
 | 
			
		||||
        rot_dim = self.rotary_pos_emb.dim
 | 
			
		||||
        base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
 | 
			
		||||
        # We should generate float inv_freq to avoid overflow, as base is too large.
 | 
			
		||||
        inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
 | 
			
		||||
                                                dtype=torch.float,
 | 
			
		||||
                                                device=inputs_embeds.device) / rot_dim))
 | 
			
		||||
        self.rotary_pos_emb.register_buffer("inv_freq",
 | 
			
		||||
                                            inv_freq.to(inputs_embeds.dtype),
 | 
			
		||||
                                            persistent=False)
 | 
			
		||||
        inv_freq = inv_freq.to(inputs_embeds.dtype)
 | 
			
		||||
        self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
 | 
			
		||||
        self.rotary_pos_emb.cached = True
 | 
			
		||||
 | 
			
		||||
    # `full_attention_mask` is not None only when
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue