[WIP] Support llama2 with transformers==4.38.0 (#11024)
* support llama2 with transformers==4.38.0 * add supprot for quantize_qkv * add original support for 4.38.0 now * code style fix
This commit is contained in:
parent
686f6038a8
commit
9942a4ba69
3 changed files with 123 additions and 64 deletions
|
|
@ -961,16 +961,24 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
llama_decoder_forward)
|
||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||
# transformers version >= 4.36.0
|
||||
from ipex_llm.transformers.models.llama import llama_attention_forward_4_36
|
||||
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
|
||||
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_36, )
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaModel,
|
||||
llama_model_forward_4_36)
|
||||
if version.parse(trans_version) >= version.parse("4.38.0"):
|
||||
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original
|
||||
# Todo: support llama_model_forward with transformers version >= 4.38.0
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_38_original)
|
||||
else:
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaModel,
|
||||
llama_model_forward_4_36)
|
||||
convert_forward(
|
||||
model,
|
||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||
llama_attention_forward_4_38)
|
||||
else:
|
||||
# transformers version between 4.31.0 - 4.35.2
|
||||
convert_forward(
|
||||
|
|
|
|||
|
|
@ -333,6 +333,7 @@ def llama_attention_forward_4_31(
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
|
|
@ -348,6 +349,7 @@ def llama_attention_forward_4_31(
|
|||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cache_position=cache_position,
|
||||
kwargs=kwargs
|
||||
)
|
||||
|
||||
|
|
@ -361,6 +363,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, hidden_size = hidden_states.size()
|
||||
|
|
@ -437,7 +440,8 @@ def llama_attention_forward_4_31_quantized(
|
|||
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
|
||||
repeated_value_states, attention_mask,
|
||||
repeated_value_states,
|
||||
attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
if use_cache:
|
||||
|
|
@ -462,7 +466,7 @@ def llama_attention_forward_4_31_quantized(
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
else:
|
||||
|
|
@ -498,6 +502,7 @@ def llama_attention_forward_4_31_original(
|
|||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, hidden_size = hidden_states.size()
|
||||
|
|
@ -683,7 +688,7 @@ def llama_attention_forward_4_31_original(
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
# otherwise, use native attention
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||
|
|
@ -919,20 +924,21 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||
|
||||
|
||||
def llama_attention_forward_4_36(
|
||||
def llama_attention_forward_4_38(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
past_key_value: Optional[List[torch.FloatTensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||
forward_function = llama_attention_forward_4_36_quantized
|
||||
forward_function = llama_attention_forward_4_38_quantized
|
||||
else:
|
||||
forward_function = llama_attention_forward_4_36_original
|
||||
forward_function = llama_attention_forward_4_38_original
|
||||
return forward_function(
|
||||
self=self,
|
||||
hidden_states=hidden_states,
|
||||
|
|
@ -941,20 +947,22 @@ def llama_attention_forward_4_36(
|
|||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
kwargs=kwargs
|
||||
)
|
||||
|
||||
|
||||
def llama_attention_forward_4_36_quantized(
|
||||
def llama_attention_forward_4_38_quantized(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
past_key_value: Optional[List[torch.FloatTensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
|
|
@ -1026,9 +1034,15 @@ def llama_attention_forward_4_36_quantized(
|
|||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama2")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
|
||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||
|
|
@ -1037,7 +1051,8 @@ def llama_attention_forward_4_36_quantized(
|
|||
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||
q_len, kv_seq_len, output_attentions):
|
||||
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
||||
repeated_value_states, attention_mask,
|
||||
repeated_value_states,
|
||||
attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len, self.head_dim,
|
||||
self.num_heads)
|
||||
else:
|
||||
|
|
@ -1053,13 +1068,17 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
else:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
|
|
@ -1097,13 +1116,17 @@ def llama_attention_forward_4_36_quantized(
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
invalidInputError(
|
||||
False,
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||
f" but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
else:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
|
|
@ -1146,16 +1169,17 @@ def llama_attention_forward_4_36_quantized(
|
|||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def llama_attention_forward_4_36_original(
|
||||
def llama_attention_forward_4_38_original(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
past_key_value: Optional[List[torch.FloatTensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
|
|
@ -1293,9 +1317,15 @@ def llama_attention_forward_4_36_original(
|
|||
"llama",
|
||||
rope_theta=rope_theta)
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama2")
|
||||
else:
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "llama")
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
|
|
@ -1335,8 +1365,13 @@ def llama_attention_forward_4_36_original(
|
|||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
if cache_position is not None:
|
||||
new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len]
|
||||
else:
|
||||
new_attention_mask = attention_mask
|
||||
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask):
|
||||
use_flash_attention(query_states, key_states, new_attention_mask):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
|
@ -1349,7 +1384,7 @@ def llama_attention_forward_4_36_original(
|
|||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, new_attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
@ -1359,7 +1394,7 @@ def llama_attention_forward_4_36_original(
|
|||
# otherwise, use native attention
|
||||
if query_states.device.type == "xpu":
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
new_attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim, self.num_heads, output_attentions)
|
||||
else:
|
||||
|
|
@ -1369,16 +1404,16 @@ def llama_attention_forward_4_36_original(
|
|||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
attn_mask=new_attention_mask,
|
||||
dropout_p=self.attention_dropout if self.training else 0.0,
|
||||
# The q_len > 1 is necessary to match with
|
||||
# AttentionMaskConverter.to_causal_4d that
|
||||
# does not create a causal mask in case q_len == 1.
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
is_causal=self.is_causal and new_attention_mask is None and q_len > 1,
|
||||
)
|
||||
else:
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
new_attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len,
|
||||
self.head_dim,
|
||||
self.num_heads, output_attentions)
|
||||
|
|
@ -1407,7 +1442,7 @@ def llama_attention_forward_4_36_original(
|
|||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||
|
||||
|
||||
def native_sdp(query, key, value, attention_mask,
|
||||
def native_sdp(query, key, value, attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
|
||||
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
|
||||
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||
|
|
@ -1423,12 +1458,17 @@ def native_sdp(query, key, value, attention_mask,
|
|||
f"but is {attn_weights.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
else:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if kv_seq_len >= 2048 or bsz >= 64:
|
||||
# for memory considerations, do not upcast attention to fp32
|
||||
|
|
@ -1442,7 +1482,7 @@ def native_sdp(query, key, value, attention_mask,
|
|||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
||||
def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
|
||||
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||
block_size = 8
|
||||
query_split = torch.split(query.to(key.dtype), block_size, dim=1)
|
||||
|
|
@ -1459,12 +1499,17 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
|
|||
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights_split = attn_weights_split + attention_mask
|
||||
if cache_position is not None:
|
||||
# for transformers 4.38.0
|
||||
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
else:
|
||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||
if attention_mask.size() != attn_mask_size:
|
||||
invalidInputError(False,
|
||||
f"Attention mask should be of size {attn_mask_size}, "
|
||||
f"but is {attention_mask.size()}")
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
|
||||
attn_outputs.append(torch.matmul(attn_weights_split, v))
|
||||
attn_output = torch.cat(attn_outputs, dim=1)
|
||||
|
|
|
|||
|
|
@ -178,6 +178,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
|
|||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
elif model_family == "llama2":
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
elif model_family == "gptj":
|
||||
q_embed = (q * cos) + (rotate_every_two(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_every_two(k) * sin)
|
||||
|
|
|
|||
Loading…
Reference in a new issue