[WIP] Support llama2 with transformers==4.38.0 (#11024)

* support llama2 with transformers==4.38.0

* add supprot for quantize_qkv

* add original support for 4.38.0 now

* code style fix
This commit is contained in:
SONG Ge 2024-05-15 18:07:00 +08:00 committed by GitHub
parent 686f6038a8
commit 9942a4ba69
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 123 additions and 64 deletions

View file

@ -961,16 +961,24 @@ def _optimize_post(model, lightweight_bmm=False):
llama_decoder_forward) llama_decoder_forward)
if version.parse(trans_version) >= version.parse("4.36.0"): if version.parse(trans_version) >= version.parse("4.36.0"):
# transformers version >= 4.36.0 # transformers version >= 4.36.0
from ipex_llm.transformers.models.llama import llama_attention_forward_4_36 from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
from ipex_llm.transformers.models.llama import llama_model_forward_4_36 from ipex_llm.transformers.models.llama import llama_model_forward_4_36
convert_forward( if version.parse(trans_version) >= version.parse("4.38.0"):
model, from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original
transformers.models.llama.modeling_llama.LlamaAttention, # Todo: support llama_model_forward with transformers version >= 4.38.0
llama_attention_forward_4_36, ) convert_forward(
convert_forward( model,
model, transformers.models.llama.modeling_llama.LlamaAttention,
transformers.models.llama.modeling_llama.LlamaModel, llama_attention_forward_4_38_original)
llama_model_forward_4_36) else:
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaModel,
llama_model_forward_4_36)
convert_forward(
model,
transformers.models.llama.modeling_llama.LlamaAttention,
llama_attention_forward_4_38)
else: else:
# transformers version between 4.31.0 - 4.35.2 # transformers version between 4.31.0 - 4.35.2
convert_forward( convert_forward(

View file

@ -333,6 +333,7 @@ def llama_attention_forward_4_31(
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states): if use_quantize_kv_cache(self.q_proj, hidden_states):
@ -348,6 +349,7 @@ def llama_attention_forward_4_31(
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
padding_mask=padding_mask, padding_mask=padding_mask,
cache_position=cache_position,
kwargs=kwargs kwargs=kwargs
) )
@ -361,6 +363,7 @@ def llama_attention_forward_4_31_quantized(
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size() bsz, q_len, hidden_size = hidden_states.size()
@ -437,7 +440,8 @@ def llama_attention_forward_4_31_quantized(
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = native_sdp(query_states, repeated_key_states, attn_output, attn_weights = native_sdp(query_states, repeated_key_states,
repeated_value_states, attention_mask, repeated_value_states,
attention_mask, cache_position,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions) self.head_dim, self.num_heads, output_attentions)
if use_cache: if use_cache:
@ -462,7 +466,7 @@ def llama_attention_forward_4_31_quantized(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask, attention_mask, cache_position,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions) self.head_dim, self.num_heads, output_attentions)
else: else:
@ -498,6 +502,7 @@ def llama_attention_forward_4_31_original(
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, padding_mask: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, hidden_size = hidden_states.size() bsz, q_len, hidden_size = hidden_states.size()
@ -683,7 +688,7 @@ def llama_attention_forward_4_31_original(
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
# otherwise, use native attention # otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask, attention_mask, cache_position,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions) self.head_dim, self.num_heads, output_attentions)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
@ -919,20 +924,21 @@ def llama_attention_selective_batching_forward_4_31(
return attn_output.to(original_dtype), attn_weights, updated_past_key_values return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def llama_attention_forward_4_36( def llama_attention_forward_4_38(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states): if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = llama_attention_forward_4_36_quantized forward_function = llama_attention_forward_4_38_quantized
else: else:
forward_function = llama_attention_forward_4_36_original forward_function = llama_attention_forward_4_38_original
return forward_function( return forward_function(
self=self, self=self,
hidden_states=hidden_states, hidden_states=hidden_states,
@ -941,20 +947,22 @@ def llama_attention_forward_4_36(
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position,
kwargs=kwargs kwargs=kwargs
) )
def llama_attention_forward_4_36_quantized( def llama_attention_forward_4_38_quantized(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
warnings.warn( warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. " "Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -1026,9 +1034,15 @@ def llama_attention_forward_4_36_quantized(
"llama", "llama",
rope_theta=rope_theta) rope_theta=rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if cache_position is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, # for transformers 4.38.0
cos, sin, position_ids, "llama") cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if len(past_key_value.key_cache) <= self.layer_idx: if len(past_key_value.key_cache) <= self.layer_idx:
@ -1037,7 +1051,8 @@ def llama_attention_forward_4_36_quantized(
if should_split_qkv_tensor(query_states, bsz, self.num_heads, if should_split_qkv_tensor(query_states, bsz, self.num_heads,
q_len, kv_seq_len, output_attentions): q_len, kv_seq_len, output_attentions):
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
repeated_value_states, attention_mask, repeated_value_states,
attention_mask, cache_position,
bsz, q_len, kv_seq_len, self.head_dim, bsz, q_len, kv_seq_len, self.head_dim,
self.num_heads) self.num_heads)
else: else:
@ -1053,13 +1068,17 @@ def llama_attention_forward_4_36_quantized(
) )
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): if cache_position is not None:
invalidInputError( # for transformers 4.38.0
False, causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," attn_weights = attn_weights + causal_mask
f" but is {attention_mask.size()}" else:
) attn_mask_size = (bsz, 1, q_len, kv_seq_len)
attn_weights = attn_weights + attention_mask if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64: if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32 # for memory considerations, do not upcast attention to fp32
@ -1097,13 +1116,17 @@ def llama_attention_forward_4_36_quantized(
) )
if attention_mask is not None: if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): if cache_position is not None:
invalidInputError( # for transformers 4.38.0
False, causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," attn_weights = attn_weights + causal_mask
f" but is {attention_mask.size()}" else:
) attn_mask_size = (bsz, 1, q_len, kv_seq_len)
attn_weights = attn_weights + attention_mask if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64: if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32 # for memory considerations, do not upcast attention to fp32
@ -1146,16 +1169,17 @@ def llama_attention_forward_4_36_quantized(
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
def llama_attention_forward_4_36_original( def llama_attention_forward_4_38_original(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None, past_key_value: Optional[List[torch.FloatTensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs **kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
if "padding_mask" in kwargs: if "padding_mask" in kwargs:
warnings.warn( warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. " "Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -1293,9 +1317,15 @@ def llama_attention_forward_4_36_original(
"llama", "llama",
rope_theta=rope_theta) rope_theta=rope_theta)
else: else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if cache_position is not None:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, # for transformers 4.38.0
cos, sin, position_ids, "llama") cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama2")
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids, "llama")
if past_key_value is not None: if past_key_value is not None:
# update the number of seen tokens # update the number of seen tokens
@ -1335,8 +1365,13 @@ def llama_attention_forward_4_36_original(
past_key_value.key_cache[self.layer_idx] = key_states past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states past_key_value.value_cache[self.layer_idx] = value_states
if cache_position is not None:
new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len]
else:
new_attention_mask = attention_mask
if not self.training and not hidden_states.requires_grad and \ if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask): use_flash_attention(query_states, key_states, new_attention_mask):
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
@ -1349,7 +1384,7 @@ def llama_attention_forward_4_36_original(
elif not self.training and not hidden_states.requires_grad and \ elif not self.training and not hidden_states.requires_grad and \
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_q4_0 import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = linear_q4_0.sdp(query_states, key_states, value_states, new_attention_mask)
attn_output = attn_output.view(query_states.shape) attn_output = attn_output.view(query_states.shape)
attn_weights = None attn_weights = None
else: else:
@ -1359,7 +1394,7 @@ def llama_attention_forward_4_36_original(
# otherwise, use native attention # otherwise, use native attention
if query_states.device.type == "xpu": if query_states.device.type == "xpu":
attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask, new_attention_mask, cache_position,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions) self.head_dim, self.num_heads, output_attentions)
else: else:
@ -1369,16 +1404,16 @@ def llama_attention_forward_4_36_original(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=new_attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with # The q_len > 1 is necessary to match with
# AttentionMaskConverter.to_causal_4d that # AttentionMaskConverter.to_causal_4d that
# does not create a causal mask in case q_len == 1. # does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1, is_causal=self.is_causal and new_attention_mask is None and q_len > 1,
) )
else: else:
attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask, new_attention_mask, cache_position,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.head_dim, self.head_dim,
self.num_heads, output_attentions) self.num_heads, output_attentions)
@ -1407,7 +1442,7 @@ def llama_attention_forward_4_36_original(
return attn_output.to(original_dtype), attn_weights, past_key_value return attn_output.to(original_dtype), attn_weights, past_key_value
def native_sdp(query, key, value, attention_mask, def native_sdp(query, key, value, attention_mask, cache_position,
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions): bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions): if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
return native_sdp_split_qkv_tensor(query, key, value, attention_mask, return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
@ -1423,12 +1458,17 @@ def native_sdp(query, key, value, attention_mask,
f"but is {attn_weights.size()}") f"but is {attn_weights.size()}")
if attention_mask is not None: if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len) if cache_position is not None:
if attention_mask.size() != attn_mask_size: # for transformers 4.38.0
invalidInputError(False, causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
f"Attention mask should be of size {attn_mask_size}, " attn_weights = attn_weights + causal_mask
f"but is {attention_mask.size()}") else:
attn_weights = attn_weights + attention_mask attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
if kv_seq_len >= 2048 or bsz >= 64: if kv_seq_len >= 2048 or bsz >= 64:
# for memory considerations, do not upcast attention to fp32 # for memory considerations, do not upcast attention to fp32
@ -1442,7 +1482,7 @@ def native_sdp(query, key, value, attention_mask,
return attn_output, attn_weights return attn_output, attn_weights
def native_sdp_split_qkv_tensor(query, key, value, attention_mask, def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
bsz, q_len, kv_seq_len, head_dim, num_heads): bsz, q_len, kv_seq_len, head_dim, num_heads):
block_size = 8 block_size = 8
query_split = torch.split(query.to(key.dtype), block_size, dim=1) query_split = torch.split(query.to(key.dtype), block_size, dim=1)
@ -1459,12 +1499,17 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask,
f"{attn_weights_split_size}, but is {attn_weights_split.size()}") f"{attn_weights_split_size}, but is {attn_weights_split.size()}")
if attention_mask is not None: if attention_mask is not None:
attn_mask_size = (bsz, 1, q_len, kv_seq_len) if cache_position is not None:
if attention_mask.size() != attn_mask_size: # for transformers 4.38.0
invalidInputError(False, causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
f"Attention mask should be of size {attn_mask_size}, " attn_weights = attn_weights + causal_mask
f"but is {attention_mask.size()}") else:
attn_weights_split = attn_weights_split + attention_mask attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_outputs.append(torch.matmul(attn_weights_split, v)) attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1) attn_output = torch.cat(attn_outputs, dim=1)

View file

@ -178,6 +178,12 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, model_family):
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
elif model_family == "llama2":
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
elif model_family == "gptj": elif model_family == "gptj":
q_embed = (q * cos) + (rotate_every_two(q) * sin) q_embed = (q * cos) + (rotate_every_two(q) * sin)
k_embed = (k * cos) + (rotate_every_two(k) * sin) k_embed = (k * cos) + (rotate_every_two(k) * sin)