[WIP] Support llama2 with transformers==4.38.0 (#11024)

* support llama2 with transformers==4.38.0

* add supprot for quantize_qkv

* add original support for 4.38.0 now

* code style fix
This commit is contained in:
SONG Ge 2024-05-15 18:07:00 +08:00 committed by GitHub
parent 686f6038a8
commit 9942a4ba69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 123 additions and 64 deletions

View file

@ -961,16 +961,24 @@ def _optimize_post(model, lightweight_bmm=False):
llama_decoder_forward)
if version.parse(trans_version) >= version.parse("4.36.0"):
# transformers version >= 4.36.0
from ipex_llm.transformers.models.llama import llama_attention_forward_4_36
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_36, )
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_36)
if version.parse(trans_version) >= version.parse("4.38.0"):
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original
# Todo: support llama_model_forward with transformers version >= 4.38.0
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38_original)
else:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_36)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
else:
# transformers version between 4.31.0 - 4.35.2
convert_forward(

View file

@ -333,6 +333,7 @@ def llama_attention_forward_4_31(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
@ -348,6 +349,7 @@ def llama_attention_forward_4_31(
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cache_position=cache_position,
kwargs=kwargs
)
@ -361,6 +363,7 @@ def llama_attention_forward_4_31_quantized(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size()
@ -437,7 +440,8 @@ def llama_attention_forward_4_31_quantized(
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
repeated_value_states, attention_mask,
repeated_value_states,
attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
if use_cache:
@ -462,7 +466,7 @@ def llama_attention_forward_4_31_quantized(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
else:
@ -498,6 +502,7 @@ def llama_attention_forward_4_31_original(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size()
@ -683,7 +688,7 @@ def llama_attention_forward_4_31_original(
value_states = repeat_kv(value_states, self.num_key_value_groups)
# otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
@ -919,20 +924,21 @@ def llama_attention_selective_batching_forward_4_31(
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def llama_attention_forward_4_36(
def llama_attention_forward_4_38(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = llama_attention_forward_4_36_quantized
forward_function = llama_attention_forward_4_38_quantized
else:
forward_function = llama_attention_forward_4_36_original
forward_function = llama_attention_forward_4_38_original
return forward_function(
self=self,
hidden_states=hidden_states,
@ -941,20 +947,22 @@ def llama_attention_forward_4_36(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
kwargs=kwargs
)
def llama_attention_forward_4_36_quantized(
def llama_attention_forward_4_38_quantized(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -1026,9 +1034,15 @@ def llama_attention_forward_4_36_quantized(
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if cache_position is not None:
# for transformers 4.38.0
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx:
@ -1037,7 +1051,8 @@ def llama_attention_forward_4_36_quantized(
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions):
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
repeated_value_states, attention_mask,
repeated_value_states,
attention_mask, cache_position,
bsz, q_len, kv_seq_len, self.head_dim,
self.num_heads)
else:
@ -1053,13 +1068,17 @@ def llama_attention_forward_4_36_quantized(
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
attn_weights = attn_weights + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
@ -1097,13 +1116,17 @@ def llama_attention_forward_4_36_quantized(
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
invalidInputError(
False,
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
f" but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
attn_weights = attn_weights + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
@ -1146,16 +1169,17 @@ def llama_attention_forward_4_36_quantized(
return attn_output, attn_weights, past_key_value
def llama_attention_forward_4_36_original(
def llama_attention_forward_4_38_original(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -1293,9 +1317,15 @@ def llama_attention_forward_4_36_original(
"llama",
rope_theta=rope_theta)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if cache_position is not None:
# for transformers 4.38.0
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if past_key_value is not None:
# update the number of seen tokens
@ -1335,8 +1365,13 @@ def llama_attention_forward_4_36_original(
past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states
if cache_position is not None:
new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len]
else:
new_attention_mask = attention_mask
if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask):
use_flash_attention(query_states, key_states, new_attention_mask):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
@ -1349,7 +1384,7 @@ def llama_attention_forward_4_36_original(
elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, new_attention_mask)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
@ -1359,7 +1394,7 @@ def llama_attention_forward_4_36_original(
# otherwise, use native attention
if query_states.device.type == "xpu":
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
new_attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
else:
@ -1369,16 +1404,16 @@ def llama_attention_forward_4_36_original(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
attn_mask=new_attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with
# AttentionMaskConverter.to_causal_4d that
# does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
is_causal=self.is_causal and new_attention_mask is None and q_len > 1,
)
else:
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
new_attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim,
self.num_heads, output_attentions)
@ -1407,7 +1442,7 @@ def llama_attention_forward_4_36_original(
return attn_output.to(original_dtype), attn_weights, past_key_value
def native_sdp(query, key, value, attention_mask,
def native_sdp(query, key, value, attention_mask, cache_position,
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
@ -1423,12 +1458,17 @@ def native_sdp(query, key, value, attention_mask,
f"but is {attn_weights.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
attn_weights = attn_weights + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32
@ -1442,7 +1482,7 @@ def native_sdp(query, key, value, attention_mask,
return attn_output, attn_weights
def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
bsz, q_len, kv_seq_len, head_dim, num_heads):
block_size = 8
query_split = torch.split(query.to(key.dtype), block_size, dim=1)
@ -1459,12 +1499,17 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights_split = attn_weights_split + attention_mask
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
attn_weights = attn_weights + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1)

View file

@ -178,6 +178,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
elif model_family == "llama2":
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
elif model_family == "gptj":
q_embed = (q * cos) + (rotate_every_two(q) * sin)
k_embed = (k * cos) + (rotate_every_two(k) * sin)