LLM: support llama split tensor for long context in transformers>=4.36. (#10844)

* LLm: support llama split tensor for long context in transformers>=4.36.

* fix dtype.

* fix style.

* fix style.

* fix style.

* fix style.

* fix dtype.

* fix style.
This commit is contained in:
Cengguang Zhang 2024-04-23 16:13:25 +08:00 committed by GitHub
parent bce99a5b00
commit 763413b7e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 28 deletions

View file

@ -258,8 +258,8 @@ def chatglm2_quantized_attention_forward_8eb45c(
query_split = torch.split(query_layer, block_size, dim=1)
key_split = torch.split(key, block_size, dim=1)
value_split = torch.split(value, block_size, dim=1)
context_layer = torch.empty(batch_size, n_head,
seq_len, head_dim).to(query_layer.device)
context_layer = torch.empty(batch_size, n_head, seq_len,
head_dim, dtype=key.dtype).to(query_layer.device)
idx = 0
for q, k, v in zip(query_split, key_split, value_split):
if attention_mask is None:
@ -543,7 +543,7 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask
value_split = torch.split(value_layer, block_size, dim=1)
batch_size, n_head, seq_len, head_dim = query_layer.shape
context_layer = torch.empty(batch_size, n_head, seq_len,
head_dim).to(query_layer.device).to(key_layer.dtype)
head_dim, dtype=key_layer.dtype).to(query_layer.device)
idx = 0
for q, k, v in zip(query_split, key_split, value_split):
result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype)

View file

@ -1028,8 +1028,14 @@ def llama_attention_forward_4_36_quantized(
if len(past_key_value.key_cache) <= self.layer_idx:
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if should_split_qkv_tensor(query_states, output_attentions):
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
repeated_value_states, attention_mask,
bsz, q_len, kv_seq_len, self.head_dim,
self.num_heads)
else:
attn_weights = torch.matmul(query_states, repeated_key_states
.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
invalidInputError(
@ -1438,7 +1444,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
attn_weights_split = torch.matmul(attn_weights_split, v)
attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split
idx = idx + block_actual_size
return attn_output, None
return attn_output.to(key.dtype), None
def llama_model_selective_batching_forward_4_31(