diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index ad5451f7..29e7968f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -128,7 +128,8 @@ def baichuan_attention_forward_7b_quantized( bsz, self.num_heads, kv_seq_len, self.head_dim, device=device, new_layout=True ) - key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states) + key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, + value_states, new_layout=True) past_key_value = (key_states, value_states) else: k_cache, v_cache = past_key_value @@ -185,7 +186,7 @@ def baichuan_attention_forward_7b_quantized( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(hidden_states.dtype), attn_weights, past_key_value def baichuan_attention_forward_7b_origin(