LLM: add long-context support for Qwen1.5-7B/Baichuan2-7B/Mistral-7B. (#10937)
* LLM: add split tensor support for baichuan2-7b and qwen1.5-7b. * fix style. * fix style. * fix style. * add support for mistral and fix condition threshold. * fix style. * fix comments.
This commit is contained in:
parent
f9615f12d1
commit
cfed76b2ed
3 changed files with 376 additions and 101 deletions
|
|
@ -49,6 +49,21 @@ import os
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
|
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||||
|
if not output_attentions:
|
||||||
|
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||||
|
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||||
|
elif query_states.dtype == torch.float16 and \
|
||||||
|
query_states.shape[2] >= 5400:
|
||||||
|
# split tensor for memory block limitation
|
||||||
|
# support fp16 and set input length threshold at 5400 for now
|
||||||
|
return True
|
||||||
|
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
|
||||||
|
# attn_weight size larger than memory block limitation 4GB
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
@ -159,6 +174,11 @@ def baichuan_attention_forward_7b_quantized(
|
||||||
if query_states.size(2) != 1 or device.type != 'xpu':
|
if query_states.size(2) != 1 or device.type != 'xpu':
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
else:
|
||||||
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
|
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
|
@ -287,9 +307,16 @@ def baichuan_attention_forward_7b_origin(
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if attention_mask.dtype == torch.bool:
|
if attention_mask.dtype == torch.bool:
|
||||||
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
||||||
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
|
attn_output = torch.matmul(query_states * scaling_factor,
|
||||||
|
key_states.transpose(-2, -1))
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_output += attention_mask
|
attn_output += attention_mask
|
||||||
attn_output = torch.softmax(attn_output, -1)
|
attn_output = torch.softmax(attn_output, -1)
|
||||||
|
|
@ -622,3 +649,21 @@ def baichuan_13b_get_alibi_mask(self, tensor, seq_length_with_past):
|
||||||
: self.n_head, :seq_length_with_past, :seq_length_with_past
|
: self.n_head, :seq_length_with_past, :seq_length_with_past
|
||||||
]
|
]
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def native_sdp_split_qkv_tensor(query, key, value, attention_mask):
|
||||||
|
block_size = 8
|
||||||
|
query_split = torch.split(query, block_size, dim=1)
|
||||||
|
key_split = torch.split(key.transpose(-2, -1), block_size, dim=1)
|
||||||
|
value_split = torch.split(value, block_size, dim=1)
|
||||||
|
attn_outputs = []
|
||||||
|
scaling_factor = 1 / math.sqrt(query.size(-1))
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
attn_output_split = torch.matmul(q * scaling_factor, k)
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_output_split += attention_mask
|
||||||
|
attn_output_split = torch.softmax(attn_output_split, -1)
|
||||||
|
attn_output_split = torch.matmul(attn_output_split, v)
|
||||||
|
attn_outputs.append(attn_output_split)
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
return attn_output, None
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,21 @@ def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||||
return use_fuse_rope
|
return use_fuse_rope
|
||||||
|
|
||||||
|
|
||||||
|
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||||
|
if not output_attentions:
|
||||||
|
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||||
|
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||||
|
elif query_states.dtype == torch.float16 and \
|
||||||
|
query_states.shape[2] >= 6300:
|
||||||
|
# split tensor for memory block limitation
|
||||||
|
# support fp16 and set input length threshold at 6300 for now
|
||||||
|
return True
|
||||||
|
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
|
||||||
|
# attn_weight size larger than memory block limitation 4GB
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
|
def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len,
|
||||||
num_heads, head_dim, hidden_size, attention_mask):
|
num_heads, head_dim, hidden_size, attention_mask):
|
||||||
attn_weights = torch.matmul(
|
attn_weights = torch.matmul(
|
||||||
|
|
@ -112,9 +127,14 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.\
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, value_states.to(query_states.dtype))
|
attn_output = torch.matmul(attn_weights, value_states.to(query_states.dtype))
|
||||||
|
|
||||||
if attn_output.size() != (bsz, num_heads, q_len, head_dim):
|
if attn_output.size() != (bsz, num_heads, q_len, head_dim):
|
||||||
|
|
@ -130,6 +150,45 @@ def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def compute_attn_outputs_weights_split_tensor(query_states, key_states, value_states,
|
||||||
|
bsz, q_len, kv_seq_len, num_heads, head_dim,
|
||||||
|
hidden_size, attention_mask):
|
||||||
|
block_size = 8
|
||||||
|
query_split = torch.split(query_states.to(key_states.dtype), block_size, dim=1)
|
||||||
|
key_split = torch.split(key_states.transpose(2, 3), block_size, dim=1)
|
||||||
|
value_split = torch.split(value_states.to(query_states.dtype), block_size, dim=1)
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||||
|
block_actual_size = attn_weights_split.size(1)
|
||||||
|
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
|
||||||
|
if attn_weights_split.size() != attn_weights_split_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Splitted attention weights should be of size "
|
||||||
|
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
|
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
if attn_output.size() != (bsz, num_heads, q_len, head_dim):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"`attn_output` should be of size {(bsz, num_heads, q_len, head_dim)},"
|
||||||
|
f" but is {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||||
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
def mistral_model_forward_4_36(
|
def mistral_model_forward_4_36(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|
@ -272,6 +331,34 @@ def mistral_attention_forward_quantized(
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is None:
|
if past_key_value is None:
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
block_size = 8
|
||||||
|
query_split = torch.split(query_states.to(key_states.dtype), block_size, dim=1)
|
||||||
|
key_split = torch.split(key_states.transpose(2, 3), block_size, dim=1)
|
||||||
|
value_split = torch.split(value_states.to(query_states.dtype), block_size, dim=1)
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
attn_weights_split = torch.matmul(q, k) / math.sqrt(self.head_dim)
|
||||||
|
block_actual_size = attn_weights_split.size(1)
|
||||||
|
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
|
||||||
|
if attn_weights_split.size() != attn_weights_split_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Splitted attention weights should be of size "
|
||||||
|
f"{attn_weights_split_size}, "
|
||||||
|
f"but is {attn_weights_split.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
|
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
else:
|
||||||
attn_weights = torch.matmul(query_states.to(key_states.dtype),
|
attn_weights = torch.matmul(query_states.to(key_states.dtype),
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
|
@ -518,12 +605,29 @@ def mistral_attention_forward_original(
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = compute_attn_outputs_weights_split_tensor(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
kv_seq_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz,
|
||||||
self.num_heads, self.head_dim,
|
q_len,
|
||||||
self.hidden_size, attention_mask)
|
kv_seq_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
attention_mask)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
|
@ -653,6 +757,34 @@ def mistral_attention_forward_4_36_quantized(
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
block_size = 8
|
||||||
|
query_split = torch.split(query_states.to(key_states.dtype), block_size, dim=1)
|
||||||
|
key_split = torch.split(key_states.transpose(2, 3), block_size, dim=1)
|
||||||
|
value_split = torch.split(value_states.to(query_states.dtype), block_size, dim=1)
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
attn_weights_split = torch.matmul(q, k) / math.sqrt(self.head_dim)
|
||||||
|
block_actual_size = attn_weights_split.size(1)
|
||||||
|
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
|
||||||
|
if attn_weights_split.size() != attn_weights_split_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Splitted attention weights should be of size "
|
||||||
|
f"{attn_weights_split_size}, "
|
||||||
|
f"but is {attn_weights_split.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
|
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
else:
|
||||||
attn_weights = torch.matmul(query_states.to(key_states.dtype),
|
attn_weights = torch.matmul(query_states.to(key_states.dtype),
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
|
@ -673,6 +805,11 @@ def mistral_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
|
@ -909,10 +1046,25 @@ def mistral_attention_forward_4_36_original(
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = compute_attn_outputs_weights_split_tensor(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
kv_seq_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz,
|
||||||
|
q_len,
|
||||||
|
kv_seq_len,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,21 @@ import os
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
|
def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||||
|
if not output_attentions:
|
||||||
|
if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None:
|
||||||
|
return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1"
|
||||||
|
elif query_states.dtype == torch.float16 and \
|
||||||
|
query_states.shape[2] >= 5000:
|
||||||
|
# split tensor for memory block limitation
|
||||||
|
# support fp16 and set input length threshold at 5000 for now
|
||||||
|
return True
|
||||||
|
elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3:
|
||||||
|
# attn_weight size larger than memory block limitation 4GB
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def should_use_fuse_rope(self, query_states, position_ids):
|
def should_use_fuse_rope(self, query_states, position_ids):
|
||||||
use_fuse_rope = query_states.device.type == "xpu"
|
use_fuse_rope = query_states.device.type == "xpu"
|
||||||
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
|
use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad)
|
||||||
|
|
@ -370,6 +385,15 @@ def qwen2_attention_forward_quantized(
|
||||||
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype)
|
||||||
key = repeat_kv(key, self.num_key_value_groups)
|
key = repeat_kv(key, self.num_key_value_groups)
|
||||||
value = repeat_kv(value, self.num_key_value_groups)
|
value = repeat_kv(value, self.num_key_value_groups)
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key,
|
||||||
|
value, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len,
|
||||||
|
self.head_dim, self.num_heads,
|
||||||
|
self.attention_dropout,
|
||||||
|
self.training)
|
||||||
|
else:
|
||||||
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
attn_weights = torch.matmul(query_states, key.transpose(2, 3))
|
||||||
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
attn_weights = attn_weights / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
|
@ -380,11 +404,17 @@ def qwen2_attention_forward_quantized(
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||||
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
(f"Attention mask should be of size "
|
||||||
|
f"{(bsz, 1, q_len, kv_seq_len)},"
|
||||||
f" but is {attention_mask.size()}"))
|
f" but is {attention_mask.size()}"))
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
|
@ -542,6 +572,15 @@ def qwen2_attention_forward_origin(
|
||||||
value_states)
|
value_states)
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, key_states,
|
||||||
|
value_states, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len,
|
||||||
|
self.head_dim, self.num_heads,
|
||||||
|
self.attention_dropout,
|
||||||
|
self.training)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(query_states,
|
attn_weights = torch.matmul(query_states,
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
@ -553,14 +592,20 @@ def qwen2_attention_forward_origin(
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
|
||||||
(f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
|
(f"Attention mask should be of size "
|
||||||
|
f"{(bsz, 1, q_len, kv_seq_len)},"
|
||||||
f" but is {attention_mask.size()}"))
|
f" but is {attention_mask.size()}"))
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
attn_weights = \
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_weights = nn.functional.dropout(attn_weights,
|
attn_weights = nn.functional.dropout(attn_weights,
|
||||||
p=self.attention_dropout,
|
p=self.attention_dropout,
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
@ -725,3 +770,36 @@ def qwen2_sdpa_attention_forward(
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len, head_dim, num_heads,
|
||||||
|
attention_dropout, training):
|
||||||
|
block_size = 8
|
||||||
|
query_split = torch.split(query, block_size, dim=1)
|
||||||
|
key_split = torch.split(key.transpose(2, 3), block_size, dim=1)
|
||||||
|
value_split = torch.split(value, block_size, dim=1)
|
||||||
|
attn_outputs = []
|
||||||
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
|
attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||||
|
block_actual_size = attn_weights_split.size(1)
|
||||||
|
attn_weights_split_size = (bsz, block_actual_size, q_len, kv_seq_len)
|
||||||
|
if attn_weights_split.size() != attn_weights_split_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Splitted attention weights should be of size "
|
||||||
|
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights_split = attn_weights_split + attention_mask
|
||||||
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
|
attn_weights_split = nn.functional.dropout(attn_weights_split,
|
||||||
|
p=attention_dropout,
|
||||||
|
training=training)
|
||||||
|
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
return attn_output, None
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue