add glm_sdpa back to fix chatglm-6b (#11313)
This commit is contained in:
parent
7f65836cb9
commit
91965b5d05
1 changed files with 44 additions and 1 deletions
|
|
@ -23,7 +23,7 @@ import torch.utils.checkpoint
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from ipex_llm.transformers.models.chatglm2 import glm_sdpa
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
|
|
@ -39,6 +39,49 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
|
||||||
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
||||||
return q, k
|
return q, k
|
||||||
|
|
||||||
|
|
||||||
|
def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
|
||||||
|
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
|
||||||
|
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask,
|
||||||
|
is_causal=is_causal).to(key.dtype)
|
||||||
|
else:
|
||||||
|
# attention_mask is not None only when past_key_value is not None and q_len > 1
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
attention_mask = ~attention_mask
|
||||||
|
if attention_mask.dtype == torch.bool:
|
||||||
|
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
|
else:
|
||||||
|
attn_bias += attention_mask
|
||||||
|
elif is_causal:
|
||||||
|
L, S = query.size(-2), key.size(-2)
|
||||||
|
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
||||||
|
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
||||||
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||||
|
attn_bias.to(key.dtype)
|
||||||
|
else:
|
||||||
|
attn_bias = None
|
||||||
|
if use_sdp(query.shape[2], key.shape[2],
|
||||||
|
query.shape[-1], query):
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp(query, key, value, attn_bias)
|
||||||
|
context_layer = attn_output.view(query.shape)
|
||||||
|
else:
|
||||||
|
head_dim = query.size(-1)
|
||||||
|
attn = torch.matmul(query.to(key.dtype) / math.sqrt(head_dim),
|
||||||
|
key.transpose(2, 3))
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn += attn_bias
|
||||||
|
attn = F.softmax(attn, dim=-1,
|
||||||
|
dtype=torch.float32).to(value.dtype)
|
||||||
|
context_layer = torch.matmul(attn, value)
|
||||||
|
return context_layer
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue