parent
9f3d4676c6
commit
e1bc18f8eb
1 changed files with 10 additions and 8 deletions
|
|
@ -59,21 +59,23 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
||||||
|
_ipex_version = None
|
||||||
|
|
||||||
|
|
||||||
def get_ipex_version():
|
def get_ipex_version():
|
||||||
|
|
||||||
if importlib.util.find_spec("intel_extension_for_pytorch") is not None:
|
global _ipex_version
|
||||||
import intel_extension_for_pytorch as ipex
|
if _ipex_version is not None:
|
||||||
return ipex.__version__
|
return _ipex_version
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
ipex_version = get_ipex_version()
|
_ipex_version = ipex.__version__
|
||||||
|
return _ipex_version
|
||||||
|
|
||||||
|
|
||||||
def llama_rms_norm_forward(self, hidden_states):
|
def llama_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
if ipex_version == "2.0.110+xpu":
|
if get_ipex_version() == "2.0.110+xpu":
|
||||||
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
|
||||||
[self.weight.size(0)], self.weight)
|
[self.weight.size(0)], self.weight)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue