refactor qwen2 forward to enable XPU (#10409)
* refactor awen2 forward to enable XPU * Update qwen2.py
This commit is contained in:
parent
f36224aac4
commit
7d29765092
2 changed files with 6 additions and 10 deletions
|
|
@ -1075,6 +1075,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
|
from bigdl.llm.transformers.models.qwen2 import qwen2_model_forward
|
||||||
|
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2Model,
|
module.Qwen2Model,
|
||||||
qwen2_model_forward)
|
qwen2_model_forward)
|
||||||
|
|
@ -1084,13 +1085,6 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2MLP,
|
module.Qwen2MLP,
|
||||||
llama_mlp_forward)
|
llama_mlp_forward)
|
||||||
if model.device.type == 'cpu':
|
|
||||||
from bigdl.llm.transformers.models.qwen2 import qwen2_sdpa_attention_forward
|
|
||||||
convert_forward(model,
|
|
||||||
module.Qwen2SdpaAttention,
|
|
||||||
qwen2_sdpa_attention_forward)
|
|
||||||
else:
|
|
||||||
from bigdl.llm.transformers.models.qwen2 import qwen2_attention_forward
|
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2Attention,
|
module.Qwen2Attention,
|
||||||
qwen2_attention_forward)
|
qwen2_attention_forward)
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,8 @@ def qwen2_attention_forward(
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
if use_quantize_kv_cache(self.q_proj, hidden_states):
|
||||||
forward_function = qwen2_attention_forward_quantized
|
forward_function = qwen2_attention_forward_quantized
|
||||||
|
elif hidden_states.device.type == "cpu":
|
||||||
|
forward_function = qwen2_sdpa_attention_forward
|
||||||
else:
|
else:
|
||||||
forward_function = qwen2_attention_forward_origin
|
forward_function = qwen2_attention_forward_origin
|
||||||
return forward_function(
|
return forward_function(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue