LLM: fix qwen efficiency issue in perf-test.
This commit is contained in:
parent
8ed89557e5
commit
4d22add4af
1 changed files with 20 additions and 16 deletions
|
|
@ -195,22 +195,26 @@ def qwen_attention_forward(
|
|||
None,
|
||||
Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
|
||||
|
||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
|
||||
if causal_mask is not None:
|
||||
attention_mask = attention_mask.masked_fill(~causal_mask,
|
||||
torch.finfo(query.dtype).min)
|
||||
else:
|
||||
attention_mask = causal_mask
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask
|
||||
).transpose(1, 2)
|
||||
attn_weight = None
|
||||
else:
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
# Remove for efficiency issue on Arc, maybe add later.
|
||||
# if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
# if attention_mask is not None:
|
||||
# attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
|
||||
# if causal_mask is not None:
|
||||
# attention_mask = attention_mask.masked_fill(~causal_mask,
|
||||
# torch.finfo(query.dtype).min)
|
||||
# else:
|
||||
# attention_mask = causal_mask
|
||||
# attn_output = F.scaled_dot_product_attention(
|
||||
# query, key, value, attn_mask=attention_mask
|
||||
# ).transpose(1, 2)
|
||||
# attn_weight = None
|
||||
# else:
|
||||
# attn_output, attn_weight = self._attn(
|
||||
# query, key, value, causal_mask, attention_mask, head_mask
|
||||
# )
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue