Add support for llama2 quantize_kv with transformers 4.38.0 (#11054)
* add support for llama2 quantize_kv with transformers 4.38.0 * fix code style * fix code style
This commit is contained in:
parent
16b2a418be
commit
192ae35012
2 changed files with 175 additions and 5 deletions
|
|
@ -964,15 +964,18 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
if version.parse(trans_version) >= version.parse("4.36.0"):
|
if version.parse(trans_version) >= version.parse("4.36.0"):
|
||||||
# transformers version >= 4.36.0
|
# transformers version >= 4.36.0
|
||||||
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
|
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38
|
||||||
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
|
|
||||||
if version.parse(trans_version) >= version.parse("4.38.0"):
|
if version.parse(trans_version) >= version.parse("4.38.0"):
|
||||||
from ipex_llm.transformers.models.llama import llama_attention_forward_4_38_original
|
from ipex_llm.transformers.models.llama import llama_model_forward_4_38
|
||||||
# Todo: support llama_model_forward with transformers version >= 4.38.0
|
convert_forward(
|
||||||
|
model,
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
|
llama_model_forward_4_38)
|
||||||
convert_forward(
|
convert_forward(
|
||||||
model,
|
model,
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention,
|
transformers.models.llama.modeling_llama.LlamaAttention,
|
||||||
llama_attention_forward_4_38_original)
|
llama_attention_forward_4_38)
|
||||||
else:
|
else:
|
||||||
|
from ipex_llm.transformers.models.llama import llama_model_forward_4_36
|
||||||
convert_forward(
|
convert_forward(
|
||||||
model,
|
model,
|
||||||
transformers.models.llama.modeling_llama.LlamaModel,
|
transformers.models.llama.modeling_llama.LlamaModel,
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,40 @@ def llama_model_forward_4_36(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_forward_4_38(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
|
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input):
|
||||||
|
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
return llama_model_forward_4_38_internal(
|
||||||
|
self=self,
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_rms_norm_forward(self, hidden_states):
|
def llama_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
|
@ -1143,8 +1177,12 @@ def llama_attention_forward_4_38_quantized(
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
else:
|
else:
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
|
if cache_position is not None:
|
||||||
|
new_attn_mask = attention_mask[:, :, kv_seq_len-q_len:kv_seq_len, 0:kv_seq_len]
|
||||||
|
else:
|
||||||
|
new_attn_mask = attention_mask
|
||||||
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
new_attn_mask)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
|
@ -1802,6 +1840,135 @@ def llama_attention_fast_forward(
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_forward_4_38_internal(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else \
|
||||||
|
self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else
|
||||||
|
self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
|
invalidInputError(False,
|
||||||
|
f"You cannot specify both input_ids and inputs_embeds at the same time,"
|
||||||
|
f" and must specify either one")
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training and use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. "
|
||||||
|
"Setting `use_cache=False`."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
past_seen_tokens = 0
|
||||||
|
if use_cache: # kept for BC (cache positions)
|
||||||
|
if not isinstance(past_key_values, Cache):
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1],
|
||||||
|
device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
|
||||||
|
|
||||||
|
# embed positions
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
causal_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
use_cache,
|
||||||
|
cache_position,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# bigdl-llm changes:
|
||||||
|
curr_device = decoder_layer.input_layernorm.weight.device
|
||||||
|
if causal_mask is not None:
|
||||||
|
causal_mask = causal_mask.to(curr_device)
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.to(curr_device)
|
||||||
|
# bigdl-llm changes end
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=causal_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||||
|
if use_cache:
|
||||||
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if not isinstance(next_decoder_cache, DynamicFp8Cache)
|
||||||
|
else next_decoder_cache
|
||||||
|
)
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states,
|
||||||
|
all_self_attns] if v is not None)
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward_4_36_internal(
|
def llama_model_forward_4_36_internal(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue