refactor qwen2 forward to enable XPU (#10409)

* refactor awen2 forward to enable XPU

* Update qwen2.py
This commit is contained in:
Heyang Sun 2024-03-14 11:03:05 +08:00 committed by GitHub
parent f36224aac4
commit 7d29765092
2 changed files with 6 additions and 10 deletions

View file

@ -1075,6 +1075,7 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
convert_forward(model, convert_forward(model,
module.Qwen2Model, module.Qwen2Model,
qwen2_model_forward) qwen2_model_forward)
@ -1084,13 +1085,6 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, convert_forward(model,
module.Qwen2MLP, module.Qwen2MLP,
llama_mlp_forward) llama_mlp_forward)
if model.device.type == 'cpu':
from bigdl.llm.transformers.models.qwen2 import qwen2_sdpa_attention_forward
convert_forward(model,
module.Qwen2SdpaAttention,
qwen2_sdpa_attention_forward)
else:
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
convert_forward(model, convert_forward(model,
module.Qwen2Attention, module.Qwen2Attention,
qwen2_attention_forward) qwen2_attention_forward)

View file

@ -106,6 +106,8 @@ def qwen2_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_quantize_kv_cache(self.q_proj, hidden_states): if use_quantize_kv_cache(self.q_proj, hidden_states):
forward_function = qwen2_attention_forward_quantized forward_function = qwen2_attention_forward_quantized
elif hidden_states.device.type == "cpu":
forward_function = qwen2_sdpa_attention_forward
else: else:
forward_function = qwen2_attention_forward_origin forward_function = qwen2_attention_forward_origin
return forward_function( return forward_function(