From 7d29765092b9676709605d09ef1f50407d4232ab Mon Sep 17 00:00:00 2001 From: Heyang Sun <60865256+Uxito-Ada@users.noreply.github.com> Date: Thu, 14 Mar 2024 11:03:05 +0800 Subject: [PATCH] refactor qwen2 forward to enable XPU (#10409) * refactor awen2 forward to enable XPU * Update qwen2.py --- python/llm/src/bigdl/llm/transformers/convert.py | 14 ++++---------- .../llm/src/bigdl/llm/transformers/models/qwen2.py | 2 ++ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index ad672926..f42c7d49 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -1075,6 +1075,7 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward + from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward convert_forward(model, module.Qwen2Model, qwen2_model_forward) @@ -1084,16 +1085,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.Qwen2MLP, llama_mlp_forward) - if model.device.type == 'cpu': - from bigdl.llm.transformers.models.qwen2 import qwen2_sdpa_attention_forward - convert_forward(model, - module.Qwen2SdpaAttention, - qwen2_sdpa_attention_forward) - else: - from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward - convert_forward(model, - module.Qwen2Attention, - qwen2_attention_forward) + convert_forward(model, + module.Qwen2Attention, + qwen2_attention_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index ac96f815..e3482119 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -106,6 +106,8 @@ def qwen2_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if use_quantize_kv_cache(self.q_proj, hidden_states): forward_function = qwen2_attention_forward_quantized + elif hidden_states.device.type == "cpu": + forward_function = qwen2_sdpa_attention_forward else: forward_function = qwen2_attention_forward_origin return forward_function(