Fix Mixtral GGUF Wrong Output Issue (#9930)

* Fix Mixtral GGUF Wrong Output Issue

* fix style

* fix style
This commit is contained in:
Heyang Sun 2024-01-18 14:11:27 +08:00 committed by GitHub
parent 453df868c9
commit 5184f400f9
3 changed files with 28 additions and 15 deletions

View file

@ -28,7 +28,7 @@ conda create -n llm python=3.9 # recommend to use Python 3.9
conda activate llm
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
pip install transformers==4.34.0 # upgrade transformers
pip install transformers==4.36.0 # upgrade transformers
```
### 2. Run
After setting up the Python environment, you could run the example by following steps.

View file

@ -28,7 +28,7 @@ conda activate llm
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip install transformers==4.34.0 # upgrade transformers
pip install transformers==4.36.0 # upgrade transformers
```
### 2. Configures OneAPI environment variables

View file

@ -17,7 +17,7 @@
import os
import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device as fill_model
from accelerate.utils import set_module_tensor_to_device
from tempfile import NamedTemporaryFile
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
@ -53,21 +53,34 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
with init_empty_weights():
model = MixtralForCausalLM(mixtral_config)
# define an operator function that passed to low-level gguf API
def process_mixtral(name, tensor):
# prepare module's name in transformers
module_name = get_mixtral_module_name(name)
# prepare module's weight in transformers
if 'ffn_gate_inp' in name:
# gguf weight needs to reshape for ffn_gate_inp
fill_model(model,
module_name,
"cpu",
tensor.reshape(num_local_experts, hidden_size),
dtype=dtype)
else:
fill_model(model,
module_name,
"cpu",
tensor,
dtype=dtype)
tensor = tensor.reshape(num_local_experts, hidden_size)
elif name.endswith("attn_q.weight"):
head, hd_size = tensor.shape[0], tensor.shape[1:]
tensor = (tensor.reshape(n_head,
head // n_head // 2,
2,
*hd_size)
.swapaxes(1, 2)
.reshape(tensor.shape))
elif name.endswith("attn_k.weight"):
head, hd_size = tensor.shape[0], tensor.shape[1:]
tensor = (tensor.reshape(n_head_kv,
head // n_head_kv // 2,
2,
*hd_size)
.swapaxes(1, 2)
.reshape(tensor.shape))
set_module_tensor_to_device(model,
module_name,
"cpu",
tensor,
dtype=dtype)
tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mixtral)