Fix llama attention optimization for XPU (#8855)
* Fix llama attention optimization fo XPU * fix chatglm2 * fix typo
This commit is contained in:
parent
7b566bf686
commit
3b4f4e1c3d
2 changed files with 21 additions and 11 deletions
|
|
@ -86,6 +86,7 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
# =====================
|
# =====================
|
||||||
|
|
||||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||||
|
device = hidden_states.device
|
||||||
mixed_x_layer = self.query_key_value(hidden_states)
|
mixed_x_layer = self.query_key_value(hidden_states)
|
||||||
|
|
||||||
if self.multi_query_attention:
|
if self.multi_query_attention:
|
||||||
|
|
@ -151,11 +152,13 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
self.kv_cache = (torch.empty(batch_size,
|
self.kv_cache = (torch.empty(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length,
|
self.max_cache_length,
|
||||||
self.hidden_size_per_attention_head,),
|
self.hidden_size_per_attention_head,
|
||||||
|
device=device),
|
||||||
torch.empty(batch_size,
|
torch.empty(batch_size,
|
||||||
self.num_attention_heads_per_partition,
|
self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length,
|
self.max_cache_length,
|
||||||
self.hidden_size_per_attention_head,))
|
self.hidden_size_per_attention_head,
|
||||||
|
device=device))
|
||||||
self.kv_cache[0][:, :, :past_length, :] = cache_k
|
self.kv_cache[0][:, :, :past_length, :] = cache_k
|
||||||
self.kv_cache[1][:, :, :past_length, :] = cache_v
|
self.kv_cache[1][:, :, :past_length, :] = cache_v
|
||||||
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
self.kv_cache[0][:, :, past_length:past_length + cur_length, :] = key_layer
|
||||||
|
|
@ -168,9 +171,11 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
self.max_cache_length = max(KV_CACHE_ALLOC_MIN_LENGTH, cur_length) \
|
||||||
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
+ KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
|
self.kv_cache = (torch.empty(batch_size, self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length, self.hidden_size_per_attention_head,),
|
self.max_cache_length, self.hidden_size_per_attention_head,
|
||||||
|
device=device),
|
||||||
torch.empty(batch_size, self.num_attention_heads_per_partition,
|
torch.empty(batch_size, self.num_attention_heads_per_partition,
|
||||||
self.max_cache_length, self.hidden_size_per_attention_head,))
|
self.max_cache_length, self.hidden_size_per_attention_head,
|
||||||
|
device=device))
|
||||||
self.kv_cache[0][:, :, :cur_length, :] = key_layer
|
self.kv_cache[0][:, :, :cur_length, :] = key_layer
|
||||||
self.kv_cache[1][:, :, :cur_length, :] = value_layer
|
self.kv_cache[1][:, :, :cur_length, :] = value_layer
|
||||||
|
|
||||||
|
|
@ -196,7 +201,7 @@ def chatglm2_attention_forward_8eb45c(
|
||||||
|
|
||||||
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attention_mask):
|
||||||
pytorch_major_version = int(torch.__version__.split('.')[0])
|
pytorch_major_version = int(torch.__version__.split('.')[0])
|
||||||
if query_layer.size(0) > 1 and pytorch_major_version >= 2:
|
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
||||||
|
|
||||||
|
|
@ -222,7 +227,6 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
else:
|
else:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = ~attention_mask
|
|
||||||
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
attention_mask = attention_mask.masked_fill(~attention_mask, -float('inf'), )
|
||||||
if torch.is_autocast_cpu_enabled():
|
if torch.is_autocast_cpu_enabled():
|
||||||
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
query_layer = query_layer.to(torch.get_autocast_cpu_dtype())
|
||||||
|
|
@ -258,6 +262,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
matmul_result = torch.empty(
|
matmul_result = torch.empty(
|
||||||
output_size[0] * output_size[1],
|
output_size[0] * output_size[1],
|
||||||
output_size[2], output_size[3], dtype=query_layer.dtype,
|
output_size[2], output_size[3], dtype=query_layer.dtype,
|
||||||
|
device=query_layer.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Raw attention scores. [b * np, sq, sk]
|
# Raw attention scores. [b * np, sq, sk]
|
||||||
|
|
@ -313,6 +318,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
context_layer = torch.empty(
|
context_layer = torch.empty(
|
||||||
output_size[0] * output_size[1],
|
output_size[0] * output_size[1],
|
||||||
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
|
output_size[2], value_layer.size(-1), dtype=value_layer.dtype,
|
||||||
|
device=value_layer.device,
|
||||||
)
|
)
|
||||||
torch.bmm(attention_probs, value_layer, out=context_layer)
|
torch.bmm(attention_probs, value_layer, out=context_layer)
|
||||||
# change view [b, np, sq, hn]
|
# change view [b, np, sq, hn]
|
||||||
|
|
|
||||||
|
|
@ -129,11 +129,13 @@ def llama_attention_forward_4_31(
|
||||||
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
if kv_seq_len > self.max_cache_length:
|
if kv_seq_len > self.max_cache_length:
|
||||||
new_cache_key = torch.empty(bsz, self.num_heads,
|
new_cache_key = torch.empty(bsz, self.num_heads,
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim,
|
||||||
|
device=device)
|
||||||
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
|
new_cache_key[:, :, :kv_seq_len-1, :] = self.kv_cache[0][:, :, :kv_seq_len-1, :]
|
||||||
|
|
||||||
new_cache_value = torch.empty(bsz, self.num_heads,
|
new_cache_value = torch.empty(bsz, self.num_heads,
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim)
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH, self.head_dim,
|
||||||
|
device=device)
|
||||||
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
|
new_cache_value[:, :, :kv_seq_len-1, :] = self.kv_cache[1][:, :, :kv_seq_len-1, :]
|
||||||
self.kv_cache = (new_cache_key, new_cache_value)
|
self.kv_cache = (new_cache_key, new_cache_value)
|
||||||
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
self.max_cache_length = kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
|
@ -146,8 +148,10 @@ def llama_attention_forward_4_31(
|
||||||
# first token case
|
# first token case
|
||||||
self.max_cache_length = max(min(self.max_position_embeddings, 2 * kv_seq_len),
|
self.max_cache_length = max(min(self.max_position_embeddings, 2 * kv_seq_len),
|
||||||
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH)
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim),
|
self.kv_cache = (torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim,
|
||||||
torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim))
|
dtype=key_states.dtype, device=device),
|
||||||
|
torch.empty(bsz, self.num_heads, self.max_cache_length, self.head_dim,
|
||||||
|
dtype=key_states.dtype, device=device))
|
||||||
self.kv_cache[0][:, :, :kv_seq_len, :] = key_states
|
self.kv_cache[0][:, :, :kv_seq_len, :] = key_states
|
||||||
self.kv_cache[1][:, :, :kv_seq_len, :] = value_states
|
self.kv_cache[1][:, :, :kv_seq_len, :] = value_states
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue