LLM: fix baichuan7b quantize kv abnormal output. (#10504)
* fix abnormal output. * fix style. * fix style.
This commit is contained in:
parent
f0f317b6cf
commit
b9d4280892
1 changed files with 3 additions and 2 deletions
|
|
@ -128,7 +128,8 @@ def baichuan_attention_forward_7b_quantized(
|
||||||
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
bsz, self.num_heads, kv_seq_len, self.head_dim,
|
||||||
device=device, new_layout=True
|
device=device, new_layout=True
|
||||||
)
|
)
|
||||||
key_states, value_states = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
key_states, value_states = append_fp8_kv_cache(k_cache, v_cache, key_states,
|
||||||
|
value_states, new_layout=True)
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
else:
|
else:
|
||||||
k_cache, v_cache = past_key_value
|
k_cache, v_cache = past_key_value
|
||||||
|
|
@ -185,7 +186,7 @@ def baichuan_attention_forward_7b_quantized(
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def baichuan_attention_forward_7b_origin(
|
def baichuan_attention_forward_7b_origin(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue