[LLM] Replace Embedding layer to fix it on CPU (#9254)
This commit is contained in:
		
							parent
							
								
									e1bc18f8eb
								
							
						
					
					
						commit
						726203d778
					
				
					 4 changed files with 59 additions and 7 deletions
				
			
		| 
						 | 
					@ -192,7 +192,8 @@ def load_low_bit(model, model_path):
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None):
 | 
					def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None,
 | 
				
			||||||
 | 
					                   replace_embedding=False):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    A method to optimize any pytorch model.
 | 
					    A method to optimize any pytorch model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -202,6 +203,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
				
			||||||
    :param optimize_llm: Whether to further optimize llm model.
 | 
					    :param optimize_llm: Whether to further optimize llm model.
 | 
				
			||||||
    :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped
 | 
					    :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped
 | 
				
			||||||
        when conducting model optimizations. Default to be None.
 | 
					        when conducting model optimizations. Default to be None.
 | 
				
			||||||
 | 
					    :param replace_embedding: Whether to replace the Embedding layer, may need to set it
 | 
				
			||||||
 | 
					        to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    :return: The optimized model.
 | 
					    :return: The optimized model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -227,7 +230,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
				
			||||||
    model = ggml_convert_low_bit(model,
 | 
					    model = ggml_convert_low_bit(model,
 | 
				
			||||||
                                 qtype=qtype,
 | 
					                                 qtype=qtype,
 | 
				
			||||||
                                 optimize_model=optimize_llm,
 | 
					                                 optimize_model=optimize_llm,
 | 
				
			||||||
                                 modules_to_not_convert=modules_to_not_convert)
 | 
					                                 modules_to_not_convert=modules_to_not_convert,
 | 
				
			||||||
 | 
					                                 replace_embedding=replace_embedding)
 | 
				
			||||||
    # add save_low_bit to pretrained model dynamically
 | 
					    # add save_low_bit to pretrained model dynamically
 | 
				
			||||||
    import types
 | 
					    import types
 | 
				
			||||||
    model._bigdl_config = dict()
 | 
					    model._bigdl_config = dict()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,6 +35,7 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import platform
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torch.nn as nn
 | 
					import torch.nn as nn
 | 
				
			||||||
from accelerate import init_empty_weights
 | 
					from accelerate import init_empty_weights
 | 
				
			||||||
| 
						 | 
					@ -82,8 +83,10 @@ def is_linear_module(module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
					def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                                 current_key_name=None, convert_shape_only=False):
 | 
					                                 current_key_name=None, convert_shape_only=False,
 | 
				
			||||||
 | 
					                                 replace_embedding=False):
 | 
				
			||||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
 | 
					    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
 | 
				
			||||||
 | 
					    from bigdl.llm.transformers.embedding import LLMEmbedding
 | 
				
			||||||
    has_been_replaced = False
 | 
					    has_been_replaced = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for name, module in model.named_children():
 | 
					    for name, module in model.named_children():
 | 
				
			||||||
| 
						 | 
					@ -147,6 +150,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                        model._modules[name].requires_grad_(False)
 | 
					                        model._modules[name].requires_grad_(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        module.weight = None
 | 
					                        module.weight = None
 | 
				
			||||||
 | 
					        elif replace_embedding and type(module) == nn.Embedding:
 | 
				
			||||||
 | 
					            # skip user-defined Embedding layer
 | 
				
			||||||
 | 
					            if platform.system().lower() == 'windows':
 | 
				
			||||||
 | 
					                model._modules[name] = LLMEmbedding(
 | 
				
			||||||
 | 
					                    num_embeddings=module.num_embeddings,
 | 
				
			||||||
 | 
					                    embedding_dim=module.embedding_dim,
 | 
				
			||||||
 | 
					                    padding_idx=module.padding_idx,
 | 
				
			||||||
 | 
					                    max_norm=module.max_norm,
 | 
				
			||||||
 | 
					                    norm_type=module.norm_type,
 | 
				
			||||||
 | 
					                    scale_grad_by_freq=module.scale_grad_by_freq,
 | 
				
			||||||
 | 
					                    sparse=module.sparse,
 | 
				
			||||||
 | 
					                    _weight=module.weight.data,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Remove the last key for recursion
 | 
					        # Remove the last key for recursion
 | 
				
			||||||
        if len(list(module.children())) > 0:
 | 
					        if len(list(module.children())) > 0:
 | 
				
			||||||
| 
						 | 
					@ -156,6 +172,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
				
			||||||
                modules_to_not_convert,
 | 
					                modules_to_not_convert,
 | 
				
			||||||
                current_key_name,
 | 
					                current_key_name,
 | 
				
			||||||
                convert_shape_only,
 | 
					                convert_shape_only,
 | 
				
			||||||
 | 
					                replace_embedding,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            has_been_replaced = _flag or has_been_replaced
 | 
					            has_been_replaced = _flag or has_been_replaced
 | 
				
			||||||
    return model, has_been_replaced
 | 
					    return model, has_been_replaced
 | 
				
			||||||
| 
						 | 
					@ -185,7 +202,7 @@ def _optimize_pre(model):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
					def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
                         convert_shape_only=False, device="cpu",
 | 
					                         convert_shape_only=False, device="cpu",
 | 
				
			||||||
                         modules_to_not_convert=None):
 | 
					                         modules_to_not_convert=None, replace_embedding=False):
 | 
				
			||||||
    logger.info(f"Converting the current model to "
 | 
					    logger.info(f"Converting the current model to "
 | 
				
			||||||
                f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
 | 
					                f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
 | 
				
			||||||
                f"format......")
 | 
					                f"format......")
 | 
				
			||||||
