fix mistral for transformers>=4.39 (#11191)
* fix mistral for transformers>=4.39
This commit is contained in:
parent
67a1e05876
commit
c44b1942ed
2 changed files with 244 additions and 9 deletions
|
|
@ -1400,15 +1400,23 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.MistralRMSNorm,
|
module.MistralRMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
else:
|
else:
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
modeling_module_name = model.__class__.__module__
|
|
||||||
module = importlib.import_module(modeling_module_name)
|
|
||||||
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
|
|
||||||
from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
|
from ipex_llm.transformers.models.mistral import mistral_model_forward_4_36
|
||||||
convert_forward(model,
|
if version.parse(trans_version) >= version.parse("4.39.0"):
|
||||||
module.MistralAttention,
|
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_39
|
||||||
mistral_attention_forward_4_36
|
convert_forward(model,
|
||||||
)
|
module.MistralAttention,
|
||||||
|
mistral_attention_forward_4_39
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from ipex_llm.transformers.models.mistral import mistral_attention_forward_4_36
|
||||||
|
|
||||||
|
convert_forward(model,
|
||||||
|
module.MistralAttention,
|
||||||
|
mistral_attention_forward_4_36
|
||||||
|
)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MistralModel,
|
module.MistralModel,
|
||||||
mistral_model_forward_4_36
|
mistral_model_forward_4_36
|
||||||
|
|
@ -1420,8 +1428,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module.MistralMLP,
|
module.MistralMLP,
|
||||||
llama_mlp_forward)
|
llama_mlp_forward)
|
||||||
else:
|
else:
|
||||||
modeling_module_name = model.__class__.__module__
|
|
||||||
module = importlib.import_module(modeling_module_name)
|
|
||||||
from ipex_llm.transformers.models.mistral import mistral_attention_forward
|
from ipex_llm.transformers.models.mistral import mistral_attention_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MistralAttention,
|
module.MistralAttention,
|
||||||
|
|
|
||||||
|
|
@ -1074,3 +1074,232 @@ def mistral_attention_forward_4_36_original(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output.to(original_dtype), attn_weights, past_key_value
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_attention_forward_4_39(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor]=None,
|
||||||
|
position_ids: Optional[torch.LongTensor]=None,
|
||||||
|
past_key_value: Optional[Cache]=None,
|
||||||
|
output_attentions: bool=False,
|
||||||
|
use_cache: bool=False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
|
forward_function = mistral_attention_forward_4_36_quantized
|
||||||
|
else:
|
||||||
|
forward_function = mistral_attention_forward_4_39_original
|
||||||
|
return forward_function(
|
||||||
|
self=self,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
kwargs=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mistral_attention_forward_4_39_original(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor]=None,
|
||||||
|
position_ids: Optional[torch.LongTensor]=None,
|
||||||
|
past_key_value: Optional[Cache]=None,
|
||||||
|
output_attentions: bool=False,
|
||||||
|
use_cache: bool=False,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
|
||||||
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||||
|
use_fuse_rope,
|
||||||
|
enough_kv_room,
|
||||||
|
bsz * q_len)
|
||||||
|
decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla
|
||||||
|
|
||||||
|
if decoding_fast_path:
|
||||||
|
hidden_states = hidden_states.view(1, -1)
|
||||||
|
|
||||||
|
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||||
|
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||||
|
|
||||||
|
kv_seq_len = cache_k.shape[-2]
|
||||||
|
|
||||||
|
import xe_linear
|
||||||
|
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
|
||||||
|
self.q_proj.weight,
|
||||||
|
self.k_proj.weight,
|
||||||
|
self.v_proj.weight,
|
||||||
|
position_ids,
|
||||||
|
cache_k, cache_v,
|
||||||
|
self.q_proj.weight.qtype,
|
||||||
|
self.v_proj.weight.qtype,
|
||||||
|
kv_seq_len,
|
||||||
|
self.head_dim)
|
||||||
|
kv_seq_len += 1
|
||||||
|
|
||||||
|
# update past_key_value's seem_tokens and kv caches.
|
||||||
|
if self.layer_idx == 0:
|
||||||
|
past_key_value._seen_tokens = kv_seq_len
|
||||||
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
|
else:
|
||||||
|
if should_use_xetla_mm_qkv(self, device):
|
||||||
|
if not hasattr(self, "qkv_proj_qweight"):
|
||||||
|
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
|
||||||
|
self.k_proj,
|
||||||
|
self.v_proj,
|
||||||
|
self.q_proj.qtype)
|
||||||
|
import xe_linear
|
||||||
|
q_out_len = self.q_proj.out_len
|
||||||
|
k_out_len = self.k_proj.out_len
|
||||||
|
v_out_len = self.v_proj.out_len
|
||||||
|
qkv_states = xe_linear.mm_xetla(hidden_states,
|
||||||
|
self.qkv_proj_qweight,
|
||||||
|
self.q_proj.qtype)
|
||||||
|
query_states = qkv_states[:, :, :q_out_len]
|
||||||
|
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
|
||||||
|
value_states = qkv_states[:, :, q_out_len + k_out_len:]
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len,
|
||||||
|
self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
invalidInputError(False,
|
||||||
|
"The cache structure has changed since version v4.36. "
|
||||||
|
f"If you are using {self.__class__.__name__} for "
|
||||||
|
"auto-regressive decodingwith k/v caching, please make sure "
|
||||||
|
"to initialize the attention class with a layer index.")
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
if use_fuse_rope:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states,
|
||||||
|
key_states,
|
||||||
|
position_ids,
|
||||||
|
"mistral")
|
||||||
|
else:
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids, "mistral")
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# update the number of seen tokens
|
||||||
|
if self.layer_idx == 0:
|
||||||
|
past_key_value._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
|
||||||
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
|
past_key_value.key_cache.append(key_states)
|
||||||
|
past_key_value.value_cache.append(value_states)
|
||||||
|
else:
|
||||||
|
cache_k = past_key_value.key_cache[self.layer_idx]
|
||||||
|
cache_v = past_key_value.value_cache[self.layer_idx]
|
||||||
|
|
||||||
|
if not enough_kv_room:
|
||||||
|
# allocate new
|
||||||
|
new_c_k, new_c_v = extend_kv_cache(bsz,
|
||||||
|
self.num_key_value_heads, # Support GQA
|
||||||
|
self.head_dim,
|
||||||
|
cache_k.size(2),
|
||||||
|
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
|
dtype=cache_k.dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
new_c_k[:] = cache_k
|
||||||
|
new_c_v[:] = cache_v
|
||||||
|
cache_k = new_c_k
|
||||||
|
cache_v = new_c_v
|
||||||
|
|
||||||
|
key_states, value_states = append_kv_cache(cache_k, cache_v,
|
||||||
|
key_states, value_states)
|
||||||
|
|
||||||
|
# update past_key_value
|
||||||
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
|
if not self.training and not hidden_states.requires_grad:
|
||||||
|
fsdp_flag = use_flash_attention(query_states, key_states)
|
||||||
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
|
if fsdp_flag:
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
is_causal=True)
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
|
# new fp16 sdp doesn't require repeat_kv
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
|
attn_output = attn_output.view(query_states.shape)
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
else:
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
|
q_len, kv_seq_len, output_attentions):
|
||||||
|
attn_output, attn_weights = compute_attn_outputs_weights_split_tensor(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
kv_seq_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
attention_mask)
|
||||||
|
else:
|
||||||
|
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bsz,
|
||||||
|
q_len,
|
||||||
|
kv_seq_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
attention_mask)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue