fix chatglm2/3-32k/128k fp16 (#11311)
This commit is contained in:
parent
1b0c4c8cb8
commit
7f65836cb9
2 changed files with 9 additions and 43 deletions
|
|
@ -98,14 +98,15 @@ def chatglm2_model_forward(
|
|||
dtype=torch.int64, device=inputs_embeds.device)
|
||||
position_ids = position_ids.repeat(batch_size, 1)
|
||||
|
||||
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
|
||||
if not getattr(self.rotary_pos_emb, "cached", False):
|
||||
rot_dim = self.rotary_pos_emb.dim
|
||||
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype) / rot_dim))
|
||||
dtype=torch.float,
|
||||
device=inputs_embeds.device) / rot_dim))
|
||||
inv_freq = inv_freq.to(inputs_embeds.dtype)
|
||||
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype
|
||||
self.rotary_pos_emb.cached = True
|
||||
|
||||
# `full_attention_mask` is not None only when
|
||||
# `past_key_values` is not None and `seq_length` > 1
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@
|
|||
#
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple, Union
|
||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||
|
|
@ -27,11 +26,6 @@ from ipex_llm.transformers.models.chatglm2 import repeat_kv
|
|||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
import math
|
||||
|
||||
import os
|
||||
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
||||
|
||||
|
||||
def chatglm4_model_forward(
|
||||
self,
|
||||
|
|
@ -45,34 +39,6 @@ def chatglm4_model_forward(
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return chatglm4_model_forward_internal(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
full_attention_mask=full_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
def chatglm4_model_forward_internal(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
|
|
@ -104,16 +70,15 @@ def chatglm4_model_forward_internal(
|
|||
dtype=torch.int64, device=inputs_embeds.device)
|
||||
position_ids = position_ids.repeat(batch_size, 1)
|
||||
|
||||
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
|
||||
if not getattr(self.rotary_pos_emb, "cached", False):
|
||||
rot_dim = self.rotary_pos_emb.dim
|
||||
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
||||
# We should generate float inv_freq to avoid overflow, as base is too large.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
||||
dtype=torch.float,
|
||||
device=inputs_embeds.device) / rot_dim))
|
||||
self.rotary_pos_emb.register_buffer("inv_freq",
|
||||
inv_freq.to(inputs_embeds.dtype),
|
||||
persistent=False)
|
||||
inv_freq = inv_freq.to(inputs_embeds.dtype)
|
||||
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.rotary_pos_emb.cached = True
|
||||
|
||||
# `full_attention_mask` is not None only when
|
||||
|
|
|
|||
Loading…
Reference in a new issue