From e1e921f425201222143f09288637af091784d645 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Thu, 21 Dec 2023 09:33:40 +0800 Subject: [PATCH] LLM: gguf other model using dtype (#9729) --- python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py | 2 +- python/llm/src/bigdl/llm/transformers/gguf/models/llama.py | 2 +- python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py index 52441b01..fcaef116 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py @@ -80,7 +80,7 @@ def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float) model = BaiChuanForCausalLM(baichuan_config) for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight) + set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) model = model.cpu() diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py b/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py index 91353f44..f35551f9 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py @@ -77,7 +77,7 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): model = LlamaForCausalLM(llama_config) for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight) + set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) model = model.cpu() diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py index 5bac931c..8add59bb 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py @@ -77,7 +77,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): model = MistralForCausalLM(mistral_config) for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight) + set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) model = model.cpu()