LLM: support llama split tensor for long context in transformers>=4.36. (#10844)
* LLm: support llama split tensor for long context in transformers>=4.36. * fix dtype. * fix style. * fix style. * fix style. * fix style. * fix dtype. * fix style.
This commit is contained in:
parent
bce99a5b00
commit
763413b7e1
2 changed files with 34 additions and 28 deletions
|
|
@ -258,8 +258,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
||||||
query_split = torch.split(query_layer, block_size, dim=1)
|
query_split = torch.split(query_layer, block_size, dim=1)
|
||||||
key_split = torch.split(key, block_size, dim=1)
|
key_split = torch.split(key, block_size, dim=1)
|
||||||
value_split = torch.split(value, block_size, dim=1)
|
value_split = torch.split(value, block_size, dim=1)
|
||||||
context_layer = torch.empty(batch_size, n_head,
|
context_layer = torch.empty(batch_size, n_head, seq_len,
|
||||||
seq_len, head_dim).to(query_layer.device)
|
head_dim, dtype=key.dtype).to(query_layer.device)
|
||||||
idx = 0
|
idx = 0
|
||||||
for q, k, v in zip(query_split, key_split, value_split):
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
|
|
@ -543,7 +543,7 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
||||||
value_split = torch.split(value_layer, block_size, dim=1)
|
value_split = torch.split(value_layer, block_size, dim=1)
|
||||||
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||||
context_layer = torch.empty(batch_size, n_head, seq_len,
|
context_layer = torch.empty(batch_size, n_head, seq_len,
|
||||||
head_dim).to(query_layer.device).to(key_layer.dtype)
|
head_dim, dtype=key_layer.dtype).to(query_layer.device)
|
||||||
idx = 0
|
idx = 0
|
||||||
for q, k, v in zip(query_split, key_split, value_split):
|
for q, k, v in zip(query_split, key_split, value_split):
|
||||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
||||||
|
|
|
||||||
|
|
@ -1028,35 +1028,41 @@ def llama_attention_forward_4_36_quantized(
|
||||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_weights = torch.matmul(query_states,
|
if should_split_qkv_tensor(query_states, output_attentions):
|
||||||
repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
||||||
|
repeated_value_states, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len, self.head_dim,
|
||||||
|
self.num_heads)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states, repeated_key_states
|
||||||
|
.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"Attention weights should be of size "
|
|
||||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
False,
|
False,
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
f"Attention weights should be of size "
|
||||||
f" but is {attention_mask.size()}"
|
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
if kv_seq_len >= 2048 or bsz >= 64:
|
if attention_mask is not None:
|
||||||
# for memory considerations, do not upcast attention to fp32
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
# for long sequences or large batches
|
invalidInputError(
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
False,
|
||||||
else:
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||||
# upcast attention to fp32
|
f" but is {attention_mask.size()}"
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
)
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
|
||||||
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
# for long sequences or large batches
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
else:
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
|
@ -1438,7 +1444,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
attn_weights_split = torch.matmul(attn_weights_split, v)
|
attn_weights_split = torch.matmul(attn_weights_split, v)
|
||||||
attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split
|
attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split
|
||||||
idx = idx + block_actual_size
|
idx = idx + block_actual_size
|
||||||
return attn_output, None
|
return attn_output.to(key.dtype), None
|
||||||
|
|
||||||
|
|
||||||
def llama_model_selective_batching_forward_4_31(
|
def llama_model_selective_batching_forward_4_31(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue