[LLM] Fix transformer qwen size mismatch and rename causal_mask (#9655)

* Fix size mismatching caused by context_layer
* Change registered_causal_mask to causal_mask
This commit is contained in:
Ziteng Zhang 2023-12-12 20:57:40 +08:00 committed by GitHub
parent 2fe38b4b9b
commit 8931f2eb62

View file

@ -64,7 +64,6 @@ def qwen_attention_forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
@ -171,20 +170,19 @@ def qwen_attention_forward(
q, k, v = query, key, value
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
# b s h d -> b s (h d)
context_layer = context_layer.flatten(2, 3).contiguous()
else:
if query.size(1) == key_size:
registered_causal_mask = torch.tril(
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=key.device)
).view(1, 1, key_size, key_size)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if (
registered_causal_mask is None
causal_mask is None
and self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
@ -192,7 +190,7 @@ def qwen_attention_forward(
):
invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
attn_output, attn_weight = self._attn(
query, key, value, registered_causal_mask, attention_mask, head_mask
query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim