LLM: support llama split tensor for long context in transformers>=4.36. (#10844)
* LLm: support llama split tensor for long context in transformers>=4.36. * fix dtype. * fix style. * fix style. * fix style. * fix style. * fix dtype. * fix style.
This commit is contained in:
parent
bce99a5b00
commit
763413b7e1
2 changed files with 34 additions and 28 deletions
|
|
@ -258,8 +258,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
|
|||
query_split = torch.split(query_layer, block_size, dim=1)
|
||||
key_split = torch.split(key, block_size, dim=1)
|
||||
value_split = torch.split(value, block_size, dim=1)
|
||||
context_layer = torch.empty(batch_size, n_head,
|
||||
seq_len, head_dim).to(query_layer.device)
|
||||
context_layer = torch.empty(batch_size, n_head, seq_len,
|
||||
head_dim, dtype=key.dtype).to(query_layer.device)
|
||||
idx = 0
|
||||
for q, k, v in zip(query_split, key_split, value_split):
|
||||
if attention_mask is None:
|
||||
|
|
@ -543,7 +543,7 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
|
|||
value_split = torch.split(value_layer, block_size, dim=1)
|
||||
batch_size, n_head, seq_len, head_dim = query_layer.shape
|
||||
context_layer = torch.empty(batch_size, n_head, seq_len,
|
||||
head_dim).to(query_layer.device).to(key_layer.dtype)
|
||||
head_dim, dtype=key_layer.dtype).to(query_layer.device)
|
||||
idx = 0
|
||||
for q, k, v in zip(query_split, key_split, value_split):
|
||||
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)
|
||||
|
|
|
|||
|
|
@ -1028,35 +1028,41 @@ def llama_attention_forward_4_36_quantized(
|
|||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_weights = torch.matmul(query_states,
|
||||
repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
if should_split_qkv_tensor(query_states, output_attentions):
|
||||
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
||||
repeated_value_states, attention_mask,
|
||||
bsz, q_len, kv_seq_len, self.head_dim,
|
||||
self.num_heads)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states, repeated_key_states
|
||||
.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
f"Attention weights should be of size "
|
||||
f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
# for long sequences or large batches
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
else:
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||
dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||
if use_cache:
|
||||
cache_kwargs = None
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
|
|
@ -1438,7 +1444,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
|||
attn_weights_split = torch.matmul(attn_weights_split, v)
|
||||
attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split
|
||||
idx = idx + block_actual_size
|
||||
return attn_output, None
|
||||
return attn_output.to(key.dtype), None
|
||||
|
||||
|
||||
def llama_model_selective_batching_forward_4_31(
|
||||
|
|
|
|||
Loading…
Reference in a new issue