LLM: add long-context support for Qwen1.5-7B/Baichuan2-7B/Mistral-7B. (#10937)

* LLM: add split tensor support for baichuan2-7b and qwen1.5-7b.

* fix style.

* fix style.

* fix style.

* add support for mistral and fix condition threshold.

* fix  style.

* fix comments.
This commit is contained in:
Cengguang Zhang 2024-05-10 16:40:15 +08:00 committed by GitHub
parent f9615f12d1
commit cfed76b2ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 376 additions and 101 deletions

View file

@ -49,6 +49,21 @@ import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
if not output_attentions:
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif query_states.dtype == torch.float16 and \
query_states.shape[2] >= 5400:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5400 for now
return True
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
# attn_weight size larger than memory block limitation 4GB
return True
return False
def baichuan_13b_rms_norm_forward(self, hidden_states): def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0 import linear_q4_0
@ -159,6 +174,11 @@ def baichuan_attention_forward_7b_quantized(
if query_states.size(2) != 1 or device.type != 'xpu': if query_states.size(2) != 1 or device.type != 'xpu':
key_states, value_states = restore_fp8_kv_cache(key_states, value_states, key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype) query_states.dtype)
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key_states,
value_states, attention_mask)
else:
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1)) attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
if attention_mask is not None: if attention_mask is not None:
@ -287,9 +307,16 @@ def baichuan_attention_forward_7b_origin(
if attention_mask is not None: if attention_mask is not None:
if attention_mask.dtype == torch.bool: if attention_mask.dtype == torch.bool:
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
key_states,
value_states,
attention_mask)
else:
scaling_factor = 1 / math.sqrt(query_states.size(-1)) scaling_factor = 1 / math.sqrt(query_states.size(-1))
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1)) attn_output = torch.matmul(query_states * scaling_factor,
key_states.transpose(-2, -1))
if attention_mask is not None: if attention_mask is not None:
attn_output += attention_mask attn_output += attention_mask
attn_output = torch.softmax(attn_output, -1) attn_output = torch.softmax(attn_output, -1)
@ -622,3 +649,21 @@ def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
: self.n_head, :seq_length_with_past, :seq_length_with_past : self.n_head, :seq_length_with_past, :seq_length_with_past
] ]
return mask return mask
def native_sdp_split_qkv_tensor(query, key, value, attention_mask):
block_size = 8
query_split = torch.split(query, block_size, dim=1)
key_split = torch.split(key.transpose(-2, -1), block_size, dim=1)
value_split = torch.split(value, block_size, dim=1)
attn_outputs = []
scaling_factor = 1 / math.sqrt(query.size(-1))
for q, k, v in zip(query_split, key_split, value_split):
attn_output_split = torch.matmul(q * scaling_factor, k)
if attention_mask is not None:
attn_output_split += attention_mask
attn_output_split = torch.softmax(attn_output_split, -1)
attn_output_split = torch.matmul(attn_output_split, v)
attn_outputs.append(attn_output_split)
attn_output = torch.cat(attn_outputs, dim=1)
return attn_output, None

View file

@ -89,6 +89,21 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
return use_fuse_rope return use_fuse_rope
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
if not output_attentions:
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif query_states.dtype == torch.float16 and \
query_states.shape[2] >= 6300:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 6300 for now
return True
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
# attn_weight size larger than memory block limitation 4GB
return True
return False
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len, def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
num_heads, head_dim, hidden_size, attention_mask): num_heads, head_dim, hidden_size, attention_mask):
attn_weights = torch.matmul( attn_weights = torch.matmul(
@ -112,9 +127,14 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.\ attn_weights = nn.functional.softmax(attn_weights, dim=-1,
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states.to(query_states.dtype)) attn_output = torch.matmul(attn_weights, value_states.to(query_states.dtype))
if attn_output.size() != (bsz, num_heads, q_len, head_dim): if attn_output.size() != (bsz, num_heads, q_len, head_dim):
@ -130,6 +150,45 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
return attn_output, attn_weights return attn_output, attn_weights
def compute_attn_outputs_weights_split_tensor(query_states, key_states, value_states,
bsz, q_len, kv_seq_len, num_heads, head_dim,
hidden_size, attention_mask):
block_size = 8
query_split = torch.split(query_states.to(key_states.dtype), block_size, dim=1)
key_split = torch.split(key_states.transpose(2, 3), block_size, dim=1)
value_split = torch.split(value_states.to(query_states.dtype), block_size, dim=1)
attn_outputs = []
for q, k, v in zip(query_split, key_split, value_split):
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
block_actual_size = attn_weights_split.size(1)
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
if attn_weights_split.size() != attn_weights_split_size:
invalidInputError(False,
f"Splitted attention weights should be of size "
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1)
if attn_output.size() != (bsz, num_heads, q_len, head_dim):
invalidInputError(
False,
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)},"
f" but is {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
return attn_output, None
def mistral_model_forward_4_36( def mistral_model_forward_4_36(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@ -272,6 +331,34 @@ def mistral_attention_forward_quantized(
dtype=attention_dtype) dtype=attention_dtype)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is None: if past_key_value is None:
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
block_size = 8
query_split = torch.split(query_states.to(key_states.dtype), block_size, dim=1)
key_split = torch.split(key_states.transpose(2, 3), block_size, dim=1)
value_split = torch.split(value_states.to(query_states.dtype), block_size, dim=1)
attn_outputs = []
for q, k, v in zip(query_split, key_split, value_split):
attn_weights_split = torch.matmul(q, k) / math.sqrt(self.head_dim)
block_actual_size = attn_weights_split.size(1)
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
if attn_weights_split.size() != attn_weights_split_size:
invalidInputError(False,
f"Splitted attention weights should be of size "
f"{attn_weights_split_size}, "
f"but is {attn_weights_split.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1)
else:
attn_weights = torch.matmul(query_states.to(key_states.dtype), attn_weights = torch.matmul(query_states.to(key_states.dtype),
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
@ -518,12 +605,29 @@ def mistral_attention_forward_original(
dtype=attention_dtype) dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype) dtype=attention_dtype)
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = compute_attn_outputs_weights_split_tensor(query_states,
key_states,
value_states,
bsz,
q_len,
kv_seq_len,
self.num_heads,
self.head_dim,
self.hidden_size,
attention_mask)
else:
attn_output, attn_weights = compute_attn_outputs_weights(query_states, attn_output, attn_weights = compute_attn_outputs_weights(query_states,
key_states, key_states,
value_states, value_states,
bsz, q_len, kv_seq_len, bsz,
self.num_heads, self.head_dim, q_len,
self.hidden_size, attention_mask) kv_seq_len,
self.num_heads,
self.head_dim,
self.hidden_size,
attention_mask)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@ -653,6 +757,34 @@ def mistral_attention_forward_4_36_quantized(
dtype=attention_dtype) dtype=attention_dtype)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx: if len(past_key_value.key_cache) <= self.layer_idx:
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
block_size = 8
query_split = torch.split(query_states.to(key_states.dtype), block_size, dim=1)
key_split = torch.split(key_states.transpose(2, 3), block_size, dim=1)
value_split = torch.split(value_states.to(query_states.dtype), block_size, dim=1)
attn_outputs = []
for q, k, v in zip(query_split, key_split, value_split):
attn_weights_split = torch.matmul(q, k) / math.sqrt(self.head_dim)
block_actual_size = attn_weights_split.size(1)
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
if attn_weights_split.size() != attn_weights_split_size:
invalidInputError(False,
f"Splitted attention weights should be of size "
f"{attn_weights_split_size}, "
f"but is {attn_weights_split.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1)
else:
attn_weights = torch.matmul(query_states.to(key_states.dtype), attn_weights = torch.matmul(query_states.to(key_states.dtype),
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
@ -673,6 +805,11 @@ def mistral_attention_forward_4_36_quantized(
) )
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
@ -909,10 +1046,25 @@ def mistral_attention_forward_4_36_original(
dtype=attention_dtype) dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype) dtype=attention_dtype)
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = compute_attn_outputs_weights_split_tensor(query_states,
key_states,
value_states,
bsz,
q_len,
kv_seq_len,
self.num_heads,
self.head_dim,
self.hidden_size,
attention_mask)
else:
attn_output, attn_weights = compute_attn_outputs_weights(query_states, attn_output, attn_weights = compute_attn_outputs_weights(query_states,
key_states, key_states,
value_states, value_states,
bsz, q_len, kv_seq_len, bsz,
q_len,
kv_seq_len,
self.num_heads, self.num_heads,
self.head_dim, self.head_dim,
self.hidden_size, self.hidden_size,

View file

@ -74,6 +74,21 @@ import os
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
if not output_attentions:
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
elif query_states.dtype == torch.float16 and \
query_states.shape[2] >= 5000:
# split tensor for memory block limitation
# support fp16 and set input length threshold at 5000 for now
return True
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
# attn_weight size larger than memory block limitation 4GB
return True
return False
def should_use_fuse_rope(self, query_states, position_ids): def should_use_fuse_rope(self, query_states, position_ids):
use_fuse_rope = query_states.device.type == "xpu" use_fuse_rope = query_states.device.type == "xpu"
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad) use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
@ -370,6 +385,15 @@ def qwen2_attention_forward_quantized(
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
key = repeat_kv(key, self.num_key_value_groups) key = repeat_kv(key, self.num_key_value_groups)
value = repeat_kv(value, self.num_key_value_groups) value = repeat_kv(value, self.num_key_value_groups)
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key,
value, attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads,
self.attention_dropout,
self.training)
else:
attn_weights = torch.matmul(query_states, key.transpose(2, 3)) attn_weights = torch.matmul(query_states, key.transpose(2, 3))
attn_weights = attn_weights / math.sqrt(self.head_dim) attn_weights = attn_weights / math.sqrt(self.head_dim)
@ -380,11 +404,17 @@ def qwen2_attention_forward_quantized(
if attention_mask is not None: if attention_mask is not None:
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" (f"Attention mask should be of size "
f"{(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}")) f" but is {attention_mask.size()}"))
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32 # upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
@ -542,6 +572,15 @@ def qwen2_attention_forward_origin(
value_states) value_states)
attn_output = attn_output.view(query_states.shape) attn_output = attn_output.view(query_states.shape)
attn_weights = None attn_weights = None
else:
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key_states,
value_states, attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads,
self.attention_dropout,
self.training)
else: else:
attn_weights = torch.matmul(query_states, attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
@ -553,14 +592,20 @@ def qwen2_attention_forward_origin(
if attention_mask is not None: if attention_mask is not None:
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" (f"Attention mask should be of size "
f"{(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}")) f" but is {attention_mask.size()}"))
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
# for long sequences or large batches
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
else:
# upcast attention to fp32 # upcast attention to fp32
attn_weights = \ attn_weights = nn.functional.softmax(attn_weights, dim=-1,
nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, attn_weights = nn.functional.dropout(attn_weights,
p=self.attention_dropout, p=self.attention_dropout,
training=self.training) training=self.training)
@ -725,3 +770,36 @@ def qwen2_sdpa_attention_forward(
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
bsz, q_len, kv_seq_len, head_dim, num_heads,
attention_dropout, training):
block_size = 8
query_split = torch.split(query, block_size, dim=1)
key_split = torch.split(key.transpose(2, 3), block_size, dim=1)
value_split = torch.split(value, block_size, dim=1)
attn_outputs = []
for q, k, v in zip(query_split, key_split, value_split):
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
block_actual_size = attn_weights_split.size(1)
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
if attn_weights_split.size() != attn_weights_split_size:
invalidInputError(False,
f"Splitted attention weights should be of size "
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_weights_split = nn.functional.dropout(attn_weights_split,
p=attention_dropout,
training=training)
attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1)
return attn_output, None