LLM: update rms related usage to suport ipex 2.1 new api (#9466)
* update rms related usage * fix style
This commit is contained in:
parent
731b0aaade
commit
d2c064124a
3 changed files with 28 additions and 8 deletions
|
|
@ -30,6 +30,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from transformers.utils import logging, ContextManagers
|
from transformers.utils import logging, ContextManagers
|
||||||
|
from bigdl.llm.transformers.models.llama import get_ipex_version
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -47,8 +48,16 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
if get_ipex_version() <= "2.0.110+xpu":
|
||||||
[self.weight.size(0)], self.weight)
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)],
|
||||||
|
self.weight)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)],
|
||||||
|
self.weight,
|
||||||
|
None,
|
||||||
|
self.epsilon)
|
||||||
else:
|
else:
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from typing import Optional, Tuple, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.llama import get_ipex_version
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -77,8 +78,16 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
|
||||||
|
|
||||||
def chatglm_rms_norm_forward(self, hidden_states):
|
def chatglm_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
if get_ipex_version() <= "2.0.110+xpu":
|
||||||
[self.weight.size(0)], self.weight)
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)], self.weight)
|
||||||
|
else:
|
||||||
|
# for ipex >= 2.1
|
||||||
|
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||||
|
[self.weight.size(0)],
|
||||||
|
self.weight,
|
||||||
|
None, # bias
|
||||||
|
self.eps)
|
||||||
else:
|
else:
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -75,13 +75,15 @@ def get_ipex_version():
|
||||||
|
|
||||||
def llama_rms_norm_forward(self, hidden_states):
|
def llama_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
if get_ipex_version() == "2.0.110+xpu":
|
if get_ipex_version() <= "2.0.110+xpu":
|
||||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
[self.weight.size(0)], self.weight)
|
[self.weight.size(0)], self.weight)
|
||||||
else:
|
else:
|
||||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||||
[self.weight.size(0)], self.weight,
|
[self.weight.size(0)],
|
||||||
self.variance_epsilon)
|
self.weight,
|
||||||
|
None,
|
||||||
|
self.variance_epsilon)
|
||||||
else:
|
else:
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue