[LLM] Replace Embedding layer to fix it on CPU (#9254)
This commit is contained in:
		
							parent
							
								
									e1bc18f8eb
								
							
						
					
					
						commit
						726203d778
					
				
					 4 changed files with 59 additions and 7 deletions
				
			
		| 
						 | 
				
			
			@ -192,7 +192,8 @@ def load_low_bit(model, model_path):
 | 
			
		|||
    return model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None):
 | 
			
		||||
def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_convert=None,
 | 
			
		||||
                   replace_embedding=False):
 | 
			
		||||
    """
 | 
			
		||||
    A method to optimize any pytorch model.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -202,6 +203,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
			
		|||
    :param optimize_llm: Whether to further optimize llm model.
 | 
			
		||||
    :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped
 | 
			
		||||
        when conducting model optimizations. Default to be None.
 | 
			
		||||
    :param replace_embedding: Whether to replace the Embedding layer, may need to set it
 | 
			
		||||
        to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
 | 
			
		||||
 | 
			
		||||
    :return: The optimized model.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +230,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
 | 
			
		|||
    model = ggml_convert_low_bit(model,
 | 
			
		||||
                                 qtype=qtype,
 | 
			
		||||
                                 optimize_model=optimize_llm,
 | 
			
		||||
                                 modules_to_not_convert=modules_to_not_convert)
 | 
			
		||||
                                 modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                 replace_embedding=replace_embedding)
 | 
			
		||||
    # add save_low_bit to pretrained model dynamically
 | 
			
		||||
    import types
 | 
			
		||||
    model._bigdl_config = dict()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -35,6 +35,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import platform
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from accelerate import init_empty_weights
 | 
			
		||||
| 
						 | 
				
			
			@ -82,8 +83,10 @@ def is_linear_module(module):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False):
 | 
			
		||||
                                 current_key_name=None, convert_shape_only=False,
 | 
			
		||||
                                 replace_embedding=False):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, FP16Linear
 | 
			
		||||
    from bigdl.llm.transformers.embedding import LLMEmbedding
 | 
			
		||||
    has_been_replaced = False
 | 
			
		||||
 | 
			
		||||
    for name, module in model.named_children():
 | 
			
		||||
| 
						 | 
				
			
			@ -147,6 +150,19 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                        model._modules[name].requires_grad_(False)
 | 
			
		||||
 | 
			
		||||
                        module.weight = None
 | 
			
		||||
        elif replace_embedding and type(module) == nn.Embedding:
 | 
			
		||||
            # skip user-defined Embedding layer
 | 
			
		||||
            if platform.system().lower() == 'windows':
 | 
			
		||||
                model._modules[name] = LLMEmbedding(
 | 
			
		||||
                    num_embeddings=module.num_embeddings,
 | 
			
		||||
                    embedding_dim=module.embedding_dim,
 | 
			
		||||
                    padding_idx=module.padding_idx,
 | 
			
		||||
                    max_norm=module.max_norm,
 | 
			
		||||
                    norm_type=module.norm_type,
 | 
			
		||||
                    scale_grad_by_freq=module.scale_grad_by_freq,
 | 
			
		||||
                    sparse=module.sparse,
 | 
			
		||||
                    _weight=module.weight.data,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # Remove the last key for recursion
 | 
			
		||||
        if len(list(module.children())) > 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -156,6 +172,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                modules_to_not_convert,
 | 
			
		||||
                current_key_name,
 | 
			
		||||
                convert_shape_only,
 | 
			
		||||
                replace_embedding,
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -185,7 +202,7 @@ def _optimize_pre(model):
 | 
			
		|||
 | 
			
		||||
def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		||||
                         convert_shape_only=False, device="cpu",
 | 
			
		||||
                         modules_to_not_convert=None):
 | 
			
		||||
                         modules_to_not_convert=None, replace_embedding=False):
 | 
			
		||||
    logger.info(f"Converting the current model to "
 | 
			
		||||
                f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
 | 
			
		||||
                f"format......")
 | 
			
		||||
| 
						 | 
				
			
			@ -196,7 +213,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
 | 
			
		||||
    model, has_been_replaced = _replace_with_low_bit_linear(
 | 
			
		||||
        model, qtype, modules_to_not_convert,
 | 
			
		||||
        None, convert_shape_only,
 | 
			
		||||
        None, convert_shape_only, replace_embedding,
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										25
									
								
								python/llm/src/bigdl/llm/transformers/embedding.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								python/llm/src/bigdl/llm/transformers/embedding.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,25 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LLMEmbedding(torch.nn.Embedding):
 | 
			
		||||
    def forward(self, x: Tensor):
 | 
			
		||||
        x_shape = x.shape
 | 
			
		||||
        return self.weight[x.reshape(-1)].reshape(*x_shape, -1)
 | 
			
		||||
| 
						 | 
				
			
			@ -68,6 +68,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
                               Default to be True.
 | 
			
		||||
        :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when
 | 
			
		||||
                                       conducting model optimizations. Default to be None.
 | 
			
		||||
        :param replace_embedding: Whether to replace the Embedding layer, may need to set it
 | 
			
		||||
            to `True` when running BigDL-LLM on GPU on Windows. Default to be `False`.
 | 
			
		||||
 | 
			
		||||
        :return: a model instance
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -118,6 +120,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        # `from_pretrained`` may pop items out in dict
 | 
			
		||||
        # and lead to args missing.
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
			
		||||
        replace_embedding = kwargs.pop("replace_embedding", False)
 | 
			
		||||
        _args = copy.deepcopy(args)
 | 
			
		||||
        _kwargs = copy.deepcopy(kwargs)
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -130,7 +133,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
            model.config.update({"bigdl_lcmu_enabled": False})
 | 
			
		||||
        model = model.to("cpu")
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     replace_embedding=replace_embedding)
 | 
			
		||||
        model.config.update({"bigdl_transformers_low_bit": q_k})
 | 
			
		||||
        model.config.update({"tie_word_embeddings": False})
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -167,6 +171,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        import os
 | 
			
		||||
 | 
			
		||||
        modules_to_not_convert = kwargs.pop("modules_to_not_convert", None)
 | 
			
		||||
        replace_embedding = kwargs.pop("replace_embedding", False)
 | 
			
		||||
        # Autofactory
 | 
			
		||||
        trust_remote_code = kwargs.pop("trust_remote_code", None)
 | 
			
		||||
        kwargs_orig = copy.deepcopy(kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -277,7 +282,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        # Loading args may differ based on their usage
 | 
			
		||||
        quant_device = "meta" if bigdl_lcmu_enabled else "cpu"
 | 
			
		||||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert)
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     replace_embedding=replace_embedding)
 | 
			
		||||
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
            loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue