[NPU] Groupwise (#12241)

* dq divide

* fix

* support attn divide

* update qwen2 7b

* divide down_proj & other linear

* use concat & reduce sum

* support scale after

* support qwen2

* w/ mm

* update reshape

* spda

* split

* split 2+

* update

* lm head-> 28

* no scale

* update

* update

* update

* fix style

* fix style

* to split linear

* update

* update code

* address comments

* fix style & remove redundant code & revert benchmark scripts

* fix style & remove code

* update save & load

---------

Co-authored-by: Yang Wang <yang3.wang@intel.com>
This commit is contained in:
Yina Chen 2024-10-23 09:10:58 +03:00 committed by GitHub
parent aedc4edfba
commit e37f951cce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 493 additions and 165 deletions

View file

@ -30,7 +30,9 @@ current_dir = os.path.dirname(os.path.realpath(__file__))
def save_npu_model_in_low_bit(repo_id, def save_npu_model_in_low_bit(repo_id,
local_model_hub, local_model_hub,
low_bit, low_bit,
max_output_len, max_prompt_len, intra_pp, inter_pp, disable_transpose_value_cache): max_output_len, max_prompt_len, intra_pp, inter_pp,
disable_transpose_value_cache,
quantization_group_size):
model_path = get_model_path(repo_id, local_model_hub) model_path = get_model_path(repo_id, local_model_hub)
# Load model in 4 bit, # Load model in 4 bit,
# which convert the relevant layers in the model into INT4 format # which convert the relevant layers in the model into INT4 format
@ -47,6 +49,7 @@ def save_npu_model_in_low_bit(repo_id,
intra_pp=intra_pp, intra_pp=intra_pp,
inter_pp=inter_pp, inter_pp=inter_pp,
transpose_value_cache=not disable_transpose_value_cache, transpose_value_cache=not disable_transpose_value_cache,
quantization_group_size=quantization_group_size
) )
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
end = time.perf_counter() end = time.perf_counter()
@ -54,6 +57,7 @@ def save_npu_model_in_low_bit(repo_id,
model.save_low_bit(model_path+'-npu-'+low_bit) model.save_low_bit(model_path+'-npu-'+low_bit)
tokenizer.save_pretrained(model_path+'-npu-'+low_bit) tokenizer.save_pretrained(model_path+'-npu-'+low_bit)
print(f"Model saved to {model_path+'-npu-'+low_bit}")
if __name__ == "__main__": if __name__ == "__main__":
@ -65,6 +69,7 @@ if __name__ == "__main__":
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--intra-pp", type=int, default=2) parser.add_argument("--intra-pp", type=int, default=2)
parser.add_argument("--inter-pp", type=int, default=2) parser.add_argument("--inter-pp", type=int, default=2)
parser.add_argument("--quantization_group_size", type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -78,5 +83,6 @@ if __name__ == "__main__":
max_prompt_len=args.max_prompt_len, max_prompt_len=args.max_prompt_len,
intra_pp=args.intra_pp, intra_pp=args.intra_pp,
inter_pp=args.inter_pp, inter_pp=args.inter_pp,
disable_transpose_value_cache=args.disable_transpose_value_cache disable_transpose_value_cache=args.disable_transpose_value_cache,
quantization_group_size=args.quantization_group_size,
) )

View file

@ -81,6 +81,8 @@ class _BaseAutoModelClass:
:param mixed_precision: boolean value, Whether to use mixed precision quantization. :param mixed_precision: boolean value, Whether to use mixed precision quantization.
Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when Default to be False. If set to ``True``, we will use ``'sym_int8'`` for lm_head when
``load_in_low_bit`` is '``sym_int4``' for certain models. ``load_in_low_bit`` is '``sym_int4``' for certain models.
:param quantization_group_size: int, quantization group size, The recommended
quantization_group_size are 0, 32, 64 or 128
:return: a model instance :return: a model instance
""" """
if kwargs.get("device_map", None) not in [None, "cpu", "auto"]: if kwargs.get("device_map", None) not in [None, "cpu", "auto"]:
@ -126,6 +128,15 @@ class _BaseAutoModelClass:
transpose_value_cache = kwargs.pop("transpose_value_cache", True) transpose_value_cache = kwargs.pop("transpose_value_cache", True)
modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
mixed_precision = kwargs.pop('mixed_precision', False) mixed_precision = kwargs.pop('mixed_precision', False)
quantization_group_size = kwargs.pop("quantization_group_size", 0)
invalidInputError(
quantization_group_size in [0, 32, 64, 128],
(
"The recommended quantization_group_size are 0, 32, 64 or 128,"
f"but got {quantization_group_size}"
)
)
_args = copy.deepcopy(args) _args = copy.deepcopy(args)
_kwargs = copy.deepcopy(kwargs) _kwargs = copy.deepcopy(kwargs)
@ -162,8 +173,11 @@ class _BaseAutoModelClass:
with torch.no_grad(): with torch.no_grad():
model.config.update({"mixed_precision": mixed_precision}) model.config.update({"mixed_precision": mixed_precision})
optimize_llm_pre(model, qtype, mixed_precision) model.config.update({"group_size": quantization_group_size})
cls.load_convert(qtype, model, "cpu", modules_to_not_convert, *args, **kwargs) optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, *args, **kwargs)
create_npu_kernels(llm) create_npu_kernels(llm)
model = model.eval() model = model.eval()
logger.info(f"Finish to convert model") logger.info(f"Finish to convert model")
@ -177,6 +191,7 @@ class _BaseAutoModelClass:
inter_pp=inter_pp, inter_pp=inter_pp,
intra_pp=intra_pp, intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache, transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
) )
model.save_low_bit = types.MethodType(save_low_bit, model) model.save_low_bit = types.MethodType(save_low_bit, model)
else: else:
@ -197,11 +212,13 @@ class _BaseAutoModelClass:
return model return model
@classmethod @classmethod
def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert, *arg, **kwarg): def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
group_size=0, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear
replace_with_QuantizedLinear(optimize_model, q_k, device=device, replace_with_QuantizedLinear(optimize_model, q_k, device=device,
modules_to_not_convert=modules_to_not_convert) modules_to_not_convert=modules_to_not_convert,
group_size=group_size)
@classmethod @classmethod
@patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import) @patch("transformers.dynamic_module_utils.get_imports", patch_flash_attn_import)
@ -214,6 +231,7 @@ class _BaseAutoModelClass:
ignore_argument(kwargs, "speculative") ignore_argument(kwargs, "speculative")
ignore_argument(kwargs, "pipeline_parallel_stages") ignore_argument(kwargs, "pipeline_parallel_stages")
ignore_argument(kwargs, "mixed_precision") ignore_argument(kwargs, "mixed_precision")
ignore_argument(kwargs, "quantization_group_size")
optimize_model = kwargs.pop("optimize_model", False) optimize_model = kwargs.pop("optimize_model", False)
max_output_len = kwargs.pop("max_output_len", 1024) max_output_len = kwargs.pop("max_output_len", 1024)
max_prompt_len = kwargs.pop("max_prompt_len", 512) max_prompt_len = kwargs.pop("max_prompt_len", 512)
@ -264,6 +282,7 @@ class _BaseAutoModelClass:
qtype = config_dict.pop("bigdl_transformers_low_bit", False) qtype = config_dict.pop("bigdl_transformers_low_bit", False)
bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True) bigdl_lcmu_enabled = config_dict.pop("bigdl_lcmu_enabled", True)
mixed_precision = config_dict.pop("mixed_precision", False) mixed_precision = config_dict.pop("mixed_precision", False)
quantization_group_size = config_dict.pop("group_size", 0)
invalidInputError( invalidInputError(
qtype, qtype,
@ -376,9 +395,10 @@ class _BaseAutoModelClass:
llm = model llm = model
with torch.no_grad(): with torch.no_grad():
optimize_llm_pre(model, qtype, mixed_precision) optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
cls.load_convert(qtype, model, quant_device, modules_to_not_convert, cls.load_convert(qtype, model, quant_device, modules_to_not_convert,
*model_args, **kwargs) quantization_group_size, *model_args, **kwargs)
create_npu_kernels(llm) create_npu_kernels(llm)
else: else:
@ -458,6 +478,7 @@ class _BaseAutoModelClass:
inter_pp=inter_pp, inter_pp=inter_pp,
intra_pp=intra_pp, intra_pp=intra_pp,
transpose_value_cache=transpose_value_cache, transpose_value_cache=transpose_value_cache,
group_size=quantization_group_size
) )
return model return model

View file

@ -16,6 +16,7 @@
import torch import torch
from typing import List from typing import List
from ipex_llm.utils.common.log4Error import invalidInputError
def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear: def merge_linear(linears: List[torch.nn.Linear]) -> torch.nn.Linear:
@ -40,3 +41,21 @@ def reshape_lm_head_input(x):
shape[1] = 1 shape[1] = 1
x = x[:, -1, :].view(shape) x = x[:, -1, :].view(shape)
return x return x
def split_linear(module, module_name, n_splits=2):
in_features = module.in_features
invalidInputError(in_features % n_splits == 0,
f"in_features of the linear layer {module_name} must be divisible by"
f" n_splits, but got in_features: {in_features}, n_splits: {n_splits}")
weight_split = torch.tensor_split(module.weight, n_splits, dim=1)
linear_list = torch.nn.ModuleList()
bias = module.bias
for idx, weight in enumerate(weight_split):
new_linear = torch.nn.Linear(weight.size(1),
weight.size(0),
bias=False if bias is None else True)
new_linear.bias = bias
new_linear.weight = torch.nn.Parameter(weight.contiguous(), requires_grad=False)
linear_list.add_module(f"{module_name}_dq_{idx}", new_linear)
return linear_list

