Add lm_head optimization on NPU (#11903)
This commit is contained in:
parent
23631cd357
commit
303a090a6b
4 changed files with 167 additions and 0 deletions
|
|
@ -30,3 +30,13 @@ def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
|
||||||
new_linear.in_features = new_weight.size(1)
|
new_linear.in_features = new_weight.size(1)
|
||||||
new_linear.out_features = new_weight.size(0)
|
new_linear.out_features = new_weight.size(0)
|
||||||
return new_linear
|
return new_linear
|
||||||
|
|
||||||
|
|
||||||
|
def reshape_lm_head_input(x):
|
||||||
|
if x.dim() > 3:
|
||||||
|
x = x.reshape([-1, x.shape[-2], x.shape[-1]])
|
||||||
|
shape = list(x.size())
|
||||||
|
if shape[1] > 10:
|
||||||
|
shape[1] = 1
|
||||||
|
x = x[:, -1, :].view(shape)
|
||||||
|
return x
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,9 @@ def optimize_llm(
|
||||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
)
|
)
|
||||||
convert_forward(model, LlamaModel, llama_model_forward)
|
convert_forward(model, LlamaModel, llama_model_forward)
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||||
|
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
|
||||||
|
convert_forward(model, LlamaForCausalLM, llama2_casullm_forward)
|
||||||
elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960:
|
elif model.config.model_type == "qwen2" and model.config.intermediate_size == 8960:
|
||||||
# for qwen2-1.5B
|
# for qwen2-1.5B
|
||||||
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
|
||||||
|
|
@ -77,3 +80,6 @@ def optimize_llm(
|
||||||
prefill_runner=prefill_runner, decode_runner=decode_runner
|
prefill_runner=prefill_runner, decode_runner=decode_runner
|
||||||
)
|
)
|
||||||
convert_forward(model, Qwen2Model, qwen2_model_forward)
|
convert_forward(model, Qwen2Model, qwen2_model_forward)
|
||||||
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
||||||
|
from ipex_llm.transformers.npu_models.qwen2_mp import qwen2_casullm_forward
|
||||||
|
convert_forward(model, Qwen2ForCausalLM, qwen2_casullm_forward)
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,9 @@ from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
||||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||||
|
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
|
||||||
class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
|
|
@ -944,3 +947,79 @@ def gen_llama_fused_model_forward(prefill_runner, decode_runner):
|
||||||
)
|
)
|
||||||
|
|
||||||
return llama_fused_model_forward
|
return llama_fused_model_forward
|
||||||
|
|
||||||
|
|
||||||
|
def llama2_casullm_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None \
|
||||||
|
else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# ipex-llm change start
|
||||||
|
hidden_states = reshape_lm_head_input(hidden_states)
|
||||||
|
# ipex-llm change end
|
||||||
|
if self.config.pretraining_tp > 1:
|
||||||
|
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp,
|
||||||
|
dim=0)
|
||||||
|
logits = [F.linear(hidden_states, lm_head_slices[i])
|
||||||
|
for i in range(self.config.pretraining_tp)]
|
||||||
|
logits = torch.cat(logits, dim=-1)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,9 @@ from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
||||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||||
|
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
|
||||||
class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
|
|
@ -981,3 +984,72 @@ def gen_qwen2_fused_model_forward(prefill_runner, decode_runner):
|
||||||
)
|
)
|
||||||
|
|
||||||
return qwen2_fused_model_forward
|
return qwen2_fused_model_forward
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_casullm_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None \
|
||||||
|
else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
# cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# ipex-llm change start
|
||||||
|
hidden_states = reshape_lm_head_input(hidden_states)
|
||||||
|
# ipex-llm change end
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue