LLM: gguf other model using dtype (#9729)

This commit is contained in:
Wang, Jian4 2023-12-21 09:33:40 +08:00 committed by GitHub
parent 13ea6330bd
commit e1e921f425
3 changed files with 3 additions and 3 deletions

View file

@ -80,7 +80,7 @@ def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float)
model = BaiChuanForCausalLM(baichuan_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight)
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
model = model.cpu()

View file

@ -77,7 +77,7 @@ def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
model = LlamaForCausalLM(llama_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight)
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
model = model.cpu()

View file

@ -77,7 +77,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
model = MistralForCausalLM(mistral_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight)
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
model = model.cpu()