Support deepspeed AutoTP (#9230)
* Support deepspeed * add test script * refactor convert * refine example * refine * refine example * fix style * refine example and adapte latest ipex * fix style
This commit is contained in:
		
							parent
							
								
									a6a8afc47e
								
							
						
					
					
						commit
						067c7e8098
					
				
					 5 changed files with 190 additions and 9 deletions
				
			
		
							
								
								
									
										103
									
								
								python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								python/llm/example/GPU/Deepspeed-AutoTP/deepspeed_autotp.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,103 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import transformers
 | 
				
			||||||
 | 
					import deepspeed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					local_rank = int(os.getenv("LOCAL_RANK", "0"))
 | 
				
			||||||
 | 
					world_size = int(os.getenv("WORLD_SIZE", "1"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from bigdl.llm import optimize_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import intel_extension_for_pytorch as ipex
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from transformers import AutoModelForCausalLM  # export AutoModelForCausalLM from transformers so that deepspeed use it
 | 
				
			||||||
 | 
					from transformers import LlamaTokenizer, AutoTokenizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == '__main__':
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--prompt', type=str, default="Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun",
 | 
				
			||||||
 | 
					                        help='Prompt to infer')
 | 
				
			||||||
 | 
					    parser.add_argument('--n-predict', type=int, default=32,
 | 
				
			||||||
 | 
					                        help='Max tokens to predict')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = AutoModelForCausalLM.from_pretrained(args.repo_id_or_model_path,
 | 
				
			||||||
 | 
					                                                 low_cpu_mem_usage=True,
 | 
				
			||||||
 | 
					                                                 torch_dtype=torch.float16,
 | 
				
			||||||
 | 
					                                                 trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                 use_cache=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model = deepspeed.init_inference(
 | 
				
			||||||
 | 
					        model,
 | 
				
			||||||
 | 
					        mp_size=world_size,
 | 
				
			||||||
 | 
					        dtype=torch.float16,
 | 
				
			||||||
 | 
					        replace_method="auto",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # move model to cpu and use bigdl-llm `optimize_model` to convert the
 | 
				
			||||||
 | 
					    # model into optimized low bit format
 | 
				
			||||||
 | 
					    # convert the rest of the model into float16 to reduce allreduce traffic
 | 
				
			||||||
 | 
					    model = optimize_model(model.module.to(f'cpu'), low_bit='sym_int4').to(torch.float16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # move model back to xpu
 | 
				
			||||||
 | 
					    model = model.to(f'xpu:{local_rank}')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Load tokenizer
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Generate predicted tokens
 | 
				
			||||||
 | 
					    with torch.inference_mode():
 | 
				
			||||||
 | 
					        # prompt = get_prompt(args.prompt, [], system_prompt=DEFAULT_SYSTEM_PROMPT)
 | 
				
			||||||
 | 
					        prompt = args.prompt
 | 
				
			||||||
 | 
					        # input_str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n"
 | 
				
			||||||
 | 
					        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
 | 
				
			||||||
 | 
					        # ipex model needs a warmup, then inference time can be accurate
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict,
 | 
				
			||||||
 | 
					                                use_cache=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # start inference
 | 
				
			||||||
 | 
					        st = time.time()
 | 
				
			||||||
 | 
					        # if your selected model is capable of utilizing previous key/value attentions
 | 
				
			||||||
 | 
					        # to enhance decoding speed, but has `"use_cache": false` in its model config,
 | 
				
			||||||
 | 
					        # it is important to set `use_cache=True` explicitly in the `generate` function
 | 
				
			||||||
 | 
					        # to obtain optimal performance with BigDL-LLM INT4 optimizations
 | 
				
			||||||
 | 
					        output = model.generate(input_ids,
 | 
				
			||||||
 | 
					                                do_sample=False,
 | 
				
			||||||
 | 
					                                max_new_tokens=args.n_predict)
 | 
				
			||||||
 | 
					        torch.xpu.synchronize()
 | 
				
			||||||
 | 
					        end = time.time()
 | 
				
			||||||
 | 
					        if local_rank == 0:
 | 
				
			||||||
 | 
					            output = output.cpu()
 | 
				
			||||||
 | 
					            output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
				
			||||||
 | 
					            print(f'Inference time: {end-st} s')
 | 
				
			||||||
 | 
					            print('-'*20, 'Prompt', '-'*20)
 | 
				
			||||||
 | 
					            print(prompt)
 | 
				
			||||||
 | 
					            print('-'*20, 'Output', '-'*20)
 | 
				
			||||||
 | 
					            print(output_str)
 | 
				
			||||||
							
								
								
									
										12
									
								
								python/llm/example/GPU/Deepspeed-AutoTP/run.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								python/llm/example/GPU/Deepspeed-AutoTP/run.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,12 @@
 | 
				
			||||||
 | 
					source bigdl-llm-init -t -g
 | 
				
			||||||
 | 
					export MASTER_ADDR=127.0.0.1
 | 
				
			||||||
 | 
					export CCL_ZE_IPC_EXCHANGE=sockets
 | 
				
			||||||
 | 
					if [[ -n $OMP_NUM_THREADS ]]; then
 | 
				
			||||||
 | 
					    export OMP_NUM_THREADS=$(($OMP_NUM_THREADS / 4))
 | 
				
			||||||
 | 
					else
 | 
				
			||||||
 | 
					    export OMP_NUM_THREADS=$(($(nproc) / 4))
 | 
				
			||||||
 | 
					fi
 | 
				
			||||||
 | 
					torchrun --standalone \
 | 
				
			||||||
 | 
					         --nnodes=1 \
 | 
				
			||||||
 | 
					         --nproc-per-node 4 \
 | 
				
			||||||
 | 
					         deepspeed_autotp.py --repo-id-or-model-path "meta-llama/Llama-2-7b-hf"
 | 
				
			||||||
| 
						 | 
					@ -45,6 +45,42 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
 | 
				
			||||||
from .utils import logger
 | 
					from .utils import logger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_deepspeed_available():
 | 
				
			||||||
 | 
					    return importlib.util.find_spec("deepspeed") is not None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_linear_module(module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    in_features = None
 | 
				
			||||||
 | 
					    out_features = None
 | 
				
			||||||
 | 
					    mp_group = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if isinstance(module, nn.Linear):
 | 
				
			||||||
 | 
					        in_features = module.in_features
 | 
				
			||||||
 | 
					        out_features = module.out_features
 | 
				
			||||||
 | 
					        mp_group = None
 | 
				
			||||||
 | 
					        result = True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if is_deepspeed_available():
 | 
				
			||||||
 | 
					            from deepspeed.module_inject.layers import LinearLayer, LinearAllreduce
 | 
				
			||||||
 | 
					            if isinstance(module, LinearLayer):
 | 
				
			||||||
 | 
					                in_features = module.weight.shape[1]
 | 
				
			||||||
 | 
					                out_features = module.weight.shape[0]
 | 
				
			||||||
 | 
					                mp_group = None
 | 
				
			||||||
 | 
					                result = True
 | 
				
			||||||
 | 
					            elif isinstance(module, LinearAllreduce):
 | 
				
			||||||
 | 
					                in_features = module.weight.shape[1]
 | 
				
			||||||
 | 
					                out_features = module.weight.shape[0]
 | 
				
			||||||
 | 
					                mp_group = module.mp_group
 | 
				
			||||||
 | 
					                result = True
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                result = False
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            result = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return result, (in_features, out_features, mp_group)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
					def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                                 current_key_name=None, convert_shape_only=False):
 | 
					                                 current_key_name=None, convert_shape_only=False):
 | 
				
			||||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
 | 
					    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
 | 
				
			||||||
| 
						 | 
					@ -54,17 +90,20 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
        if current_key_name is None:
 | 
					        if current_key_name is None:
 | 
				
			||||||
            current_key_name = []
 | 
					            current_key_name = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
 | 
					        is_linear, linear_args = is_linear_module(module)
 | 
				
			||||||
 | 
					        if is_linear and name not in modules_to_not_convert:
 | 
				
			||||||
            # Check if the current key is not in the `modules_to_not_convert`
 | 
					            # Check if the current key is not in the `modules_to_not_convert`
 | 
				
			||||||
            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
 | 
					            if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
 | 
				
			||||||
 | 
					                in_features, out_features, mp_group = linear_args
 | 
				
			||||||
                with init_empty_weights():
 | 
					                with init_empty_weights():
 | 
				
			||||||
                    new_linear = None
 | 
					                    new_linear = None
 | 
				
			||||||
                    if qtype != ggml_tensor_qtype["fp16"]:
 | 
					                    if qtype != ggml_tensor_qtype["fp16"]:
 | 
				
			||||||
                        new_linear = LowBitLinear(
 | 
					                        new_linear = LowBitLinear(
 | 
				
			||||||
                            module.in_features,
 | 
					                            in_features,
 | 
				
			||||||
                            module.out_features,
 | 
					                            out_features,
 | 
				
			||||||
                            qtype,
 | 
					                            qtype,
 | 
				
			||||||
                            module.bias is not None,
 | 
					                            module.bias is not None,
 | 
				
			||||||
 | 
					                            mp_group=mp_group,
 | 
				
			||||||
                        )
 | 
					                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        device_type = module.weight.data.device.type
 | 
					                        device_type = module.weight.data.device.type
 | 
				
			||||||
| 
						 | 
					@ -82,10 +121,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                        if module.in_features in [4096, 11008]:
 | 
					                        if module.in_features in [4096, 11008]:
 | 
				
			||||||
                            # esimd fp16 path
 | 
					                            # esimd fp16 path
 | 
				
			||||||
                            new_linear = FP16Linear(
 | 
					                            new_linear = FP16Linear(
 | 
				
			||||||
                                module.in_features,
 | 
					                                in_features,
 | 
				
			||||||
                                module.out_features,
 | 
					                                out_features,
 | 
				
			||||||
                                qtype,
 | 
					                                qtype,
 | 
				
			||||||
                                module.bias is not None,
 | 
					                                module.bias is not None,
 | 
				
			||||||
 | 
					                                mp_group=mp_group,
 | 
				
			||||||
                            )
 | 
					                            )
 | 
				
			||||||
                            device_type = module.weight.data.device.type
 | 
					                            device_type = module.weight.data.device.type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -328,7 +328,7 @@ class MatMulLowBit(torch.autograd.Function):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LowBitLinear(nn.Linear):
 | 
					class LowBitLinear(nn.Linear):
 | 
				
			||||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
					    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
				
			||||||
                 conver_to_half=True):
 | 
					                 conver_to_half=True, mp_group=None):
 | 
				
			||||||
        super().__init__(input_features, output_features, bias)
 | 
					        super().__init__(input_features, output_features, bias)
 | 
				
			||||||
        self.weight = FP4Params(self.weight.data,
 | 
					        self.weight = FP4Params(self.weight.data,
 | 
				
			||||||
                                requires_grad=False,
 | 
					                                requires_grad=False,
 | 
				
			||||||
| 
						 | 
					@ -339,6 +339,7 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
        self.weight_length = self.out_len * self.in_len
 | 
					        self.weight_length = self.out_len * self.in_len
 | 
				
			||||||
        self.qtype = qtype
 | 
					        self.qtype = qtype
 | 
				
			||||||
        self.conver_to_half = conver_to_half
 | 
					        self.conver_to_half = conver_to_half
 | 
				
			||||||
 | 
					        self.mp_group = mp_group
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: torch.Tensor):
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
					        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
				
			||||||
| 
						 | 
					@ -378,6 +379,9 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
                                                     input_seq_size)
 | 
					                                                     input_seq_size)
 | 
				
			||||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
					            new_shape = x_shape[:-1] + (self.out_len,)
 | 
				
			||||||
            result = result.view(new_shape)
 | 
					            result = result.view(new_shape)
 | 
				
			||||||
 | 
					            if self.mp_group is not None:
 | 
				
			||||||
 | 
					                from deepspeed import comm as dist
 | 
				
			||||||
 | 
					                dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
            if self.bias is not None:
 | 
					            if self.bias is not None:
 | 
				
			||||||
                result += self.bias
 | 
					                result += self.bias
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -400,7 +404,7 @@ class LowBitLinear(nn.Linear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FP16Linear(nn.Linear):
 | 
					class FP16Linear(nn.Linear):
 | 
				
			||||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
					    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
				
			||||||
                 conver_to_half=True):
 | 
					                 conver_to_half=True, mp_group=None):
 | 
				
			||||||
        super().__init__(input_features, output_features, bias)
 | 
					        super().__init__(input_features, output_features, bias)
 | 
				
			||||||
        self.in_len = input_features
 | 
					        self.in_len = input_features
 | 
				
			||||||
        self.out_len = output_features
 | 
					        self.out_len = output_features
 | 
				
			||||||
| 
						 | 
					@ -408,6 +412,7 @@ class FP16Linear(nn.Linear):
 | 
				
			||||||
        self.weight_length = self.out_len * self.in_len
 | 
					        self.weight_length = self.out_len * self.in_len
 | 
				
			||||||
        self.qtype = qtype
 | 
					        self.qtype = qtype
 | 
				
			||||||
        self.conver_to_half = conver_to_half
 | 
					        self.conver_to_half = conver_to_half
 | 
				
			||||||
 | 
					        self.mp_group = mp_group
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: torch.Tensor):
 | 
					    def forward(self, x: torch.Tensor):
 | 
				
			||||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
					        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
				
			||||||
