LLM : Fix llama draft_model dtype error (#10005)

* fix llama draft_model dtype error

* updat
This commit is contained in:
Wang, Jian4 2024-01-26 10:59:48 +08:00 committed by GitHub
parent aae1870096
commit 98ea3459e5

View file

@ -729,7 +729,7 @@ def llama_attention_forward_4_36(
def native_sdp(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads):
attn_weights = torch.matmul(query,
attn_weights = torch.matmul(query.to(key.dtype),
key.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)