diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 81feb4aa..549d137d 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -64,7 +64,6 @@ def qwen_attention_forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - registered_causal_mask: Optional[torch.Tensor] = None, layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -171,20 +170,19 @@ def qwen_attention_forward( q, k, v = query, key, value context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask) - # b s h d -> b s (h d) - context_layer = context_layer.flatten(2, 3).contiguous() - else: if query.size(1) == key_size: - registered_causal_mask = torch.tril( + causal_mask = torch.tril( torch.ones((key_size, key_size), dtype=torch.bool, device=key.device) ).view(1, 1, key_size, key_size) + else: + causal_mask = None query = query.permute(0, 2, 1, 3) if not self.use_cache_quantization: key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) if ( - registered_causal_mask is None + causal_mask is None and self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32 @@ -192,7 +190,7 @@ def qwen_attention_forward( ): invalidInputError(False, _ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) attn_output, attn_weight = self._attn( - query, key, value, registered_causal_mask, attention_mask, head_mask + query, key, value, causal_mask, attention_mask, head_mask ) context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim