[LLM] Fix Qwen registered_causal_mask is None (#9513)
* Add registered_causal_mask init based on 2abd8e5777.
This commit is contained in:
parent
11fa5a8a0e
commit
0f0c6bb631
1 changed files with 3 additions and 0 deletions
|
|
@ -174,6 +174,9 @@ def qwen_attention_forward(
|
||||||
context_layer = context_layer.flatten(2, 3).contiguous()
|
context_layer = context_layer.flatten(2, 3).contiguous()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
registered_causal_mask = torch.tril(
|
||||||
|
torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device)
|
||||||
|
).view(1, 1, key.size(1), key.size(1))
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
if not self.use_cache_quantization:
|
if not self.use_cache_quantization:
|
||||||
key = key.permute(0, 2, 1, 3)
|
key = key.permute(0, 2, 1, 3)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue