LLM : Fix llama draft_model dtype error (#10005)
* fix llama draft_model dtype error * updat
This commit is contained in:
parent
aae1870096
commit
98ea3459e5
1 changed files with 1 additions and 1 deletions
|
|
@ -729,7 +729,7 @@ def llama_attention_forward_4_36(
|
|||
|
||||
def native_sdp(query, key, value, attention_mask,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||
attn_weights = torch.matmul(query,
|
||||
attn_weights = torch.matmul(query.to(key.dtype),
|
||||
key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
|
||||
attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)
|
||||
|
|
|
|||
Loading…
Reference in a new issue