View file

@ -31,7 +31,8 @@ def module_optimization(func) -> torch.nn.Module:
torch.nn.Module: optimized module torch.nn.Module: optimized module
""" """
def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, *args, **kwargs): def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
group_size=0, *args, **kwargs):
"""Recursively apply the optimization function. """Recursively apply the optimization function.
Args: Args:
@ -42,18 +43,22 @@ def module_optimization(func) -> torch.nn.Module:
""" """
for name, layer in model.named_children(): for name, layer in model.named_children():
if name not in modules_to_not_convert: if name not in modules_to_not_convert:
new_layer = func(layer, qtype, device, modules_to_not_convert, *args, **kwargs) new_layer = func(layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs)
if new_layer: if new_layer:
model.add_module(name, new_layer) model.add_module(name, new_layer)
wrapper(new_layer, qtype, device, modules_to_not_convert, *args, **kwargs) wrapper(new_layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs)
else: else:
wrapper(layer, qtype, device, modules_to_not_convert, *args, **kwargs) wrapper(layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs)
return wrapper return wrapper
@module_optimization @module_optimization
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert): def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
group_size):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype] iqtype = ggml_tensor_qtype[qtype]
@ -66,7 +71,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert):
iqtype = ggml_tensor_qtype[qtype] iqtype = ggml_tensor_qtype[qtype]
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device) iqtype, device=device)
return QuantizedLinear(qweights, scale, layer.bias) return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size)
def convert_forward(m, target_m, new_forward): def convert_forward(m, target_m, new_forward):

View file

@ -19,6 +19,7 @@ import importlib
import numpy as np import numpy as np
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params
from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead from ipex_llm.transformers.npu_models.lm_head import LMHeadLinear, SlicedLMHead
from ipex_llm.utils.common.log4Error import invalidInputError
def convert_forward(m, target_m, new_forward): def convert_forward(m, target_m, new_forward):
@ -29,7 +30,8 @@ def convert_forward(m, target_m, new_forward):
convert_forward(sub_m, target_m, new_forward) convert_forward(sub_m, target_m, new_forward)
def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision): def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
quantization_group_size=0):
if model.config.model_type == "baichuan": if model.config.model_type == "baichuan":
# process NormHead module in Baichuan2 7B # process NormHead module in Baichuan2 7B
if hasattr(model, 'lm_head') and model.lm_head is not None: if hasattr(model, 'lm_head') and model.lm_head is not None:
@ -86,17 +88,40 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision):
model = model.llm model = model.llm
if model.config.model_type == "qwen2": if model.config.model_type == "qwen2":
from ipex_llm.transformers.npu_models.qwen2_mp import split_mlp_down_proj from ipex_llm.transformers.npu_models.qwen2_mp import split_linears
model.apply(split_mlp_down_proj)
if quantization_group_size == 0:
n_splits_linear = 1
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
else:
invalidInputError(
model.config.hidden_size % quantization_group_size == 0 and
model.config.intermediate_size % quantization_group_size == 0,
"The model hidden_size and intermediate_size should be divisible by "
f"quantization_group_size, but got hidden_size: {model.config.hidden_size}, "
f"intermediate_size: {model.config.intermediate_size}, and "
f"quantization_group_size: {quantization_group_size}"
)
n_splits_linear = model.config.hidden_size // quantization_group_size
n_splits_down_proj = model.config.intermediate_size // quantization_group_size
model.apply(lambda m: split_linears(m, n_splits_hidden_size=n_splits_linear,
n_splits_down_proj=n_splits_down_proj))
# for Qwen2-7B-Insturct, divide lm_head into 14 parts # for Qwen2-7B-Insturct, divide lm_head into 14 parts
if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \ if model.config.hidden_size == 3584 and model.config.vocab_size == 152064 and \
not cpu_lm_head: not cpu_lm_head:
# Do not split lm_head and use sym_int8 instead when mixed_precison is True # Do not split lm_head and use sym_int8 instead when mixed_precison is True
if quantization_group_size != 0:
split_num = model.config.hidden_size // quantization_group_size
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
bias=model.lm_head.bias, use_split=True)
else:
# Do not split lm_head and use sym_int8 instead when mixed_precison is True
is_split = (not mixed_precision) and qtype == "sym_int4_rtn" is_split = (not mixed_precision) and qtype == "sym_int4_rtn"
split_num = 14 if is_split else 1 split_num = 14 if is_split else 1
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
bias=model.lm_head.bias) bias=model.lm_head.bias, use_split=False)
del model.lm_head del model.lm_head
model.lm_head = new_lm_head model.lm_head = new_lm_head
@ -132,6 +157,7 @@ def optimize_llm(
inter_pp=None, inter_pp=None,
intra_pp=None, intra_pp=None,
transpose_value_cache=True, transpose_value_cache=True,
group_size=0
): ):
if model.config.model_type == "llama": if model.config.model_type == "llama":
if intra_pp is None: if intra_pp is None:
@ -168,7 +194,13 @@ def optimize_llm(
if intra_pp is None: if intra_pp is None:
intra_pp = 2 intra_pp = 2
if inter_pp is None: if inter_pp is None:
inter_pp = 2 if model.config.intermediate_size == 18944 else 1 if model.config.intermediate_size == 18944:
if group_size != 0:
inter_pp = 5
else:
inter_pp = 2
else:
inter_pp = 1
from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward from ipex_llm.transformers.npu_models.qwen2_mp import gen_qwen2_fused_model_forward
from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner from ipex_llm.transformers.npu_models.qwen2_mp import DecodeRunner, PrefillRunner

View file

@ -130,6 +130,7 @@ class QuantizedLinear(torch.nn.Module):
weight: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor, scale: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
group_size: int = False,
): ):
"""Initialize the QuantizedLinear class. """Initialize the QuantizedLinear class.
@ -154,8 +155,11 @@ class QuantizedLinear(torch.nn.Module):
) )
) )
self.outC, self.inC = self.weight.shape self.outC, self.inC = self.weight.shape
if group_size != 0:
self.scale = Parameter(scale, requires_grad=False)
else:
if self.weight.dtype == torch.uint8: if self.weight.dtype == torch.uint8:
# In case is Int4 we need to double the input channels because weights are compressed # Int4 we need to double the input channels because weights are compressed
self.inC *= 2 self.inC *= 2
self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False) self.scale = Parameter(scale * math.sqrt(self.inC), requires_grad=False)
self.bias = bias self.bias = bias

