fix chatglm2/3-32k/128k fp16 (#11311)
This commit is contained in:
parent
1b0c4c8cb8
commit
7f65836cb9
2 changed files with 9 additions and 43 deletions
|
|
@ -98,14 +98,15 @@ def chatglm2_model_forward(
|
||||||
dtype=torch.int64, device=inputs_embeds.device)
|
dtype=torch.int64, device=inputs_embeds.device)
|
||||||
position_ids = position_ids.repeat(batch_size, 1)
|
position_ids = position_ids.repeat(batch_size, 1)
|
||||||
|
|
||||||
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
|
if not getattr(self.rotary_pos_emb, "cached", False):
|
||||||
rot_dim = self.rotary_pos_emb.dim
|
rot_dim = self.rotary_pos_emb.dim
|
||||||
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
||||||
device=inputs_embeds.device,
|
dtype=torch.float,
|
||||||
dtype=inputs_embeds.dtype) / rot_dim))
|
device=inputs_embeds.device) / rot_dim))
|
||||||
|
inv_freq = inv_freq.to(inputs_embeds.dtype)
|
||||||
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype
|
self.rotary_pos_emb.cached = True
|
||||||
|
|
||||||
# `full_attention_mask` is not None only when
|
# `full_attention_mask` is not None only when
|
||||||
# `past_key_values` is not None and `seq_length` > 1
|
# `past_key_values` is not None and `seq_length` > 1
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
from typing import Optional, Tuple, Union
|
||||||
import torch.nn.functional as F
|
|
||||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
|
|
@ -27,11 +26,6 @@ from ipex_llm.transformers.models.chatglm2 import repeat_kv
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
|
||||||
KV_CACHE_ALLOC_MIN_LENGTH = 512
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm4_model_forward(
|
def chatglm4_model_forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -45,34 +39,6 @@ def chatglm4_model_forward(
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return chatglm4_model_forward_internal(
|
|
||||||
self=self,
|
|
||||||
input_ids=input_ids,
|
|
||||||
position_ids=position_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
full_attention_mask=full_attention_mask,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm4_model_forward_internal(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
full_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else
|
output_hidden_states if output_hidden_states is not None else
|
||||||
self.config.output_hidden_states
|
self.config.output_hidden_states
|
||||||
|
|
@ -104,16 +70,15 @@ def chatglm4_model_forward_internal(
|
||||||
dtype=torch.int64, device=inputs_embeds.device)
|
dtype=torch.int64, device=inputs_embeds.device)
|
||||||
position_ids = position_ids.repeat(batch_size, 1)
|
position_ids = position_ids.repeat(batch_size, 1)
|
||||||
|
|
||||||
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
|
if not getattr(self.rotary_pos_emb, "cached", False):
|
||||||
rot_dim = self.rotary_pos_emb.dim
|
rot_dim = self.rotary_pos_emb.dim
|
||||||
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
|
||||||
# We should generate float inv_freq to avoid overflow, as base is too large.
|
# We should generate float inv_freq to avoid overflow, as base is too large.
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=inputs_embeds.device) / rot_dim))
|
device=inputs_embeds.device) / rot_dim))
|
||||||
self.rotary_pos_emb.register_buffer("inv_freq",
|
inv_freq = inv_freq.to(inputs_embeds.dtype)
|
||||||
inv_freq.to(inputs_embeds.dtype),
|
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
persistent=False)
|
|
||||||
self.rotary_pos_emb.cached = True
|
self.rotary_pos_emb.cached = True
|
||||||
|
|
||||||
# `full_attention_mask` is not None only when
|
# `full_attention_mask` is not None only when
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue