add glm_sdpa back to fix chatglm-6b (#11313)

This commit is contained in:
Yishuo Wang 2024-06-14 10:31:43 +08:00 committed by GitHub
parent 7f65836cb9
commit 91965b5d05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -23,7 +23,7 @@ import torch.utils.checkpoint
import torch.nn.functional as F
from typing import Optional, Tuple
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.chatglm2 import glm_sdpa
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
def rotate_half(x):
@ -39,6 +39,49 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
key,
value,
attention_mask,
is_causal=is_causal).to(key.dtype)
else:
# attention_mask is not None only when past_key_value is not None and q_len > 1
if attention_mask is not None:
attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype,
device=query.device)
attention_mask = ~attention_mask
if attention_mask.dtype == torch.bool:
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
else:
attn_bias += attention_mask
elif is_causal:
L, S = query.size(-2), key.size(-2)
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(key.dtype)
else:
attn_bias = None
if use_sdp(query.shape[2], key.shape[2],
query.shape[-1], query):
import xe_addons
attn_output = xe_addons.sdp(query, key, value, attn_bias)
context_layer = attn_output.view(query.shape)
else:
head_dim = query.size(-1)
attn = torch.matmul(query.to(key.dtype) / math.sqrt(head_dim),
key.transpose(2, 3))
if attn_bias is not None:
attn += attn_bias
attn = F.softmax(attn, dim=-1,
dtype=torch.float32).to(value.dtype)
context_layer = torch.matmul(attn, value)
return context_layer
import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))