From b9d42808928e9c3dca227f6f92d2268d2586f190 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Fri, 22 Mar 2024 10:00:08 +0800 Subject: [PATCH] LLM: fix baichuan7b quantize kv abnormal output. (#10504) * fix abnormal output. * fix style. * fix style. --- python/llm/src/bigdl/llm/transformers/models/baichuan.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index ad5451f7..29e7968f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -128,7 +128,8 @@ def baichuan_attention_forward_7b_quantized( bsz, self.num_heads, kv_seq_len, self.head_dim, device=device, new_layout=True ) - key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states) + key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states, + value_states, new_layout=True) past_key_value = (key_states, value_states) else: k_cache, v_cache = past_key_value @@ -185,7 +186,7 @@ def baichuan_attention_forward_7b_quantized( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(hidden_states.dtype), attn_weights, past_key_value def baichuan_attention_forward_7b_origin(