[WIP] Support llama2 with transformers==4.38.0 (#11024)
* support llama2 with transformers==4.38.0 * add supprot for quantize_qkv * add original support for 4.38.0 now * code style fix
This commit is contained in:
parent
686f6038a8
commit
9942a4ba69
3 changed files with 123 additions and 64 deletions
|
|
@ -961,16 +961,24 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
llama_decoder_forward)
|
llama_decoder_forward)
|
||||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
# transformers version >= 4.36.0
|
# transformers version >= 4.36.0
|
||||||
from ipex_llm.transformers.models.llama import llama_attention_forward_4_36
|
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
|
||||||
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
|
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
|
||||||
convert_forward(
|
if version.parse(trans_version) >= version.parse("4.38.0"):
|
||||||
model,
|
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
# Todo: support llama_model_forward with transformers version >= 4.38.0
|
||||||
llama_attention_forward_4_36, )
|
convert_forward(
|
||||||
convert_forward(
|
model,
|
||||||
model,
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
transformers.models.llama.modeling_llama.LlamaModel,
|
llama_attention_forward_4_38_original)
|
||||||
llama_model_forward_4_36)
|
else:
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
|
llama_model_forward_4_36)
|
||||||
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
|
llama_attention_forward_4_38)
|
||||||
else:
|
else:
|
||||||
# transformers version between 4.31.0 - 4.35.2
|
# transformers version between 4.31.0 - 4.35.2
|
||||||
convert_forward(
|
convert_forward(
|
||||||
|
|
|
||||||
|
|
@ -333,6 +333,7 @@ def llama_attention_forward_4_31(
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
|
|
@ -348,6 +349,7 @@ def llama_attention_forward_4_31(
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
padding_mask=padding_mask,
|
padding_mask=padding_mask,
|
||||||
|
cache_position=cache_position,
|
||||||
kwargs=kwargs
|
kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -361,6 +363,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, hidden_size = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
|
|
@ -437,7 +440,8 @@ def llama_attention_forward_4_31_quantized(
|
||||||
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
||||||
repeated_value_states, attention_mask,
|
repeated_value_states,
|
||||||
|
attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads, output_attentions)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
@ -462,7 +466,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads, output_attentions)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
else:
|
else:
|
||||||
|
|
@ -498,6 +502,7 @@ def llama_attention_forward_4_31_original(
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, hidden_size = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
|
|
@ -683,7 +688,7 @@ def llama_attention_forward_4_31_original(
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
# otherwise, use native attention
|
# otherwise, use native attention
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads, output_attentions)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
|
|
@ -919,20 +924,21 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward_4_36(
|
def llama_attention_forward_4_38(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[List[torch.FloatTensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
forward_function = llama_attention_forward_4_36_quantized
|
forward_function = llama_attention_forward_4_38_quantized
|
||||||
else:
|
else:
|
||||||
forward_function = llama_attention_forward_4_36_original
|
forward_function = llama_attention_forward_4_38_original
|
||||||
return forward_function(
|
return forward_function(
|
||||||
self=self,
|
self=self,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|
@ -941,20 +947,22 @@ def llama_attention_forward_4_36(
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
kwargs=kwargs
|
kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward_4_36_quantized(
|
def llama_attention_forward_4_38_quantized(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[List[torch.FloatTensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -1026,9 +1034,15 @@ def llama_attention_forward_4_36_quantized(
|
||||||
"llama",
|
"llama",
|
||||||
rope_theta=rope_theta)
|
rope_theta=rope_theta)
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
if cache_position is not None:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
# for transformers 4.38.0
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "llama2")
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "llama")
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
|
||||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
|
|
@ -1037,7 +1051,8 @@ def llama_attention_forward_4_36_quantized(
|
||||||
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
q_len, kv_seq_len, output_attentions):
|
q_len, kv_seq_len, output_attentions):
|
||||||
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
||||||
repeated_value_states, attention_mask,
|
repeated_value_states,
|
||||||
|
attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len, self.head_dim,
|
bsz, q_len, kv_seq_len, self.head_dim,
|
||||||
self.num_heads)
|
self.num_heads)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1053,13 +1068,17 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
if cache_position is not None:
|
||||||
invalidInputError(
|
# for transformers 4.38.0
|
||||||
False,
|
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
attn_weights = attn_weights + causal_mask
|
||||||
f" but is {attention_mask.size()}"
|
else:
|
||||||
)
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
attn_weights = attn_weights + attention_mask
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
if kv_seq_len >= 2048 or bsz >= 64:
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
# for memory considerations, do not upcast attention to fp32
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
|
@ -1097,13 +1116,17 @@ def llama_attention_forward_4_36_quantized(
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
if cache_position is not None:
|
||||||
invalidInputError(
|
# for transformers 4.38.0
|
||||||
False,
|
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
attn_weights = attn_weights + causal_mask
|
||||||
f" but is {attention_mask.size()}"
|
else:
|
||||||
)
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
attn_weights = attn_weights + attention_mask
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
if kv_seq_len >= 2048 or bsz >= 64:
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
# for memory considerations, do not upcast attention to fp32
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
|
@ -1146,16 +1169,17 @@ def llama_attention_forward_4_36_quantized(
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward_4_36_original(
|
def llama_attention_forward_4_38_original(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_value: Optional[Cache] = None,
|
past_key_value: Optional[List[torch.FloatTensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -1293,9 +1317,15 @@ def llama_attention_forward_4_36_original(
|
||||||
"llama",
|
"llama",
|
||||||
rope_theta=rope_theta)
|
rope_theta=rope_theta)
|
||||||
else:
|
else:
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
if cache_position is not None:
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
# for transformers 4.38.0
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "llama2")
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# update the number of seen tokens
|
# update the number of seen tokens
|
||||||
|
|
@ -1335,8 +1365,13 @@ def llama_attention_forward_4_36_original(
|
||||||
past_key_value.key_cache[self.layer_idx] = key_states
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
|
if cache_position is not None:
|
||||||
|
new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len]
|
||||||
|
else:
|
||||||
|
new_attention_mask = attention_mask
|
||||||
|
|
||||||
if not self.training and not hidden_states.requires_grad and \
|
if not self.training and not hidden_states.requires_grad and \
|
||||||
use_flash_attention(query_states, key_states, attention_mask):
|
use_flash_attention(query_states, key_states, new_attention_mask):
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
@ -1349,7 +1384,7 @@ def llama_attention_forward_4_36_original(
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, new_attention_mask)
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
|
|
@ -1359,7 +1394,7 @@ def llama_attention_forward_4_36_original(
|
||||||
# otherwise, use native attention
|
# otherwise, use native attention
|
||||||
if query_states.device.type == "xpu":
|
if query_states.device.type == "xpu":
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
new_attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim, self.num_heads, output_attentions)
|
self.head_dim, self.num_heads, output_attentions)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1369,16 +1404,16 @@ def llama_attention_forward_4_36_original(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_mask=attention_mask,
|
attn_mask=new_attention_mask,
|
||||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||||
# The q_len > 1 is necessary to match with
|
# The q_len > 1 is necessary to match with
|
||||||
# AttentionMaskConverter.to_causal_4d that
|
# AttentionMaskConverter.to_causal_4d that
|
||||||
# does not create a causal mask in case q_len == 1.
|
# does not create a causal mask in case q_len == 1.
|
||||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
is_causal=self.is_causal and new_attention_mask is None and q_len > 1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attention_mask,
|
new_attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len,
|
bsz, q_len, kv_seq_len,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.num_heads, output_attentions)
|
self.num_heads, output_attentions)
|
||||||
|
|
@ -1407,7 +1442,7 @@ def llama_attention_forward_4_36_original(
|
||||||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def native_sdp(query, key, value, attention_mask,
|
def native_sdp(query, key, value, attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||||
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
|
|
@ -1423,12 +1458,17 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
f"but is {attn_weights.size()}")
|
f"but is {attn_weights.size()}")
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
if cache_position is not None:
|
||||||
if attention_mask.size() != attn_mask_size:
|
# for transformers 4.38.0
|
||||||
invalidInputError(False,
|
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||||
f"Attention mask should be of size {attn_mask_size}, "
|
attn_weights = attn_weights + causal_mask
|
||||||
f"but is {attention_mask.size()}")
|
else:
|
||||||
attn_weights = attn_weights + attention_mask
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
if kv_seq_len >= 2048 or bsz >= 64:
|
if kv_seq_len >= 2048 or bsz >= 64:
|
||||||
# for memory considerations, do not upcast attention to fp32
|
# for memory considerations, do not upcast attention to fp32
|
||||||
|
|
@ -1442,7 +1482,7 @@ def native_sdp(query, key, value, attention_mask,
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
|
||||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||||
block_size = 8
|
block_size = 8
|
||||||
query_split = torch.split(query.to(key.dtype), block_size, dim=1)
|
query_split = torch.split(query.to(key.dtype), block_size, dim=1)
|
||||||
|
|
@ -1459,12 +1499,17 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||||
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
if cache_position is not None:
|
||||||
if attention_mask.size() != attn_mask_size:
|
# for transformers 4.38.0
|
||||||
invalidInputError(False,
|
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||||
f"Attention mask should be of size {attn_mask_size}, "
|
attn_weights = attn_weights + causal_mask
|
||||||
f"but is {attention_mask.size()}")
|
else:
|
||||||
attn_weights_split = attn_weights_split + attention_mask
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||||
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||||
attn_output = torch.cat(attn_outputs, dim=1)
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
elif model_family == "llama2":
|
||||||
|
cos = cos.unsqueeze(1)
|
||||||
|
sin = sin.unsqueeze(1)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
elif model_family == "gptj":
|
elif model_family == "gptj":
|
||||||
q_embed = (q * cos) + (rotate_every_two(q) * sin)
|
q_embed = (q * cos) + (rotate_every_two(q) * sin)
|
||||||
k_embed = (k * cos) + (rotate_every_two(k) * sin)
|
k_embed = (k * cos) + (rotate_every_two(k) * sin)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue