diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm.py b/python/llm/src/ipex_llm/transformers/models/chatglm.py index 51a2026a..77e2ae44 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm.py @@ -23,7 +23,7 @@ import torch.utils.checkpoint import torch.nn.functional as F from typing import Optional, Tuple from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache -from ipex_llm.transformers.models.chatglm2 import glm_sdpa +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp def rotate_half(x): @@ -39,6 +39,49 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k + +def glm_sdpa(query, key, value, attention_mask=None, is_causal=False): + if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu': + context_layer = F.scaled_dot_product_attention(query.to(key.dtype), + key, + value, + attention_mask, + is_causal=is_causal).to(key.dtype) + else: + # attention_mask is not None only when past_key_value is not None and q_len > 1 + if attention_mask is not None: + attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype, + device=query.device) + attention_mask = ~attention_mask + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + elif is_causal: + L, S = query.size(-2), key.size(-2) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(key.dtype) + else: + attn_bias = None + if use_sdp(query.shape[2], key.shape[2], + query.shape[-1], query): + import xe_addons + attn_output = xe_addons.sdp(query, key, value, attn_bias) + context_layer = attn_output.view(query.shape) + else: + head_dim = query.size(-1) + attn = torch.matmul(query.to(key.dtype) / math.sqrt(head_dim), + key.transpose(2, 3)) + if attn_bias is not None: + attn += attn_bias + attn = F.softmax(attn, dim=-1, + dtype=torch.float32).to(value.dtype) + context_layer = torch.matmul(attn, value) + return context_layer + + import os KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))