[LLM] Fix transformer qwen size mismatch and rename causal_mask (#9655)
* Fix size mismatching caused by context_layer * Change registered_causal_mask to causal_mask
This commit is contained in:
parent
2fe38b4b9b
commit
8931f2eb62
1 changed files with 5 additions and 7 deletions
|
|
@ -64,7 +64,6 @@ def qwen_attention_forward(
|
|||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
|
||||
registered_causal_mask: Optional[torch.Tensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
|
|
@ -171,20 +170,19 @@ def qwen_attention_forward(
|
|||
q, k, v = query, key, value
|
||||
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
||||
|
||||
# b s h d -> b s (h d)
|
||||
context_layer = context_layer.flatten(2, 3).contiguous()
|
||||
|
||||
else:
|
||||
if query.size(1) == key_size:
|
||||
registered_causal_mask = torch.tril(
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key_size, key_size), dtype=torch.bool, device=key.device)
|
||||
).view(1, 1, key_size, key_size)
|
||||
else:
|
||||
causal_mask = None
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
if not self.use_cache_quantization:
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
if (
|
||||
registered_causal_mask is None
|
||||
causal_mask is None
|
||||
and self.use_flash_attn
|
||||
and flash_attn_unpadded_func is not None
|
||||
and not self.is_fp32
|
||||
|
|
@ -192,7 +190,7 @@ def qwen_attention_forward(
|
|||
):
|
||||
invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, registered_causal_mask, attention_mask, head_mask
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
context_layer = self._merge_heads(
|
||||
attn_output, self.num_heads, self.head_dim
|
||||
|
|
|
|||
Loading…
Reference in a new issue