View file

@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
from torch import nn from torch import nn
import numpy as np import numpy as np
from filelock import FileLock
from intel_npu_acceleration_library.backend import NNFactory from intel_npu_acceleration_library.backend import NNFactory
from intel_npu_acceleration_library.backend.bindings import lib as backend_lib from intel_npu_acceleration_library.backend.bindings import lib as backend_lib
@ -34,6 +34,7 @@ class LMHeadLinear(NNFactory):
profile: bool = False, profile: bool = False,
device: str = "NPU", device: str = "NPU",
dtype: np.dtype = np.int8, dtype: np.dtype = np.int8,
use_split: bool = False,
): ):
"""Initialize the LMHeadLinear class. """Initialize the LMHeadLinear class.
@ -51,9 +52,14 @@ class LMHeadLinear(NNFactory):
self.inC, self.outC = inC, outC self.inC, self.outC = inC, outC
self.batch = batch self.batch = batch
input = self.parameter((self.batch, self.inC))
self.split_num = split_num self.split_num = split_num
if use_split:
input = self.parameter((1, self.batch, self.inC))
res = self.dq_split_linear(input, self.split_num, self.outC, self.inC, wt_dtype=dtype,
scale_factor=False)
else:
input = self.parameter((self.batch, self.inC))
split_size = self.inC // split_num // 2 * 2 split_size = self.inC // split_num // 2 * 2
for i in range(self.split_num): for i in range(self.split_num):
@ -61,7 +67,8 @@ class LMHeadLinear(NNFactory):
end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC end_idx = (i + 1) * split_size if i < self.split_num - 1 else self.inC
input_slice = self.slice(input, begin=[0, start_idx], input_slice = self.slice(input, begin=[0, start_idx],
end=[self.batch, end_idx]) end=[self.batch, end_idx])
linear_slice = self.linear(input_slice, outC, split_size, bias=False, wt_dtype=dtype) linear_slice = self.linear(input_slice, outC, split_size, bias=False,
wt_dtype=dtype)
if i == 0: if i == 0:
res = linear_slice res = linear_slice
else: else:
@ -71,6 +78,14 @@ class LMHeadLinear(NNFactory):
self.compile() self.compile()
print("end compiling lm_head") print("end compiling lm_head")
def set_weights(self, op_id, weights):
self.set_weights_async(op_id, weights)
with FileLock(f"lmhead_run.lock"):
backend_lib.run(self._mm)
def set_weights_async(self, op_id, weights):
self.setWeights(1, op_id, *weights)
def run( def run(
self, X: np.ndarray self, X: np.ndarray
) -> np.ndarray: ) -> np.ndarray:
@ -93,7 +108,7 @@ class LMHeadLinear(NNFactory):
class SlicedLMHead(nn.Module): class SlicedLMHead(nn.Module):
def __init__(self, weight, bias, split_num): def __init__(self, weight, bias, split_num, use_split=False):
super().__init__() super().__init__()
self.split_num = split_num self.split_num = split_num
self.outC, self.inC = weight.shape self.outC, self.inC = weight.shape
@ -110,6 +125,7 @@ class SlicedLMHead(nn.Module):
new_linear.out_features = new_weight.size(0) new_linear.out_features = new_weight.size(0)
self.lm_heads.append(new_linear) self.lm_heads.append(new_linear)
self.bias = bias self.bias = bias
self.use_split = use_split
def forward(self, hidden_states): def forward(self, hidden_states):
if hidden_states.size(0) * hidden_states.size(1) == 1: if hidden_states.size(0) * hidden_states.size(1) == 1:
@ -143,9 +159,19 @@ class SlicedLMHead(nn.Module):
def get_fused_lm_head(self): def get_fused_lm_head(self):
np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8 np_dtype = np.uint8 if self.get_weight_dtype() == torch.uint8 else np.int8
self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num, self.fused_lm_head = LMHeadLinear(self.inC, self.outC, 1, self.split_num,
False, "NPU", dtype=np_dtype) False, "NPU", dtype=np_dtype, use_split=self.use_split)
if self.use_split:
weights = []
scales = []
for i in range(self.split_num):
weights.append(self.lm_heads[i].weight)
scales.append(self.lm_heads[i].scale)
fused_lm_head_weights = (torch.stack(weights, axis=0).numpy(),
torch.stack(scales, axis=0).numpy())
else:
fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(), fused_lm_head_weights = [(self.lm_heads[i].weight.data.numpy(),
self.lm_heads[i].scale.data.numpy()) self.lm_heads[i].scale.data.numpy())
for i in range(self.split_num)] for i in range(self.split_num)]
self.fused_lm_head.setWeights(1, self.lm_heads[0].op_id,
*fused_lm_head_weights) self.fused_lm_head.set_weights(self.lm_heads[0].op_id,
fused_lm_head_weights)

View file

@ -27,6 +27,8 @@ from filelock import FileLock
import ctypes import ctypes
import math import math
import numpy as np import numpy as np
from typing import Optional, Any, List
import numpy.typing as npt
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -60,6 +62,12 @@ def run_model(
op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy())) op_args.append((set_contiguous(w[0]).numpy(), set_contiguous(w[1]).numpy()))
op_args_flatten.append(op_args[-1][0]) op_args_flatten.append(op_args[-1][0])
op_args_flatten.append(op_args[-1][1]) op_args_flatten.append(op_args[-1][1])
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
op_args.append(w.numpy())
op_args_flatten.append(op_args[-1])
elif isinstance(w, np.ndarray): # scale
op_args.append(w)
op_args_flatten.append(op_args[-1])
else: else:
op_args.append(set_contiguous(w).to(torch.float16).numpy()) op_args.append(set_contiguous(w).to(torch.float16).numpy())
op_args_flatten.append(op_args[-1]) op_args_flatten.append(op_args[-1])
@ -94,7 +102,8 @@ def run_model(
class LLMBaseNNFactory(NNFactory): class LLMBaseNNFactory(NNFactory):
def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU"): def __init__(self, max_seq_len, transpose_value, dtype, profile=False, device="NPU",
n_splits_linear=1, n_splits_down_proj=1, group_size=False):
super().__init__(profile, device) super().__init__(profile, device)
self.cache_parameter_ops = [] self.cache_parameter_ops = []
self.input_ops = [] self.input_ops = []
@ -104,6 +113,9 @@ class LLMBaseNNFactory(NNFactory):
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.transpose_value = transpose_value self.transpose_value = transpose_value
self.dtype = dtype self.dtype = dtype
self.n_splits_linear = n_splits_linear
self.n_splits_down_proj = n_splits_down_proj
self.group_size = group_size
def attention(self, def attention(self,
*, *,
@ -124,6 +136,8 @@ class LLMBaseNNFactory(NNFactory):
v_bias=None): v_bias=None):
hidden_size = num_heads * head_dim hidden_size = num_heads * head_dim
num_key_value_groups = num_heads // num_key_value_heads num_key_value_groups = num_heads // num_key_value_heads
groupsize = hidden_size // self.n_splits_linear
if self.n_splits_linear == 1:
query_states = self.linear( query_states = self.linear(
hidden_states, hidden_states,
num_heads * head_dim, num_heads * head_dim,
@ -131,8 +145,7 @@ class LLMBaseNNFactory(NNFactory):
bias=False, bias=False,
wt_dtype=self.dtype, wt_dtype=self.dtype,
) )
if q_bias is not None:
query_states = query_states + q_bias
key_states = self.linear( key_states = self.linear(
hidden_states, hidden_states,
num_key_value_heads * head_dim, num_key_value_heads * head_dim,
@ -140,8 +153,7 @@ class LLMBaseNNFactory(NNFactory):
bias=False, bias=False,
wt_dtype=self.dtype, wt_dtype=self.dtype,
) )
if k_bias is not None:
key_states = key_states + k_bias
value_states = self.linear( value_states = self.linear(
hidden_states, hidden_states,
num_key_value_heads * head_dim, num_key_value_heads * head_dim,
@ -149,6 +161,67 @@ class LLMBaseNNFactory(NNFactory):
bias=False, bias=False,
wt_dtype=self.dtype, wt_dtype=self.dtype,
) )
else:
hidden_states = self.unsqueeze(hidden_states, axis=0)
if mode == "prefill":
query_states_to_concat = []
key_states_to_concat = []
value_states_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
query_states_to_concat.append(
self.linear(
sub_hidden_states,
num_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
key_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
value_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
else:
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
if q_bias is not None:
query_states = query_states + q_bias
if k_bias is not None:
key_states = key_states + k_bias
if v_bias is not None: if v_bias is not None:
value_states = value_states + v_bias value_states = value_states + v_bias
@ -215,23 +288,100 @@ class LLMBaseNNFactory(NNFactory):
attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.transpose(attn_output, [0, 2, 1, 3])
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
if self.n_splits_linear == 1:
attn_output = self.linear( attn_output = self.linear(
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
) )
else:
if mode == "prefill":
attn_output_to_concat = []
for i in range(self.n_splits_linear):
sub_attn_output = self.slice(attn_output,
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
attn_output_to_concat.append(
self.linear(
sub_attn_output, hidden_size, groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
attn_output = sum(attn_output_to_concat)
else:
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
return attn_output, new_key_states, new_value_states return attn_output, new_key_states, new_value_states
def mlp(self, hidden_states): def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
if self.n_splits_linear == 1:
mm1 = self.linear( mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype hidden_states, self.intermediate_size, self.hidden_size, bias=False,
wt_dtype=self.dtype
) )
mm2 = self.linear( mm2 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype hidden_states, self.intermediate_size, self.hidden_size, bias=False,
wt_dtype=self.dtype
) # type: ignore[attr-defined] ) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill":
gate_up_groupsize = self.hidden_size // self.n_splits_linear
mm1_to_concat = []
mm2_to_concat = []
for i in range(self.n_splits_linear):
sub_hidden_states = self.slice(hidden_states,
begin=[0, 0, i * gate_up_groupsize],
end=[1, seq_len, (i + 1) * gate_up_groupsize])
mm1_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm2_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
else:
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.n_splits_down_proj == 1:
hidden_states = self.linear( hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
) )
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill":
down_groupsize = self.intermediate_size // self.n_splits_down_proj
hidden_states_to_concat = []
for i in range(self.n_splits_down_proj):
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
end=[1, seq_len, (i + 1) * down_groupsize])
hidden_states_to_concat.append(
self.linear(
sub_mm1, self.hidden_size, down_groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
hidden_states = sum(hidden_states_to_concat)
else:
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
return hidden_states return hidden_states
def layer_norm(self, hidden_states, layernorm_weight): def layer_norm(self, hidden_states, layernorm_weight):
@ -341,6 +491,19 @@ class LLMBaseNNFactory(NNFactory):
self.linear_ops.append(op) self.linear_ops.append(op)
return op return op
def dq_split_linear(self,
input_node: ctypes._Pointer,
output_channels: int,
input_channels: int,
n_splits: int,
act_dtype: npt.DTypeLike = np.float16,
wt_dtype: npt.DTypeLike = np.float16,
scale_factor: bool = False):
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
False, act_dtype, wt_dtype, scale_factor)
self.linear_ops.append(op)
return op
def parameter(self, shape): def parameter(self, shape):
invalidInputError(False, invalidInputError(False,
("parameter should not be called directly, " ("parameter should not be called directly, "

View file

@ -42,7 +42,27 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.npu_models.common import split_linear
def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down_proj=2):
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
if isinstance(module, Qwen2Attention):
for name in attn_module_names:
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
n_splits=n_splits_hidden_size))
delattr(module, name)
elif isinstance(module, Qwen2MLP):
for name in mlp_module_names:
n_splits_mlp = n_splits_hidden_size
if name == 'down_proj':
n_splits_mlp = n_splits_down_proj
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
n_splits=n_splits_mlp))
delattr(module, name)
def split_mlp_down_proj(module: torch.nn.Module): def split_mlp_down_proj(module: torch.nn.Module):
@ -94,12 +114,18 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
device: str = "NPU", device: str = "NPU",
rms_norm_eps, rms_norm_eps,
intermediate_size, intermediate_size,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
): ):
super().__init__(max_seq_len=max_seq_len, super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value, transpose_value=transpose_value,
dtype=dtype, dtype=dtype,
profile=profile, profile=profile,
device=device) device=device,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.dtype = dtype self.dtype = dtype
@ -221,32 +247,9 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
new_key_states = self.convert_to_fp16(curr_key_values[i][0]) new_key_states = self.convert_to_fp16(curr_key_values[i][0])
new_value_states = self.convert_to_fp16(curr_key_values[i][1]) new_value_states = self.convert_to_fp16(curr_key_values[i][1])
print("start compiling") print(f"{mode} start compiling")
self.compile() self.compile()
print("end compiling") print(f"{mode} end compiling")
def mlp(self, hidden_states, seq_len):
mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
)
mm2 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype
) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.intermediate_size == 18944:
# for qwen2-7b
mm1_0 = self.slice(mm1, begin=[0, 0, 0], end=[1, seq_len, 9472])
mm1_1 = self.slice(mm1, begin=[0, 0, 9472], end=[1, seq_len, 18944])
hidden_states_0 = self.linear(mm1_0, self.hidden_size, 9472,
bias=False, wt_dtype=self.dtype)
hidden_states_1 = self.linear(mm1_1, self.hidden_size, 9472,
bias=False, wt_dtype=self.dtype)
hidden_states = hidden_states_0 + hidden_states_1
else:
hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
)
return hidden_states
def build_decoder( def build_decoder(
self, self,
@ -285,7 +288,7 @@ class LowBitQwenMultiDecoderlayer(LLMBaseNNFactory):
hidden_states = self.eltwise_add(residual, attn_output) hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
hidden_states = self.mlp(hidden_states, self.seq_len) hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states) hidden_states = self.convert_to_fp16(hidden_states)
@ -314,6 +317,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
max_seq_len: int = 1024, max_seq_len: int = 1024,
transpose_value: bool = False, transpose_value: bool = False,
do_print: bool = False, do_print: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0,
): ):
super().__init__() super().__init__()
@ -323,6 +329,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
for w in parameters: for w in parameters:
if isinstance(w, tuple): # from QuantizedLinear if isinstance(w, tuple): # from QuantizedLinear
op_parameters.append((w[0].numpy(), w[1].numpy())) op_parameters.append((w[0].numpy(), w[1].numpy()))
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
op_parameters.append(w.numpy())
elif isinstance(w, np.ndarray): # scale
op_parameters.append(w)
else: else:
op_parameters.append(w.to(torch.float16).numpy()) op_parameters.append(w.to(torch.float16).numpy())
self.op_parameters = op_parameters self.op_parameters = op_parameters
@ -331,6 +341,10 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
self.transpose_value = transpose_value self.transpose_value = transpose_value
if isinstance(parameters[0], tuple): if isinstance(parameters[0], tuple):
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
elif parameters[0].dtype == torch.int8:
np_dtype = np.int8
elif parameters[0].dtype == torch.uint8:
np_dtype = np.uint8
else: # FP16 Linear else: # FP16 Linear
np_dtype = np.float16 np_dtype = np.float16
@ -368,6 +382,9 @@ class FusedQwenLowBitMultiDecoderlayer(torch.nn.Module):
mode="decode", mode="decode",
transpose_value=self.transpose_value, transpose_value=self.transpose_value,
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
self.backend_decoders.append(decoder) self.backend_decoders.append(decoder)
@ -450,6 +467,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
intermediate_size, intermediate_size,
max_seq_len: int = 128, max_seq_len: int = 128,
transpose_value: bool = False, transpose_value: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0,
): ):
super().__init__() super().__init__()
self.op_parameters = parameters self.op_parameters = parameters
@ -478,6 +498,9 @@ class FusedQwenLowBitDecoderlayer(torch.nn.Module):
mode="prefill", mode="prefill",
transpose_value=self.transpose_value, transpose_value=self.transpose_value,
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
self.layer_norm_0 = layer_norm_0 self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1 self.layer_norm_1 = layer_norm_1
@ -554,6 +577,7 @@ def run_decode(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
group_size = getattr(model.config, "group_size", 0)
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
@ -561,34 +585,56 @@ def run_decode(
k_biases = [] k_biases = []
v_biases = [] v_biases = []
layer_indexs = range(layer_start, layer_end) layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs: for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
if model.config.intermediate_size == 8960: weights = []
# for qwen2-1.5b if n_splits_linear == 1:
weights = [ for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
(attn_layer.q_proj.weight, attn_layer.q_proj.scale), attn_layer.v_proj_dq_list):
(attn_layer.k_proj.weight, attn_layer.k_proj.scale), weights.append((q.weight, q.scale))
(attn_layer.v_proj.weight, attn_layer.v_proj.scale), weights.append((k.weight, k.scale))
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), weights.append((v.weight, v.scale))
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), for l in attn_layer.o_proj_dq_list:
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), weights.append((l.weight, l.scale))
] else:
elif model.config.intermediate_size == 18944: for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
# for qwen2-7b attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list]:
weights = [ l_weights = []
(attn_layer.q_proj.weight, attn_layer.q_proj.scale), scales = []
(attn_layer.k_proj.weight, attn_layer.k_proj.scale), for l in layer_list:
(attn_layer.v_proj.weight, attn_layer.v_proj.scale), l_weights.append(l.weight)
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), scales.append(l.scale)
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale), if n_splits_linear == 1:
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale) for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
] weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
@ -598,9 +644,9 @@ def run_decode(
layer_weights.extend(weights) layer_weights.extend(weights)
input_layer_norm_weights.append(layer_norm_0) input_layer_norm_weights.append(layer_norm_0)
post_attn_layernorm_weights.append(layer_norm_1) post_attn_layernorm_weights.append(layer_norm_1)
q_biases.append(attn_layer.q_proj.bias.to(torch.float16)) q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16))
k_biases.append(attn_layer.k_proj.bias.to(torch.float16)) k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16))
v_biases.append(attn_layer.v_proj.bias.to(torch.float16)) v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16))
multi_decoder = FusedQwenLowBitMultiDecoderlayer( multi_decoder = FusedQwenLowBitMultiDecoderlayer(
parameters=layer_weights, parameters=layer_weights,
@ -621,6 +667,9 @@ def run_decode(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
transpose_value=transpose_value_cache, transpose_value=transpose_value_cache,
do_print=False, do_print=False,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
dist.barrier() dist.barrier()
@ -703,11 +752,15 @@ class DecodeRunner:
self.forward_signal = torch.tensor(0, dtype=torch.int) self.forward_signal = torch.tensor(0, dtype=torch.int)
n_layers_per_rank = num_layers // (world_size - 1)
if num_layers % (world_size - 1) > 0:
n_layers_per_rank += 1
for rank in range(1, world_size): for rank in range(1, world_size):
input_q = mp.Queue() input_q = mp.Queue()
output_q = mp.Queue() output_q = mp.Queue()
start_layer = (rank - 1) * (num_layers // (world_size - 1)) start_layer = (rank - 1) * n_layers_per_rank
end_layer = (rank) * (num_layers // (world_size - 1)) end_layer = (rank) * n_layers_per_rank
if rank == world_size - 1: if rank == world_size - 1:
end_layer = num_layers end_layer = num_layers
p = mp.Process( p = mp.Process(
@ -787,39 +840,34 @@ def run_prefill(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
group_size = getattr(model.config, "group_size", 0)
deocderlayers = [] deocderlayers = []
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end) layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs: for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
if model.config.intermediate_size == 8960: weights = []
# for qwen2-1.5b
weights = [ for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
(attn_layer.q_proj.weight, attn_layer.q_proj.scale), attn_layer.v_proj_dq_list):
(attn_layer.k_proj.weight, attn_layer.k_proj.scale), weights.append((q.weight, q.scale))
(attn_layer.v_proj.weight, attn_layer.v_proj.scale), weights.append((k.weight, k.scale))
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), weights.append((v.weight, v.scale))
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), for l in attn_layer.o_proj_dq_list:
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), weights.append((l.weight, l.scale))
] for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
elif model.config.intermediate_size == 18944: weights.append((g.weight, g.scale))
# for qwen2-7b weights.append((u.weight, u.scale))
weights = [ for l in mlp_layer.down_proj_dq_list:
(attn_layer.q_proj.weight, attn_layer.q_proj.scale), weights.append((l.weight, l.scale))
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
(mlp_layer.down_proj_0.weight, mlp_layer.down_proj_0.scale),
(mlp_layer.down_proj_1.weight, mlp_layer.down_proj_1.scale)
]
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
@ -835,14 +883,17 @@ def run_prefill(
cached_sin=cached_sin, cached_sin=cached_sin,
layer_norm_0=layer_norm_0, layer_norm_0=layer_norm_0,
layer_norm_1=layer_norm_1, layer_norm_1=layer_norm_1,
q_bias=attn_layer.q_proj.bias.to(torch.float16), q_bias=attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16),
k_bias=attn_layer.k_proj.bias.to(torch.float16), k_bias=attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16),
v_bias=attn_layer.v_proj.bias.to(torch.float16), v_bias=attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16),
layer_idx=layer_idx, layer_idx=layer_idx,
rms_norm_eps=rms_norm_eps, rms_norm_eps=rms_norm_eps,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
max_seq_len=max_output_len, max_seq_len=max_output_len,
transpose_value=transpose_value_cache, transpose_value=transpose_value_cache,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
layer_weights.extend(weights) layer_weights.extend(weights)