| 
						 | 
					@ -442,6 +447,9 @@ class FP16Linear(nn.Linear):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        new_shape = x_shape[:-1] + (self.out_len,)
 | 
					        new_shape = x_shape[:-1] + (self.out_len,)
 | 
				
			||||||
        result = result.view(new_shape)
 | 
					        result = result.view(new_shape)
 | 
				
			||||||
 | 
					        if self.mp_group is not None:
 | 
				
			||||||
 | 
					            from deepspeed import comm as dist
 | 
				
			||||||
 | 
					            dist.inference_all_reduce(result, group=self.mp_group)
 | 
				
			||||||
        if self.bias is not None:
 | 
					        if self.bias is not None:
 | 
				
			||||||
            result += self.bias
 | 
					            result += self.bias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,7 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
from typing import Optional, Tuple
 | 
					from typing import Optional, Tuple
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
| 
						 | 
					@ -58,10 +59,27 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
				
			||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
					KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_ipex_version():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if importlib.util.find_spec("intel_extension_for_pytorch") is not None:
 | 
				
			||||||
 | 
					        import intel_extension_for_pytorch as ipex
 | 
				
			||||||
 | 
					        return ipex.__version__
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ipex_version = get_ipex_version()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def llama_rms_norm_forward(self, hidden_states):
 | 
					def llama_rms_norm_forward(self, hidden_states):
 | 
				
			||||||
    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
					    if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
 | 
				
			||||||
 | 
					        if ipex_version == "2.0.110+xpu":
 | 
				
			||||||
            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
					            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
				
			||||||
                                                             [self.weight.size(0)], self.weight)
 | 
					                                                             [self.weight.size(0)], self.weight)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states,
 | 
				
			||||||
 | 
					                                                             [self.weight.size(0)], self.weight,
 | 
				
			||||||
 | 
					                                                             self.variance_epsilon)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        input_dtype = hidden_states.dtype
 | 
					        input_dtype = hidden_states.dtype
 | 
				
			||||||
        hidden_states = hidden_states.to(torch.float32)
 | 
					        hidden_states = hidden_states.to(torch.float32)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue