add basic support for llama3.2 (#12125)
This commit is contained in:
parent
66f419f8b7
commit
584c3489e7
3 changed files with 271 additions and 2 deletions
|
|
@ -1267,7 +1267,19 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
from ipex_llm.transformers.models.llama import llama_rms_norm_forward
|
||||
from ipex_llm.transformers.models.llama import llama_mlp_forward
|
||||
|
||||
if model.config.model_type == "llama":
|
||||
if model.config.model_type == "llama" and model.config.rope_scaling is not None:
|
||||
# llama 3.2
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||
from ipex_llm.transformers.models.common import mlp_silu_forward
|
||||
from ipex_llm.transformers.models.llama32 import llama_model_forward
|
||||
from ipex_llm.transformers.models.llama32 import llama_attention_forward
|
||||
convert_forward(model, module.LlamaRMSNorm, rms_norm_forward)
|
||||
convert_forward(model, module.LlamaMLP, mlp_silu_forward)
|
||||
convert_forward(model, module.LlamaModel, llama_model_forward)
|
||||
convert_forward(model, module.LlamaAttention, llama_attention_forward)
|
||||
elif model.config.model_type == "llama":
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
|
|
|||
|
|
@ -24,12 +24,16 @@ from .models.utils import (
|
|||
init_fp8_kv_cache, append_fp8_kv_cache,
|
||||
init_kv_cache, append_kv_cache, extend_kv_cache
|
||||
)
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from typing import Optional, Dict, Tuple, Any, List
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
||||
|
||||
class DynamicFp8Cache(DynamicCache):
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
# ignore num_hidden_layers to fix transformers >= 4.45
|
||||
super().__init__()
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
|
|
@ -37,6 +41,9 @@ class DynamicFp8Cache(DynamicCache):
|
|||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# fix converting empty DynamicCache in transformers >= 4.45
|
||||
if key_states == []:
|
||||
return key_states, value_states
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
|
|
@ -71,6 +78,10 @@ class DynamicFp8Cache(DynamicCache):
|
|||
class DynamicNormalCache(DynamicCache):
|
||||
KV_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
# ignore num_hidden_layers to fix transformers >= 4.45
|
||||
super().__init__()
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
|
|
@ -78,6 +89,9 @@ class DynamicNormalCache(DynamicCache):
|
|||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# fix converting empty DynamicCache in transformers >= 4.45
|
||||
if key_states == []:
|
||||
return key_states, value_states
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
|
|
@ -257,6 +271,9 @@ class DynamicCompressCache(DynamicCache):
|
|||
KV_CACHE_ALLOC_BLOCK_LENGTH: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# fix converting empty DynamicCache in transformers >= 4.45
|
||||
if key_states == []:
|
||||
return key_states, value_states
|
||||
|
||||
bsz, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
|
|
@ -354,6 +371,10 @@ class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
|
|||
KV_CACHE_ALLOC_BLOCK_LENGTH: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# fix converting empty DynamicCache in transformers >= 4.45
|
||||
if key_states == []:
|
||||
return key_states, value_states
|
||||
|
||||
bsz, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
if layer_idx == 0:
|
||||
|
|
|
|||
236
python/llm/src/ipex_llm/transformers/models/llama32.py
Normal file
236
python/llm/src/ipex_llm/transformers/models/llama32.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
# which is licensed under Apache License 2.0:
|
||||
#
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import repeat_kv
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.common import attention_softmax
|
||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
# IPEX-LLM OPT end
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
invalidInputError((input_ids is None) ^ (inputs_embeds is None),
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, "
|
||||
"and must specify either one")
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def llama_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]]=None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.split([self.num_heads,
|
||||
self.num_key_value_heads,
|
||||
self.num_key_value_heads], dim=1)
|
||||
|
||||
if position_embeddings is None:
|
||||
cos, sin = self.rotary_emb(value_states, position_ids)
|
||||
else:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
kv_seq_len = key_states.size(2)
|
||||
if attention_mask is not None: # no matter the length, we just slice it
|
||||
causal_mask = attention_mask[:, :, :, :kv_seq_len]
|
||||
|
||||
attn_weights = None
|
||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||
import xe_addons
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, causal_mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, causal_mask)
|
||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||
import xe_addons
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||
value_states, causal_mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||
value_states, causal_mask)
|
||||
else:
|
||||
if isinstance(past_key_value, DynamicFp8Cache):
|
||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||
query_states.dtype)
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states,
|
||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if causal_mask is not None:
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = attention_softmax(attn_weights, self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
Loading…
Reference in a new issue