LLM: fix baichuan7b quantize kv abnormal output. (#10504)

* fix abnormal output.

* fix style.

* fix style.
This commit is contained in:
Cengguang Zhang 2024-03-22 10:00:08 +08:00 committed by GitHub
parent f0f317b6cf
commit b9d4280892

View file

@ -128,7 +128,8 @@ def baichuan_attention_forward_7b_quantized(
bsz, self.num_heads, kv_seq_len, self.head_dim,
device=device, new_layout=True
)
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
value_states, new_layout=True)
past_key_value = (key_states, value_states)
else:
k_cache, v_cache = past_key_value
@ -185,7 +186,7 @@ def baichuan_attention_forward_7b_quantized(
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
def baichuan_attention_forward_7b_origin(