fix chatglm2/3-32k/128k fp16 (#11311)

This commit is contained in:
Yishuo Wang 2024-06-14 09:58:07 +08:00 committed by GitHub
parent 1b0c4c8cb8
commit 7f65836cb9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 43 deletions

View file

@ -98,14 +98,15 @@ def chatglm2_model_forward(
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
if not getattr(self.rotary_pos_emb, "cached", False):
rot_dim = self.rotary_pos_emb.dim
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype) / rot_dim))
dtype=torch.float,
device=inputs_embeds.device) / rot_dim))
inv_freq = inv_freq.to(inputs_embeds.dtype)
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
self.rotary_pos_emb.cached_dtype = inputs_embeds.dtype
self.rotary_pos_emb.cached = True
# `full_attention_mask` is not None only when
# `past_key_values` is not None and `seq_length` > 1

View file

@ -18,8 +18,7 @@
#
import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F
from typing import Optional, Tuple, Union
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
@ -27,11 +26,6 @@ from ipex_llm.transformers.models.chatglm2 import repeat_kv
from transformers.modeling_outputs import BaseModelOutputWithPast
import math
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
KV_CACHE_ALLOC_MIN_LENGTH = 512
def chatglm4_model_forward(
self,
@ -45,34 +39,6 @@ def chatglm4_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
return chatglm4_model_forward_internal(
self=self,
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
full_attention_mask=full_attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def chatglm4_model_forward_internal(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]=None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states
@ -104,16 +70,15 @@ def chatglm4_model_forward_internal(
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)
if getattr(self.rotary_pos_emb, "cached_dtype", None) != inputs_embeds.dtype:
if not getattr(self.rotary_pos_emb, "cached", False):
rot_dim = self.rotary_pos_emb.dim
base = 10000 * getattr(self.rotary_pos_emb, "rope_ratio", 1)
# We should generate float inv_freq to avoid overflow, as base is too large.
inv_freq = 1.0 / (base ** (torch.arange(0, rot_dim, 2,
dtype=torch.float,
device=inputs_embeds.device) / rot_dim))
self.rotary_pos_emb.register_buffer("inv_freq",
inv_freq.to(inputs_embeds.dtype),
persistent=False)
inv_freq = inv_freq.to(inputs_embeds.dtype)
self.rotary_pos_emb.register_buffer("inv_freq", inv_freq, persistent=False)
self.rotary_pos_emb.cached = True
# `full_attention_mask` is not None only when