refactor chatglm4 (#11301)

* glm4

* remove useless code

* stype

* add rope_ratio

* update

* fix fp16

* fix style
This commit is contained in:
Xin Qiu 2024-06-13 18:06:04 +08:00 committed by GitHub
parent 5e25766855
commit f1410d6823
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -20,13 +20,12 @@
import torch
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
import torch.nn.functional as F
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, apply_ipex_rotate_every_two
from ipex_llm.transformers.models.utils import use_sdp
from ipex_llm.transformers.models.chatglm2 import should_split_qkv_tensor, glm_sdpa
from ipex_llm.transformers.models.chatglm2 import split_tensor_along_last_dim
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.chatglm2 import repeat_kv
from transformers.modeling_outputs import BaseModelOutputWithPast
import math
import os
@ -97,6 +96,18 @@ def chatglm4_model_forward_internal(
past_key_values,
padding_mask=attention_mask)
# ipex-llm changes begin
# 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`
# 2. generate `causal_mask` and replace `full_attention_mask` with it
if position_ids is None:
if past_key_values is None:
position_ids = torch.arange(seq_length, dtype=torch.int64, device=inputs_embeds.device)
else:
kv_length = past_key_values[0][0].size(2)
position_ids = torch.arange(kv_length, kv_length + seq_length,
dtype=torch.int64, device=inputs_embeds.device)
position_ids = position_ids.repeat(batch_size, 1)
use_fuse_rope = input_ids.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not self.training
@ -117,11 +128,31 @@ def chatglm4_model_forward_internal(
sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
rotary_pos_emb = (cos, sin)
# Run encoder.
# `full_attention_mask` is not None only when
# `past_key_values` is not None and `seq_length` > 1
if full_attention_mask is not None:
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
mask_value = torch.finfo(inputs_embeds.dtype).min
causal_mask.masked_fill_(full_attention_mask, mask_value)
elif self.training or (inputs_embeds.device.type != "xpu" and past_key_values is None):
full_attention_mask = self.get_masks(input_ids,
past_key_values,
padding_mask=attention_mask)
causal_mask = torch.zeros([batch_size, 1, seq_length, full_attention_mask.size(-1)],
dtype=inputs_embeds.dtype, device=inputs_embeds.device)
mask_value = torch.finfo(inputs_embeds.dtype).min
causal_mask.masked_fill_(full_attention_mask, mask_value)
else:
causal_mask = None
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
inputs_embeds, causal_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
)
# ipex-llm changes end
if presents is not None and type(presents) is torch.Tensor:
presents = presents.split(1, dim=0)
presents = list(presents)
@ -141,7 +172,6 @@ def chatglm4_model_forward_internal(
)
@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [b, np, sq, hn]
b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
@ -165,167 +195,110 @@ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Ten
def chatglm4_attention_forward(
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
):
# hidden_states: [sq, b, h]
# hidden_states: [b, sq, h]
bsz, q_len, _ = hidden_states.size()
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
# =====================
# Query, Key, and Value
# =====================
# past_key_value: [bsz, n_kv_head, seq_len, head_dim]
past_key_value = None if kv_cache is None else (kv_cache[0],
kv_cache[1])
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
device = hidden_states.device
mixed_x_layer = self.query_key_value(hidden_states)
n_head = self.num_attention_heads_per_partition
n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
head_dim = self.hidden_size_per_attention_head
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1] + (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
)
key_layer = key_layer.view(
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
)
else:
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
qkv = self.query_key_value(hidden_states)
# [bs, q_len, np * 3 * hn] -> [bsz, n_head, seq_len, head_dim]
qkv = qkv.view(bsz, q_len, n_head + 2 * n_kv_head, head_dim)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
query_states, key_states, value_states = qkv.split([n_head,
n_kv_head,
n_kv_head], dim=2)
# [b, sq, np, hn] -> [b, np, sq, hn]
query_layer, key_layer, value_layer = [k.transpose(1, 2)
for k in [query_layer, key_layer, value_layer]]
kv_seq_len = key_states.shape[1]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
# apply relative positional encoding (rotary embedding)
if isinstance(rotary_pos_emb, tuple) and len(rotary_pos_emb) == 2:
# use_fuse_rope, see chatglm4_model_forward
cos, sin = rotary_pos_emb
rot_dim = cos.shape[-1]
query_layer = query_layer.transpose(1, 2)
key_layer = key_layer.transpose(1, 2)
query_layer_cur = query_layer[..., :rot_dim]
key_layer_cur = key_layer[..., :rot_dim]
query_layer_cur = query_states[..., :rot_dim]
key_layer_cur = key_states[..., :rot_dim]
# ipex_llm's apply_rotary_embedding can change the origin storage,
# so query_layer will get the result directly.
torch.ops.torch_ipex.apply_rotary_embedding(query_layer_cur, sin, cos, query_layer_cur)
torch.ops.torch_ipex.apply_rotary_embedding(key_layer_cur, sin, cos, key_layer_cur)
query_layer = query_layer.transpose(1, 2)
key_layer = key_layer.transpose(1, 2)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
elif rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = apply_rotary_pos_emb(query_states, rotary_pos_emb)
key_states = apply_rotary_pos_emb(key_states, rotary_pos_emb)
cur_length, batch_size = query_layer.shape[2], query_layer.shape[0]
# adjust key and value for inference
if kv_cache is not None and use_cache:
cache_k, cache_v = kv_cache
past_length = cache_k.size(2)
if cache_k.stride()[1] < (past_length + cur_length) * cache_k.size(3):
max_cache_length = past_length + cur_length + KV_CACHE_ALLOC_BLOCK_LENGTH
new_cache_k, new_cache_v = extend_kv_cache(batch_size,
key_layer.size(1),
self.hidden_size_per_attention_head,
past_length,
max_cache_length,
dtype=query_layer.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
cache_k = new_cache_k
cache_v = new_cache_v
key_layer, value_layer = append_kv_cache(cache_k, cache_v, key_layer, value_layer)
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, hidden_states)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
)
if use_cache:
if kv_cache is None:
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0),
value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
if past_key_value is None:
past_key_value = torch.cat((key_states.unsqueeze(0).unsqueeze(0),
value_states.unsqueeze(0).unsqueeze(0)), dim=1)
else:
kv_cache = (key_layer, value_layer)
past_key_value = (key_states, value_states)
else:
kv_cache = None
past_key_value = None
if self.multi_query_attention:
key_layer = key_layer.unsqueeze(2)
key_layer = key_layer.expand(
-1, -1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1, -1
)
key_layer = key_layer.contiguous().view(
key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
)
value_layer = value_layer.unsqueeze(2)
value_layer = value_layer.expand(
-1, -1,
self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
-1, -1
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:1] +
(self.num_attention_heads_per_partition,) + value_layer.size()[3:]
)
# ==================================
# core attention computation
# ==================================
context_layer = core_attn_forward(query_layer, key_layer, value_layer, attention_mask)
# =================
# Output. [sq, b, h]
# =================
output = self.dense(context_layer)
return output, kv_cache
def core_attn_forward(query_layer, key_layer, value_layer, attention_mask):
L, S = query_layer.shape[2], key_layer.shape[2]
batch_size, n_head, seq_len, head_dim = query_layer.shape
if attention_mask is None and L == S:
if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len):
# split second dim to block size = 8
block_size = 8
query_layer = query_layer.to(key_layer.dtype)
query_split = torch.split(query_layer, block_size, dim=1)
key_split = torch.split(key_layer, block_size, dim=1)
value_split = torch.split(value_layer, block_size, dim=1)
results = []
for q, k, v in zip(query_split, key_split, value_split):
result = glm_sdpa(q, k, v, is_causal=True)
results.append(result)
context_layer = torch.cat(results, dim=1)
# IPEX-LLM OPT: sdp
attn_weights = None
if use_sdp(q_len, kv_seq_len, head_dim, query_states):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
else:
context_layer = glm_sdpa(query_layer,
key_layer,
value_layer,
is_causal=True)
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
import xe_addons
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states, value_states,
attention_mask)
elif query_states.device.type == "cpu":
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, n_head // n_kv_head)
value_states = repeat_kv(value_states, n_head // n_kv_head)
if q_len == kv_seq_len:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, is_causal=True
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask
)
else:
context_layer = glm_sdpa(query_layer,
key_layer,
value_layer,
attention_mask)
context_layer = context_layer.transpose(1, 2).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.reshape(*new_context_layer_shape)
if use_quantize_kv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, n_head // n_kv_head)
value_states = repeat_kv(value_states, n_head // n_kv_head)
attn_weights = torch.matmul(query_states / math.sqrt(head_dim),
key_states.transpose(2, 3)).to(value_states.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
return context_layer
# context_layer's shape: [bsz, n_head, seq_len, head_dim] -> [seq_len, bsz, n_head * head_dim]
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, n_head * head_dim)
output = self.dense(attn_output)
return output, past_key_value