diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 16abd567..9981aea3 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -744,6 +744,9 @@ def _optimize_pre(model, qtype=None): if model.config.model_type == "gemma2": from ipex_llm.transformers.models.gemma2 import merge_qkv model.apply(merge_qkv) + if model.config.model_type == "llama": + from ipex_llm.transformers.models.llama import merge_qkv + model.apply(merge_qkv) return model @@ -989,6 +992,10 @@ def _optimize_post(model, lightweight_bmm=False): model, transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_forward_4_41) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaSdpaAttention, + llama_attention_forward_4_41) else: from ipex_llm.transformers.models.llama import llama_model_forward_4_38 convert_forward( @@ -999,6 +1006,10 @@ def _optimize_post(model, lightweight_bmm=False): model, transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_forward_4_38) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaSdpaAttention, + llama_attention_forward_4_38) else: from ipex_llm.transformers.models.llama import llama_model_forward_4_36 convert_forward( @@ -1009,6 +1020,10 @@ def _optimize_post(model, lightweight_bmm=False): model, transformers.models.llama.modeling_llama.LlamaAttention, llama_attention_forward_4_38) + convert_forward( + model, + transformers.models.llama.modeling_llama.LlamaSdpaAttention, + llama_attention_forward_4_38) else: # transformers version between 4.31.0 - 4.35.2 convert_forward( diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 805a129e..e1b2d5f1 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -49,12 +49,13 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \ use_sdp_causal from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check -from ipex_llm.transformers.models.utils import use_decoding_fast_path +from ipex_llm.transformers.models.utils import use_decoding_fast_path, get_q_proj_or_qkv_proj from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaModel +from transformers.models.llama.modeling_llama import LlamaModel, LlamaAttention from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4 from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.utils.common import invalidInputError +from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base try: from transformers.cache_utils import Cache, DynamicCache @@ -66,6 +67,10 @@ from transformers import logging logger = logging.get_logger(__name__) +def merge_qkv(module: torch.nn.Module): + return merge_qkv_base(module, LlamaAttention) + + def llama_decoding_fast_path_qtype_check(proj): # IQ2_XXS only can be used in Llama-like model qtype = getattr(proj, "qtype", None) @@ -406,6 +411,9 @@ def fuse_qkv_weight_xetla(q_proj, k_proj, v_proj, qtype): def should_use_xetla_mm_qkv(self, device): + if not hasattr(self, "q_proj"): + # TODO: how to support xetla_mm_qkv for merged_qkv + return False full_attn = self.q_proj.out_len == self.k_proj.out_len == self.v_proj.out_len supported_qtype = self.q_proj.qtype == SYM_INT4 and full_attn supported_qtype = supported_qtype or self.q_proj.qtype == FP8E5 @@ -428,7 +436,8 @@ def llama_attention_forward_4_31( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): + if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states, + self.num_key_value_groups): forward_function = llama_attention_forward_4_31_quantized else: forward_function = llama_attention_forward_4_31_original @@ -466,7 +475,7 @@ def llama_attention_forward_4_31_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -500,9 +509,16 @@ def llama_attention_forward_4_31_quantized( self.head_dim, self.rotary_emb.base,) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -516,12 +532,9 @@ def llama_attention_forward_4_31_quantized( kv_seq_len += past_key_value[0].shape[-2] if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -604,7 +617,7 @@ def llama_attention_forward_4_31_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -654,7 +667,7 @@ def llama_attention_forward_4_31_original( for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: - if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ + if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \ hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): @@ -692,9 +705,19 @@ def llama_attention_forward_4_31_original( key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] value_states = qkv_states[:, :, q_out_len + k_out_len:] else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, + self.head_dim) + query_states, key_states, value_states = \ + qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], + dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -708,12 +731,9 @@ def llama_attention_forward_4_31_original( kv_seq_len += past_key_value[0].shape[-2] if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -839,7 +859,7 @@ def llama_attention_selective_batching_forward_4_31( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0]) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -886,9 +906,18 @@ def llama_attention_selective_batching_forward_4_31( if self.config.pretraining_tp > 1: invalidInputError(False, f"vLLM: config.pretraining_tp > 1 not supported yet") else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, + self.head_dim) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], + dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -902,12 +931,9 @@ def llama_attention_selective_batching_forward_4_31( kv_seq_len += max(kv_pair[0].shape[-2] for kv_pair in past_key_value) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, @@ -1030,7 +1056,8 @@ def llama_attention_forward_4_41( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): + if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states, + self.num_key_value_groups): forward_function = llama_attention_forward_4_41_quantized else: forward_function = llama_attention_forward_4_41_original @@ -1069,7 +1096,7 @@ def llama_attention_forward_4_41_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -1098,9 +1125,16 @@ def llama_attention_forward_4_41_quantized( self.head_dim, self.rotary_emb.base,) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -1122,12 +1156,9 @@ def llama_attention_forward_4_41_quantized( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: if cache_position is not None: # for transformers 4.38.0 @@ -1301,7 +1332,7 @@ def llama_attention_forward_4_41_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -1360,7 +1391,7 @@ def llama_attention_forward_4_41_original( for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: - if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ + if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \ hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): @@ -1399,9 +1430,20 @@ def llama_attention_forward_4_41_original( key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] value_states = qkv_states[:, :, q_out_len + k_out_len:] else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, + self.num_heads + 2 * self.num_key_value_heads, + self.head_dim) + query_states, key_states, value_states = \ + qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], + dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -1421,12 +1463,9 @@ def llama_attention_forward_4_41_original( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: if cache_position is not None: # for transformers 4.38.0 @@ -1582,7 +1621,8 @@ def llama_attention_forward_4_38( cache_position: Optional[torch.LongTensor] = None, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]: - if use_quantize_kv_cache(self.q_proj, hidden_states, self.num_key_value_groups): + if use_quantize_kv_cache(get_q_proj_or_qkv_proj(self), hidden_states, + self.num_key_value_groups): forward_function = llama_attention_forward_4_38_quantized else: forward_function = llama_attention_forward_4_38_original @@ -1621,7 +1661,7 @@ def llama_attention_forward_4_38_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -1650,9 +1690,16 @@ def llama_attention_forward_4_38_quantized( self.head_dim, self.rotary_emb.base,) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -1674,12 +1721,9 @@ def llama_attention_forward_4_38_quantized( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: if cache_position is not None: # for transformers 4.38.0 @@ -1853,7 +1897,7 @@ def llama_attention_forward_4_38_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = use_decoding_fast_path(self.q_proj, + decoding_fast_path = use_decoding_fast_path(getattr(self, "q_proj", None), use_fuse_rope, enough_kv_room, bsz * q_len, @@ -1911,7 +1955,7 @@ def llama_attention_forward_4_38_original( for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: - if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ + if fp16_fusion_check(getattr(self, "q_proj", None), hidden_states, self.training) and \ hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): @@ -1950,9 +1994,20 @@ def llama_attention_forward_4_38_original( key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len] value_states = qkv_states[:, :, q_out_len + k_out_len:] else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, + self.num_heads + 2 * self.num_key_value_heads, + self.head_dim) + query_states, key_states, value_states = \ + qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], + dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -1972,12 +2027,9 @@ def llama_attention_forward_4_38_original( kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_fuse_rope: - rope_theta = self.rotary_emb.base - query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, - key_states, - position_ids, - "llama", - rope_theta=rope_theta) + import xe_addons + xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids, + query_states, key_states) else: if cache_position is not None: # for transformers 4.38.0 @@ -2413,9 +2465,16 @@ def llama_attention_fast_forward( value_states = torch.cat(value_states, dim=-1) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if hasattr(self, "q_proj"): + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + else: + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 0b344c69..14375dd6 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -377,6 +377,8 @@ def use_decoding_fast_path(proj, enough_kv_room, bs, qtype_check=decoding_fast_path_qtype_check): + if proj is None: + return False device = get_xpu_device_type(proj.weight) if not qtype_check(proj): return False @@ -419,6 +421,8 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool): def fp16_fusion_check(proj, x, training): # only use fp16 fusion on PVC inference + if proj is None: + return False if not hasattr(proj, "qtype"): return False if proj.qtype != ggml_tensor_qtype["fp16"]: @@ -491,3 +495,11 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int): ) else: return x.device.type == 'xpu' and use_compress_kv == "1" + + +def get_q_proj_or_qkv_proj(self): + if hasattr(self, "q_proj"): + proj = self.q_proj + elif hasattr(self, "qkv_proj"): + proj = self.qkv_proj + return proj diff --git a/python/llm/test/inference/test_transformers_api.py b/python/llm/test/inference/test_transformers_api.py index f16773c6..f04f4257 100644 --- a/python/llm/test/inference/test_transformers_api.py +++ b/python/llm/test/inference/test_transformers_api.py @@ -152,20 +152,21 @@ def test_optimize_model(Model, Tokenizer, model_path, prompt): tokenizer = Tokenizer.from_pretrained(model_path, trust_remote_code=True) input_ids = tokenizer.encode(prompt, return_tensors="pt") - model = Model.from_pretrained(model_path, - load_in_4bit=True, - optimize_model=False, - trust_remote_code=True) - logits_base_model = (model(input_ids)).logits + with torch.inference_mode(): + model = Model.from_pretrained(model_path, + load_in_4bit=True, + optimize_model=False, + trust_remote_code=True) + logits_base_model = (model(input_ids)).logits - model = Model.from_pretrained(model_path, - load_in_4bit=True, - optimize_model=True, - trust_remote_code=True) - logits_optimized_model = (model(input_ids)).logits - diff = abs(logits_base_model - logits_optimized_model).flatten() + model = Model.from_pretrained(model_path, + load_in_4bit=True, + optimize_model=True, + trust_remote_code=True) + logits_optimized_model = (model(input_ids)).logits + diff = abs(logits_base_model - logits_optimized_model).flatten() - assert any(diff) is False + assert any(diff) is False if __name__ == '__main__':