LLM: update rms related usage to suport ipex 2.1 new api (#9466)

* update rms related usage

* fix style
This commit is contained in:
Ruonan Wang 2023-11-16 11:21:50 +08:00 committed by GitHub
parent 731b0aaade
commit d2c064124a
3 changed files with 28 additions and 8 deletions

View file

@ -30,6 +30,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from transformers.utils import logging, ContextManagers
from bigdl.llm.transformers.models.llama import get_ipex_version
logger = logging.get_logger(__name__)
try:
@ -47,8 +48,16 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
[self.weight.size(0)],
self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.epsilon)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

View file

@ -22,6 +22,7 @@ from typing import Optional, Tuple, List
import torch.nn.functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.llama import get_ipex_version
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -77,8 +78,16 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
# for ipex >= 2.1
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None, # bias
self.eps)
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

View file

@ -75,12 +75,14 @@ def get_ipex_version():
def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() == "2.0.110+xpu":
if get_ipex_version() <= "2.0.110+xpu":
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight)
else:
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
[self.weight.size(0)], self.weight,
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
else:
input_dtype = hidden_states.dtype