Use merge_qkv to replace fused_qkv for llama2 (#11727)

* update 4.38

* support new versions

* update

* fix style

* fix style

* update rope

* temp test sdpa

* fix style

* fix cpu ut
This commit is contained in:
Ruonan Wang 2024-08-07 13:04:01 +03:00 committed by GitHub
parent d2abc9711b
commit 00a5574c8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 180 additions and 93 deletions

View file

@ -744,6 +744,9 @@ def _optimize_pre(model, qtype=None):
if model.config.model_type == "gemma2":
from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv)
if model.config.model_type == "llama":
from ipex_llm.transformers.models.llama import merge_qkv
model.apply(merge_qkv)
return model
@ -989,6 +992,10 @@ def _optimize_post(model, lightweight_bmm=False):
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_41)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_41)
else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
convert_forward(
@ -999,6 +1006,10 @@ def _optimize_post(model, lightweight_bmm=False):
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_38)
else:
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward(
@ -1009,6 +1020,10 @@ def _optimize_post(model, lightweight_bmm=False):
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaSdpaAttention,
llama_attention_forward_4_38)
else:
# transformers version between 4.31.0 - 4.35.2
convert_forward(

View file

@ -49,12 +49,13 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
use_sdp_causal
from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check
from ipex_llm.transformers.models.utils import use_decoding_fast_path
from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4
from ipex_llm.ggml.quantize import ggml_tensor_qtype
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base
try:
from transformers.cache_utils import Cache, DynamicCache
@ -66,6 +67,10 @@ from transformers import logging
logger = logging.get_logger(__name__)
def merge_qkv(module: torch.nn.Module):
return merge_qkv_base(module, LlamaAttention)
def llama_decoding_fast_path_qtype_check(proj):
# IQ2_XXS only can be used in Llama-like model
qtype = getattr(proj, "qtype", None)
@ -406,6 +411,9 @@ def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype):
def should_use_xetla_mm_qkv(self, device):
if not hasattr(self, "q_proj"):
# TODO: how to support xetla_mm_qkv for merged_qkv
return False
full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len
supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn
supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5
@ -428,7 +436,8 @@ def llama_attention_forward_4_31(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
self.num_key_value_groups):
forward_function = llama_attention_forward_4_31_quantized
else:
forward_function = llama_attention_forward_4_31_original
@ -466,7 +475,7 @@ def llama_attention_forward_4_31_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -500,9 +509,16 @@ def llama_attention_forward_4_31_quantized(
self.head_dim,
self.rotary_emb.base,)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -516,12 +532,9 @@ def llama_attention_forward_4_31_quantized(
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -604,7 +617,7 @@ def llama_attention_forward_4_31_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -654,7 +667,7 @@ def llama_attention_forward_4_31_original(
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
# only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"):
@ -692,9 +705,19 @@ def llama_attention_forward_4_31_original(
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = \
qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -708,12 +731,9 @@ def llama_attention_forward_4_31_original(
kv_seq_len += past_key_value[0].shape[-2]
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -839,7 +859,7 @@ def llama_attention_selective_batching_forward_4_31(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0])
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -886,9 +906,18 @@ def llama_attention_selective_batching_forward_4_31(
if self.config.pretraining_tp > 1:
invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet")
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -902,12 +931,9 @@ def llama_attention_selective_batching_forward_4_31(
kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
@ -1030,7 +1056,8 @@ def llama_attention_forward_4_41(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
self.num_key_value_groups):
forward_function = llama_attention_forward_4_41_quantized
else:
forward_function = llama_attention_forward_4_41_original
@ -1069,7 +1096,7 @@ def llama_attention_forward_4_41_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -1098,9 +1125,16 @@ def llama_attention_forward_4_41_quantized(
self.head_dim,
self.rotary_emb.base,)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -1122,12 +1156,9 @@ def llama_attention_forward_4_41_quantized(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
@ -1301,7 +1332,7 @@ def llama_attention_forward_4_41_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -1360,7 +1391,7 @@ def llama_attention_forward_4_41_original(
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
# only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"):
@ -1399,9 +1430,20 @@ def llama_attention_forward_4_41_original(
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len,
self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = \
qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -1421,12 +1463,9 @@ def llama_attention_forward_4_41_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
@ -1582,7 +1621,8 @@ def llama_attention_forward_4_38(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups):
if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states,
self.num_key_value_groups):
forward_function = llama_attention_forward_4_38_quantized
else:
forward_function = llama_attention_forward_4_38_original
@ -1621,7 +1661,7 @@ def llama_attention_forward_4_38_quantized(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -1650,9 +1690,16 @@ def llama_attention_forward_4_38_quantized(
self.head_dim,
self.rotary_emb.base,)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -1674,12 +1721,9 @@ def llama_attention_forward_4_38_quantized(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
@ -1853,7 +1897,7 @@ def llama_attention_forward_4_38_original(
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
no_tp = not self.config.pretraining_tp > 1
decoding_fast_path = use_decoding_fast_path(self.q_proj,
decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None),
use_fuse_rope,
enough_kv_room,
bsz * q_len,
@ -1911,7 +1955,7 @@ def llama_attention_forward_4_38_original(
for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
# only use mm_qkv_out on pvc for llama-7b
if not hasattr(self, "qkv_proj_weight"):
@ -1950,9 +1994,20 @@ def llama_attention_forward_4_38_original(
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
value_states = qkv_states[:, :, q_out_len + k_out_len:]
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len,
self.num_heads + 2 * self.num_key_value_heads,
self.head_dim)
query_states, key_states, value_states = \
qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads],
dim=2)
query_states = query_states.view(bsz, q_len,
self.num_heads, self.head_dim).transpose(1, 2)
@ -1972,12 +2027,9 @@ def llama_attention_forward_4_38_original(
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if use_fuse_rope:
rope_theta = self.rotary_emb.base
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
key_states,
position_ids,
"llama",
rope_theta=rope_theta)
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
if cache_position is not None:
# for transformers 4.38.0
@ -2413,9 +2465,16 @@ def llama_attention_fast_forward(
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
if hasattr(self, "q_proj"):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
else:
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split([self.num_heads,
self.num_key_value_heads,
self.num_key_value_heads], dim=2)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,

View file

@ -377,6 +377,8 @@ def use_decoding_fast_path(proj,
enough_kv_room,
bs,
qtype_check=decoding_fast_path_qtype_check):
if proj is None:
return False
device = get_xpu_device_type(proj.weight)
if not qtype_check(proj):
return False
@ -419,6 +421,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool):
def fp16_fusion_check(proj, x, training):
# only use fp16 fusion on PVC inference
if proj is None:
return False
if not hasattr(proj, "qtype"):
return False
if proj.qtype != ggml_tensor_qtype["fp16"]:
@ -491,3 +495,11 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
)
else:
return x.device.type == 'xpu' and use_compress_kv == "1"
def get_q_proj_or_qkv_proj(self):
if hasattr(self, "q_proj"):
proj = self.q_proj
elif hasattr(self, "qkv_proj"):
proj = self.qkv_proj
return proj

View file

@ -152,20 +152,21 @@ def test_optimize_model(Model, Tokenizer, model_path, prompt):
tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt")
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
logits_base_model = (model(input_ids)).logits
with torch.inference_mode():
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=False,
trust_remote_code=True)
logits_base_model = (model(input_ids)).logits
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
logits_optimized_model = (model(input_ids)).logits
diff = abs(logits_base_model - logits_optimized_model).flatten()
model = Model.from_pretrained(model_path,
load_in_4bit=True,
optimize_model=True,
trust_remote_code=True)
logits_optimized_model = (model(input_ids)).logits
diff = abs(logits_base_model - logits_optimized_model).flatten()
assert any(diff) is False
assert any(diff) is False
if __name__ == '__main__':