LLM: fix qwen efficiency issue in perf-test.

This commit is contained in:
Cengguang Zhang 2023-12-18 18:32:54 +08:00 committed by GitHub
parent 8ed89557e5
commit 4d22add4af

View file

@ -195,22 +195,26 @@ def qwen_attention_forward(
None, None,
Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)) Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED))
if not self.use_cache_quantization and SUPPORT_TORCH2: # Remove for efficiency issue on Arc, maybe add later.
if attention_mask is not None: # if not self.use_cache_quantization and SUPPORT_TORCH2:
attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) # if attention_mask is not None:
if causal_mask is not None: # attention_mask = attention_mask.expand(-1, -1, query.size(2), -1)
attention_mask = attention_mask.masked_fill(~causal_mask, # if causal_mask is not None:
torch.finfo(query.dtype).min) # attention_mask = attention_mask.masked_fill(~causal_mask,
else: # torch.finfo(query.dtype).min)
attention_mask = causal_mask # else:
attn_output = F.scaled_dot_product_attention( # attention_mask = causal_mask
query, key, value, attn_mask=attention_mask # attn_output = F.scaled_dot_product_attention(
).transpose(1, 2) # query, key, value, attn_mask=attention_mask
attn_weight = None # ).transpose(1, 2)
else: # attn_weight = None
attn_output, attn_weight = self._attn( # else:
query, key, value, causal_mask, attention_mask, head_mask # attn_output, attn_weight = self._attn(
) # query, key, value, causal_mask, attention_mask, head_mask
# )
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads( context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim attn_output, self.num_heads, self.head_dim
) )