| 
						 | 
					@ -196,7 +213,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
					    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
				
			||||||
        model, qtype, modules_to_not_convert,
 | 
					        model, qtype, modules_to_not_convert,
 | 
				
			||||||
        None, convert_shape_only,
 | 
					        None, convert_shape_only, replace_embedding,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    if not has_been_replaced:
 | 
					    if not has_been_replaced:
 | 
				
			||||||
        warnings.warn(
 | 
					        warnings.warn(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										25
									
								
								python/llm/src/bigdl/llm/transformers/embedding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								python/llm/src/bigdl/llm/transformers/embedding.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,25 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch import Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LLMEmbedding(torch.nn.Embedding):
 | 
				
			||||||
 | 
					    def forward(self, x: Tensor):
 | 
				
			||||||
 | 
					        x_shape = x.shape
 | 
				
			||||||
 | 
					        return self.weight[x.reshape(-1)].reshape(*x_shape, -1)
 | 
				
			||||||
| 
						 | 
					@ -68,6 +68,8 @@ class _BaseAutoModelClass:
 | 
				
			||||||
                               Default to be True.
 | 
					                               Default to be True.
 | 
				
			||||||
        :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
 | 
					        :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
 | 
				
			||||||
                                       conducting model optimizations. Default to be None.
 | 
					                                       conducting model optimizations. Default to be None.
 | 
				
			||||||
 | 
					        :param replace_embedding: Whether to replace the Embedding layer, may need to set it
 | 
				
			||||||
 | 
					            to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :return: a model instance
 | 
					        :return: a model instance
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -118,6 +120,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        # `from_pretrained`` may pop items out in dict
 | 
					        # `from_pretrained`` may pop items out in dict
 | 
				
			||||||
        # and lead to args missing.
 | 
					        # and lead to args missing.
 | 
				
			||||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
					        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
				
			||||||
 | 
					        replace_embedding = kwargs.pop("replace_embedding", False)
 | 
				
			||||||
        _args = copy.deepcopy(args)
 | 
					        _args = copy.deepcopy(args)
 | 
				
			||||||
        _kwargs = copy.deepcopy(kwargs)
 | 
					        _kwargs = copy.deepcopy(kwargs)
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -130,7 +133,8 @@ class _BaseAutoModelClass:
 | 
				
			||||||
            model.config.update({"bigdl_lcmu_enabled": False})
 | 
					            model.config.update({"bigdl_lcmu_enabled": False})
 | 
				
			||||||
        model = model.to("cpu")
 | 
					        model = model.to("cpu")
 | 
				
			||||||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
					        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
				
			||||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
					                                     modules_to_not_convert=modules_to_not_convert,
 | 
				
			||||||
 | 
					                                     replace_embedding=replace_embedding)
 | 
				
			||||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
					        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
				
			||||||
        model.config.update({"tie_word_embeddings": False})
 | 
					        model.config.update({"tie_word_embeddings": False})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -167,6 +171,7 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        import os
 | 
					        import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
					        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
				
			||||||
 | 
					        replace_embedding = kwargs.pop("replace_embedding", False)
 | 
				
			||||||
        # Autofactory
 | 
					        # Autofactory
 | 
				
			||||||
        trust_remote_code = kwargs.pop("trust_remote_code", None)
 | 
					        trust_remote_code = kwargs.pop("trust_remote_code", None)
 | 
				
			||||||
        kwargs_orig = copy.deepcopy(kwargs)
 | 
					        kwargs_orig = copy.deepcopy(kwargs)
 | 
				
			||||||
| 
						 | 
					@ -277,7 +282,8 @@ class _BaseAutoModelClass:
 | 
				
			||||||
        # Loading args may differ based on their usage
 | 
					        # Loading args may differ based on their usage
 | 
				
			||||||
        quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
 | 
					        quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
 | 
				
			||||||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
					        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
				
			||||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
					                                     modules_to_not_convert=modules_to_not_convert,
 | 
				
			||||||
 | 
					                                     replace_embedding=replace_embedding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if is_sharded:
 | 
					        if is_sharded:
 | 
				
			||||||
            loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
 | 
					            loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue