diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 441e7e4d..e3072179 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, is_linear, linear_args = is_linear_module(module) if is_linear and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and + module.weight.data.device.type != 'meta' and + not isinstance(module, LowBitLinear)): in_features, out_features, mp_group = linear_args with init_empty_weights(): new_linear = None diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py index 2c4ace53..ba1feae4 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py @@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module + def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, low_bit='sym_int4'): config = loader.config @@ -78,7 +79,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, else: set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype) model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) - + tensor_loader = loader.tensor_loader tensor_loader.load_while_process(process_mistral) diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py index 4ab221b8..216ceccb 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py @@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer from ..gguf import GGUFFileLoader +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module -def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): +def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, + low_bit='sym_int4'): # mixtral enjoys a general architecture of llma # e.g. it applies llama tokenizer config = loader.config @@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): n_head = config['llama.attention.head_count'] n_head_kv = config['llama.attention.head_count_kv'] hidden_size = config['llama.embedding_length'] + qtype = ggml_tensor_qtype[low_bit] mixtral_config = MixtralConfig( vocab_size=len(config['tokenizer.ggml.tokens']), @@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): "cpu", tensor, dtype=dtype) + model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) tensor_loader = loader.tensor_loader tensor_loader.load_while_process(process_mixtral)