LLM: fix baichuan7b quantize kv abnormal output. (#10504)
* fix abnormal output. * fix style. * fix style.
This commit is contained in:
parent
f0f317b6cf
commit
b9d4280892
1 changed files with 3 additions and 2 deletions
|
|
@ -128,7 +128,8 @@ def baichuan_attention_forward_7b_quantized(
|
|||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||
device=device, new_layout=True
|
||||
)
|
||||
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
|
||||
value_states, new_layout=True)
|
||||
past_key_value = (key_states, value_states)
|
||||
else:
|
||||
k_cache, v_cache = past_key_value
|
||||
|
|
@ -185,7 +186,7 @@ def baichuan_attention_forward_7b_quantized(
|
|||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
|
||||
|
||||
|
||||
def baichuan_attention_forward_7b_origin(
|
||||
|
|
|
|||
Loading…
Reference in a new issue