LLM: update rms related usage to suport ipex 2.1 new api (#9466)
* update rms related usage * fix style
This commit is contained in:
		
							parent
							
								
									731b0aaade
								
							
						
					
					
						commit
						d2c064124a
					
				
					 3 changed files with 28 additions and 8 deletions
				
			
		| 
						 | 
				
			
			@ -30,6 +30,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
 | 
			
		|||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
 | 
			
		||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
from transformers.utils import logging, ContextManagers
 | 
			
		||||
from bigdl.llm.transformers.models.llama import get_ipex_version
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
| 
						 | 
				
			
			@ -47,8 +48,16 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		|||
 | 
			
		||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                         [self.weight.size(0)], self.weight)
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                             [self.weight.size(0)],
 | 
			
		||||
                                                             self.weight)
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                                  [self.weight.size(0)],
 | 
			
		||||
                                                                  self.weight,
 | 
			
		||||
                                                                  None,
 | 
			
		||||
                                                                  self.epsilon)
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,6 +22,7 @@ from typing import Optional, Tuple, List
 | 
			
		|||
import torch.nn.functional as F
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from bigdl.llm.transformers.models.llama import get_ipex_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
| 
						 | 
				
			
			@ -77,8 +78,16 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
 | 
			
		|||
 | 
			
		||||
def chatglm_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                         [self.weight.size(0)], self.weight)
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                             [self.weight.size(0)], self.weight)
 | 
			
		||||
        else:
 | 
			
		||||
            # for ipex >= 2.1
 | 
			
		||||
            hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                               [self.weight.size(0)],
 | 
			
		||||
                                                               self.weight,
 | 
			
		||||
                                                               None,  # bias
 | 
			
		||||
                                                               self.eps)
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -75,13 +75,15 @@ def get_ipex_version():
 | 
			
		|||
 | 
			
		||||
def llama_rms_norm_forward(self, hidden_states):
 | 
			
		||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
			
		||||
        if get_ipex_version() == "2.0.110+xpu":
 | 
			
		||||
        if get_ipex_version() <= "2.0.110+xpu":
 | 
			
		||||
            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                             [self.weight.size(0)], self.weight)
 | 
			
		||||
        else:
 | 
			
		||||
            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
			
		||||
                                                             [self.weight.size(0)], self.weight,
 | 
			
		||||
                                                             self.variance_epsilon)
 | 
			
		||||
            hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
 | 
			
		||||
                                                                  [self.weight.size(0)],
 | 
			
		||||
                                                                  self.weight,
 | 
			
		||||
                                                                  None,
 | 
			
		||||
                                                                  self.variance_epsilon)
 | 
			
		||||
    else:
 | 
			
		||||
        input_dtype = hidden_states.dtype
 | 
			
		||||
        hidden_states = hidden_states.to(torch.float32)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue