From 51bcac12299a3271e9cb0fc1577c45edd4c0643a Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Tue, 13 Aug 2024 18:53:55 -0700 Subject: [PATCH] follow up on experimental support of fused decoder layer for llama2 (#11785) * clean up and support transpose value cache * refine * fix style * fix style --- .../HF-Transformers-AutoModels/LLM/llama2.py | 502 ++++++++---------- .../src/ipex_llm/transformers/npu_model.py | 2 +- .../ipex_llm/transformers/npu_models/kv.py | 124 +++-- 3 files changed, 301 insertions(+), 327 deletions(-) diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py index 55749fcf..9a384fe7 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/llama2.py @@ -15,7 +15,7 @@ # import os -os.environ["OMP_NUM_THREADS"] = "4" +os.environ["OMP_NUM_THREADS"] = "8" os.environ["IPEX_LLM_LAST_LM_HEAD"] = "1" import torch import time @@ -40,6 +40,7 @@ from functools import partial import torch.nn.functional as F import torch.nn.parallel import torch.distributed as dist +from filelock import FileLock from transformers.utils import logging logger = logging.get_logger(__name__) @@ -116,164 +117,12 @@ def run_model( return results -class LowBitLlamaDecoderlayer(NNFactory): - def __init__( - self, - hidden_shape: Sequence[int], - attenion_mask_shape=None, - position_id_shape=None, - past_key_shape=None, - past_value_shape=None, - input_layernorm_shape=None, - post_layernorm_shape=None, - *, - num_heads: int, - num_key_value_heads: int, - cached_cos, - cached_sin, - mode: str = "prefill", - dtype: np.dtype = np.int8, - max_seq_len: int = 128, - profile: bool = False, - device: str = "NPU", - rms_norm_eps, - intermediate_size, - **additional_args - ): - super().__init__(profile, device) - self.max_seq_len = max_seq_len - self.intermediate_size = intermediate_size - eps = self.constant(rms_norm_eps) - - self.batch_size, self.seq_len, self.hidden_size = hidden_shape - - if mode == "decode": - invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode") - self.num_heads = num_heads - self.num_key_value_heads = num_key_value_heads - - self.head_dim = self.hidden_size // self.num_heads - - # define input, the order self.parameter matters - input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) - - # Self Attention - if mode == "decode": - attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) - else: - attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) - - position_ids = self.parameter((self.batch_size, self.seq_len)) - - input_layernorm_weight = self.parameter((1, self.hidden_size,)) - post_attention_layernorm_weight = self.parameter((1, self.hidden_size,)) - - if mode == "decode": - past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) - past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) - - residual = input - - input_2d = self.reshape(input, (self.batch_size * self.seq_len, self.hidden_size)) - - # input_layernorm forward - input_2d = self.convert_to_fp32(input_2d) - variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True) - input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) - input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) - input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) - input_2d = self.convert_to_fp16(input_2d) - - query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype) - key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype) - value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=dtype) - - cos = self.constant(cached_cos) - cos = self.unsqueeze(cos, axis=0) - - sin = self.constant(cached_sin) - sin = self.unsqueeze(sin, axis=0) - - query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) - key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) - value_states = self.reshape(value_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) - - query_states = self.transpose(query_states, [0, 2, 1, 3]) - key_states = self.transpose(key_states, [0, 2, 1, 3]) - value_states = self.transpose(value_states, [0, 2, 1, 3]) - - query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - new_key_states = key_states - new_value_states = value_states - - invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads") - - if mode == "decode": - key_states = self.concat(past_key, key_states, axis=-2) - value_states = self.concat(past_value, value_states, axis=-2) - - attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) - attn_weight = self.eltwise_add(attn_weight, attention_mask) - attn_weight = self.convert_to_fp32(attn_weight) - attn_weight = self.softmax(attn_weight, -1) - attn_weight = self.convert_to_fp16(attn_weight) - attn_output = self.matmul(attn_weight, value_states, False, False) - - attn_output = self.transpose(attn_output, [0, 2, 1, 3]) - attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) - - attn_output = self.linear(attn_output, self.hidden_size, self.hidden_size, bias=False, wt_dtype=dtype) - - hidden_states = self.eltwise_add(residual, attn_output) - - # Fully Connected - residual = hidden_states - hidden_states = self.convert_to_fp32(hidden_states) - variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True) - hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) - hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) - hidden_states = self.convert_to_fp16(hidden_states) - - # mlp - mm1 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, - bias=False, wt_dtype=dtype) - mm2 = self.linear(hidden_states, self.intermediate_size, self.hidden_size, - bias=False, wt_dtype=dtype) # type: ignore[attr-defined] - mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] - - hidden_states = self.linear(mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=dtype) - - hidden_states = self.eltwise_add(residual, hidden_states) - hidden_states = self.convert_to_fp16(hidden_states) - - # hacking to add key, value to outputs - new_key_states = self.convert_to_fp16(new_key_states) - new_value_states = self.convert_to_fp16(new_value_states) - - self.compile() - - def rotate_half(self, x): - x1 = self.slice(x, [0, 0, 0, 0], [self.batch_size, self.num_heads, self.seq_len, self.head_dim//2], ) - x2 = self.slice(x, [0, 0, 0, self.head_dim//2], [self.batch_size, self.num_heads, self.seq_len, self.head_dim]) - return self.concat(self.negative(x2), x1, axis=-1) - - def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids): - position_ids = self.squeeze(position_ids) - cos = self.gather(cos, self.convert_to_int32(position_ids), self.constant(1), 0) - sin = self.gather(sin, self.convert_to_int32(position_ids), self.constant(1), 0) - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) - - q_embed = self.eltwise_add(self.eltwise_mul(q, cos), self.eltwise_mul(self.rotate_half(q), sin)) - k_embed = self.eltwise_add(self.eltwise_mul(k, cos), self.eltwise_mul(self.rotate_half(k), sin)) - - return q_embed, k_embed - - class LowBitLlamaMultiDecoderlayer(NNFactory): def __init__( self, + # batch_size: int, + # seq_len: int, + # hidden_size: int, hidden_shape: Sequence[int], *shapes, num_heads: int, @@ -281,16 +130,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): num_layers: int, cached_cos, cached_sin, - input_layernorm_weights, - post_attn_layernorm_weights, + input_layernorm_weights=None, + post_attn_layernorm_weights=None, mode: str = "prefill", dtype: np.dtype = np.int8, - max_seq_len: int = 128, + max_seq_len: int = 1024, + transpose_value: bool = False, profile: bool = False, device: str = "NPU", rms_norm_eps, intermediate_size, - **additional_args ): super().__init__(profile, device) self.max_seq_len = max_seq_len @@ -301,6 +150,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): self.batch_size, self.seq_len, self.hidden_size = hidden_shape self.mode = mode self.rms_norm_eps = rms_norm_eps + self.transpose_value = transpose_value cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -309,11 +159,16 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): self.sin = self.unsqueeze(sin, axis=0) if mode == "decode": - invalidInputError(self.seq_len == 1, "seq_len must be 1 for decode mode") + assert self.seq_len == 1, "seq_len must be 1 for decode mode" + self.kv_seq_len = self.max_seq_len + 1 + else: + self.kv_seq_len = self.seq_len + self.num_heads = num_heads self.num_key_value_heads = num_key_value_heads self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads # define input, the order self.parameter matters input = self.parameter((self.batch_size, self.seq_len, self.hidden_size)) @@ -323,22 +178,35 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): attention_mask = self.parameter((self.batch_size, 1, 1, self.max_seq_len + 1)) else: attention_mask = self.parameter((self.batch_size, 1, self.seq_len, self.seq_len)) - + + position_ids = self.parameter((self.batch_size, self.seq_len)) past_keys = [] past_values = [] if mode == "decode": for i in range(num_layers): past_key = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) - past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) + if transpose_value: + past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)) + else: + past_value = self.parameter((self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)) past_keys.append(past_key) past_values.append(past_value) else: - past_key = None - past_value = None + past_keys = [None] * num_layers + past_values = [None] * num_layers + + if input_layernorm_weights is None: + assert post_attn_layernorm_weights is None + input_layernorm_weights = [] + post_attn_layernorm_weights = [] + for i in range(num_layers): + input_layernorm_weights.append(self.parameter((1, self.hidden_size,))) + post_attn_layernorm_weights.append(self.parameter((1, self.hidden_size,))) + else: + input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights] + post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights] - # input_layernorm_weight = self.parameter((1, self.hidden_size,)) - # post_attention_layernorm_weight = self.parameter((1, self.hidden_size,)) hidden_states = input curr_key_values = [] @@ -352,6 +220,7 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): past_value=past_values[i],) curr_key_values.append((new_key_states, new_value_states)) + # define outputs hidden_states = self.convert_to_fp16(hidden_states) @@ -359,7 +228,22 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): new_key_states = self.convert_to_fp16(curr_key_values[i][0]) new_value_states = self.convert_to_fp16(curr_key_values[i][1]) - self.compile() + with FileLock("decoder_compile.lock"): + print("start compiling") + self.compile() + + def repeat_kv(self, hidden_states, n_rep, transpose=False): + if n_rep == 1: + return hidden_states + if not transpose: + hidden_states = self.reshape(hidden_states, [self.batch_size, self.num_key_value_heads, 1, self.kv_seq_len, self.head_dim]) + hidden_states = self.broadcast(hidden_states, [self.batch_size, self.num_key_value_heads, n_rep, self.kv_seq_len, self.head_dim]) + hidden_states = self.reshape(hidden_states, [self.batch_size, n_rep*self.num_key_value_heads, self.kv_seq_len, self.head_dim]) + else: + hidden_states = self.reshape(hidden_states, [self.batch_size, self.num_key_value_heads, 1, self.head_dim, self.kv_seq_len]) + hidden_states = self.broadcast(hidden_states, [self.batch_size, self.num_key_value_heads, n_rep, self.head_dim, self.kv_seq_len]) + hidden_states = self.reshape(hidden_states, [self.batch_size, n_rep*self.num_key_value_heads, self.head_dim, self.kv_seq_len]) + return hidden_states def build_decoder(self, hidden_states, attention_mask, position_ids, input_layernorm_weight, post_attention_layernorm_weight, @@ -372,10 +256,11 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): # input layernorm input_2d = self.convert_to_fp32(input_2d) + # variance = self.reduce_mean(self.eltwise_mul(input_2d, input_2d), -1, keep_dims=True) variance = self.reduce_mean(self.power(input_2d, self.constant(np.array([[2]], dtype=np.float32))), -1, keep_dims=True) eps = self.constant(self.rms_norm_eps) input_2d = self.eltwise_div(input_2d, self.sqrt(self.eltwise_add(variance, eps))) - input_layernorm_weight = self.constant(input_layernorm_weight) + # input_layernorm_weight = self.constant(input_layernorm_weight) input_layernorm_weight = self.convert_to_fp32(input_layernorm_weight) input_2d = self.eltwise_mul(input_layernorm_weight, input_2d) input_2d = self.convert_to_fp16(input_2d) @@ -384,6 +269,12 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): query_states = self.linear(input_2d, self.num_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) key_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) value_states = self.linear(input_2d, self.num_key_value_heads*self.head_dim, self.hidden_size, bias=False, wt_dtype=self.dtype) + + # cos = self.constant(self.cached_cos) + # cos = self.unsqueeze(cos, axis=0) + + # sin = self.constant(self.cached_sin) + # sin = self.unsqueeze(sin, axis=0) query_states = self.reshape(query_states, [self.batch_size, self.seq_len, self.num_heads, self.head_dim]) key_states = self.reshape(key_states, [self.batch_size, self.seq_len, self.num_key_value_heads, self.head_dim]) @@ -391,27 +282,35 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): query_states = self.transpose(query_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3]) - value_states = self.transpose(value_states, [0, 2, 1, 3]) + if self.transpose_value: + value_states = self.transpose(value_states, [0, 2, 3, 1]) + else: + value_states = self.transpose(value_states, [0, 2, 1, 3]) query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, self.cos, self.sin, position_ids) new_key_states = key_states new_value_states = value_states - # repeat_kv cannot be implemented because Broadcast op is needed - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - invalidInputError(self.num_heads == self.num_key_value_heads, "num_heads must be equal to num_key_value_heads") + if self.mode == "decode": key_states = self.concat(past_key, key_states, axis=-2) - value_states = self.concat(past_value, value_states, axis=-2) + if self.transpose_value: + value_states = self.concat(past_value, value_states, axis=-1) + else: + value_states = self.concat(past_value, value_states, axis=-2) + + # repeat_kv cannot be implemented because Broadcast op is needed + key_states = self.repeat_kv(key_states, self.num_key_value_groups) + value_states = self.repeat_kv(value_states, self.num_key_value_groups, self.transpose_value) attn_weight = self.matmul(query_states, key_states, False, True) / (math.sqrt(self.head_dim)) attn_weight = self.eltwise_add(attn_weight, attention_mask) attn_weight = self.convert_to_fp32(attn_weight) attn_weight = self.softmax(attn_weight, -1) attn_weight = self.convert_to_fp16(attn_weight) - attn_output = self.matmul(attn_weight, value_states, False, False) + attn_output = self.matmul(attn_weight, value_states, False, self.transpose_value) + attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.reshape(attn_output, [self.batch_size, self.seq_len, self.hidden_size]) @@ -422,10 +321,12 @@ class LowBitLlamaMultiDecoderlayer(NNFactory): # Fully Connected residual = hidden_states + # post_attention_layernorm forward + hidden_states = self.convert_to_fp32(hidden_states) variance = self.reduce_mean(self.power(hidden_states, self.constant(np.array([[[2]]], dtype=np.float32))), -1, keep_dims=True) hidden_states = self.eltwise_div(hidden_states, self.sqrt(self.eltwise_add(variance, eps))) - post_attention_layernorm_weight = self.constant(post_attention_layernorm_weight) + # post_attention_layernorm_weight = self.constant(post_attention_layernorm_weight) post_attention_layernorm_weight = self.convert_to_fp32(post_attention_layernorm_weight) hidden_states = self.eltwise_mul(post_attention_layernorm_weight, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -472,12 +373,17 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): layer_indexes : List[int], cached_cos, cached_sin, + # rotary_emb, + # batch_size: int, + # seq_len: int, + # hidden_size: int, num_heads: int, head_dim: int, num_key_value_heads: int, rms_norm_eps, intermediate_size, - max_seq_len: int = 128, + max_seq_len: int = 1024, + transpose_value: bool = False, ): super().__init__() @@ -491,38 +397,74 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): self.op_id = str(uuid.uuid4()) # self.layer_idx = layer_idx self.max_seq_len = max_seq_len + self.transpose_value = transpose_value # self.rotary_emb = rotary_emb if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 + assert np_dtype == np.uint8 + assert parameters[0][1].dtype == torch.float16, parameters[0] else: # FP16 Linear - invalidInputError(False, "Please use int4 optimization") + assert False, "should not be here" + np_dtype = np.float16 self.layer_indexes = layer_indexes + self.num_layers_1 = len(self.layer_indexes) // 2 + self.num_layers_0 = len(self.layer_indexes) - self.num_layers_1 + + assert self.num_layers_1 + self.num_layers_0 == len(input_laynorm_weights) + assert self.num_layers_1 + self.num_layers_0 == len(post_attn_layernorm_weights) print("create dedcoder layer") - self.backend_cls_decode = LowBitLlamaMultiDecoderlayer([1, 1, num_heads*head_dim], - input_layernorm_weights=input_laynorm_weights, - post_attn_layernorm_weights=post_attn_layernorm_weights, + self.backend_cls_decode_0 = LowBitLlamaMultiDecoderlayer([1, 1, num_heads*head_dim], + input_layernorm_weights=input_laynorm_weights[:self.num_layers_0], + post_attn_layernorm_weights=post_attn_layernorm_weights[:self.num_layers_0], cached_cos=cached_cos, cached_sin=cached_sin, num_heads=num_heads, num_key_value_heads=num_key_value_heads, - num_layers=len(layer_indexes), + num_layers=self.num_layers_0, max_seq_len=max_seq_len, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, mode="decode", + transpose_value=self.transpose_value, + dtype=np_dtype) + self.backend_cls_decode_1 = LowBitLlamaMultiDecoderlayer([1, 1, num_heads*head_dim], + input_layernorm_weights=input_laynorm_weights[self.num_layers_0:], + post_attn_layernorm_weights=post_attn_layernorm_weights[self.num_layers_0:], + cached_cos=cached_cos, + cached_sin=cached_sin, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + num_layers=self.num_layers_1, + max_seq_len=max_seq_len, + rms_norm_eps=rms_norm_eps, + intermediate_size=intermediate_size, + mode="decode", + transpose_value=self.transpose_value, dtype=np_dtype) print("created dedcoder layer") + + assert (self.num_layers_0 + self.num_layers_1) * 7 == len(op_parameters) - self.backend_cls_decode.setWeights(3+len(layer_indexes)*2, self.op_id, *op_parameters) - print("weight setted") - backend_lib.run(self.backend_cls_decode._mm,) + self.backend_cls_decode_0.setWeights(3+self.num_layers_0*2, self.op_id, *op_parameters[:self.num_layers_0*7]) + backend_lib.run(self.backend_cls_decode_0._mm) + print("first inference done") - self.kv_cache_c_parameter_handel = None + + self.backend_cls_decode_1.setWeights(3+self.num_layers_1*2, self.op_id, *op_parameters[self.num_layers_0*7:]) + + + print("weight setted") + backend_lib.run(self.backend_cls_decode_1._mm) + + print("2nd inference done") + + self.kv_cache_c_parameter_handel = (None, None) self.kv_cache_parameters = None self.kv_cache_prefetched = False + def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -541,8 +483,6 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): torch.Tensor: result """ seq_len = hidden_states.shape[1] - backend_cls = self.backend_cls_decode - pad_len = self.max_seq_len + 1 - attention_mask.size(-1) pad_mask = (0, pad_len) @@ -550,8 +490,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): value=torch.finfo(torch.float16).min) padded_attention_mask[:,:,:,-1] = 0.0 inputs = (hidden_states.to(torch.float16), - padded_attention_mask, - position_ids,) + padded_attention_mask, + position_ids, + ) if self.kv_cache_parameters is None: self.kv_cache_parameters = [] @@ -562,56 +503,76 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): cached_prt = self.kv_cache_parameters[0].storage().data_ptr() current_ptr = past_key_value.key_cache[self.layer_indexes[0]].storage().data_ptr() if cached_prt != current_ptr: + # print("kv cache changed") self.kv_cache_parameters = [] - self.kv_cache_c_parameter_handel = None + self.kv_cache_c_parameter_handel = (None, None) self.kv_cache_prefetched = False if len(self.kv_cache_parameters) == 0: for idx in self.layer_indexes: past_key = past_key_value.key_cache[idx] past_value = past_key_value.value_cache[idx] + + assert past_key.dtype == torch.float16, f"past_key dtype is {past_key.dtype}" + new_size = (past_key.size(0), past_key.size(1), self.max_seq_len, past_key.size(3)) past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) + assert past_key.is_contiguous() past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) + if self.transpose_value: + past_value = past_value.transpose(-1, -2) + assert past_value.is_contiguous() self.kv_cache_parameters.append(past_key) self.kv_cache_parameters.append(past_value) - self.kv_cache_c_parameter_handel = self.backend_cls_decode.create_parameters([p.numpy() for p in self.kv_cache_parameters]) + handle_0 = self.backend_cls_decode_0.create_parameters([p.numpy() for p in self.kv_cache_parameters[:self.num_layers_0*2]]) + handle_1 = self.backend_cls_decode_1.create_parameters([p.numpy() for p in self.kv_cache_parameters[self.num_layers_0*2:]]) + assert len(self.kv_cache_parameters) == (self.num_layers_0 + self.num_layers_1) * 2 + self.kv_cache_c_parameter_handel = (handle_0, handle_1) x_np = [elem.to(torch.float16).numpy() for elem in inputs] + key_value_states = [] + with record_function(f"npu_factory"): if not self.kv_cache_prefetched: - self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel) + self.backend_cls_decode_0.load_wt_fn(len(inputs), self.backend_cls_decode_0._mm, self.kv_cache_c_parameter_handel[0]) + self.backend_cls_decode_1.load_wt_fn(len(inputs), self.backend_cls_decode_1._mm, self.kv_cache_c_parameter_handel[1]) - for idx, elem in enumerate(x_np): - self.backend_cls_decode.set_input_tensor(elem, idx) + models_ptr = (ctypes.POINTER(ctypes.c_char) * 2)(self.backend_cls_decode_0._mm, self.backend_cls_decode_1._mm) + inputs_ptr = (ctypes.c_void_p * 3)(x_np[0].ctypes.data_as(ctypes.c_void_p), x_np[1].ctypes.data_as(ctypes.c_void_p), x_np[2].ctypes.data_as(ctypes.c_void_p)) - backend_lib.run(self.backend_cls_decode._mm,) - ret = self.backend_cls_decode.out - results = [adapt_output_tensor(r, r.shape, torch.float16) for r in ret] + backend_lib.run_decoders(models_ptr, inputs_ptr, 2, 3) - hidden_states = results[0] - key_value_states = results[1:] + for i in range(1, len(self.backend_cls_decode_0.torch_out)): + key_value_states.append(self.backend_cls_decode_0.torch_out[i]) + + for i in range(1, len(self.backend_cls_decode_1.torch_out)): + key_value_states.append(self.backend_cls_decode_1.torch_out[i]) - cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len} + hidden_states = self.backend_cls_decode_1.torch_out[0] + + cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len, "transpose": self.transpose_value} for i in range(len(self.layer_indexes)): key_states, value_states = past_key_value.update(key_value_states[2*i], key_value_states[2*i+1], self.layer_indexes[i], cache_kwargs) - self.backend_cls_decode.load_wt_fn(len(inputs), self.backend_cls_decode._mm, self.kv_cache_c_parameter_handel) + self.backend_cls_decode_0.load_wt_fn(len(inputs), self.backend_cls_decode_0._mm, self.kv_cache_c_parameter_handel[0]) + self.backend_cls_decode_1.load_wt_fn(len(inputs), self.backend_cls_decode_1._mm, self.kv_cache_c_parameter_handel[1]) self.kv_cache_prefetched = True + outputs = (hidden_states,) outputs += (past_key_value,) - return outputs class FusedLlamaLowBitDecoderlayer(torch.nn.Module): + """LLAMA MLP operation NPU backend.""" + def __init__( self, parameters: List[torch.Tensor], @@ -625,41 +586,36 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): rms_norm_eps, intermediate_size, max_seq_len: int = 128, + transpose_value: bool = False, ): super().__init__() self.op_parameters = parameters self.op_id = str(uuid.uuid4()) self.layer_idx = layer_idx self.max_seq_len = max_seq_len + self.transpose_value = transpose_value # self.rotary_emb = rotary_emb if isinstance(parameters[0], tuple): # weight, scale from QuantizedLinear np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 else: # FP16 Linear np_dtype = np.float16 - self.backend_cls_prefill = partial(LowBitLlamaDecoderlayer, - cached_cos=cached_cos, - cached_sin=cached_sin, + self.backend_cls_prefill = partial(LowBitLlamaMultiDecoderlayer, num_heads=num_heads, num_key_value_heads=num_key_value_heads, + num_layers=1, + cached_cos=cached_cos, + cached_sin=cached_sin, + input_layernorm_weights=None, + post_attn_layernorm_weights=None, max_seq_len=max_seq_len, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, mode="prefill", + transpose_value=self.transpose_value, dtype=np_dtype) - self.backend_cls_decode = partial(LowBitLlamaDecoderlayer, - cached_cos=cached_cos, - cached_sin=cached_sin, - num_heads=num_heads, - num_key_value_heads=num_key_value_heads, - max_seq_len=max_seq_len, - rms_norm_eps=rms_norm_eps, - intermediate_size=intermediate_size, - mode="decode", - dtype=np_dtype) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 - def forward(self, hidden_states: torch.Tensor, @@ -670,43 +626,28 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module): use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs,) -> torch.Tensor: + """Torch module forward method. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: result + """ + assert not output_attentions + # assert cache_position is None + # assert use_cache + seq_len = hidden_states.shape[1] - # cos, sin = self.rotary_emb(hidden_states, position_ids) - if seq_len == 1: - backend_cls = self.backend_cls_decode - past_key = past_key_value.key_cache[self.layer_idx] - past_value = past_key_value.value_cache[self.layer_idx] - new_size = (past_key.size(0), - past_key.size(1), - self.max_seq_len, - past_key.size(3)) - past_key = past_key.as_strided(new_size, past_key.stride(), storage_offset=0) - past_value = past_value.as_strided(new_size, past_value.stride(), storage_offset=0) - - pad_len = self.max_seq_len + 1 - attention_mask.size(-1) - - pad_mask = (0, pad_len) - padded_attention_mask = F.pad(attention_mask.to(torch.float16), pad_mask, - value=torch.finfo(torch.float16).min) - padded_attention_mask[:,:,:,-1] = 0.0 - inputs = (hidden_states.to(torch.float16), - padded_attention_mask, - position_ids,) - - inputs += (self.layer_norm_0, self.layer_norm_1) - - inputs += (past_key, past_value) - hidden_states, new_key, new_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=4) - cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len} - key_states, value_states = past_key_value.update(new_key, new_value, self.layer_idx, cache_kwargs) - else: - backend_cls = self.backend_cls_prefill - inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) - inputs += (self.layer_norm_0, self.layer_norm_1) - hidden_states, past_key, past_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=1) - cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len} - key_states, value_states = past_key_value.update(past_key, past_value, self.layer_idx, cache_kwargs) + backend_cls = self.backend_cls_prefill + inputs = (hidden_states.to(torch.float16), attention_mask, position_ids) + inputs += (self.layer_norm_0, self.layer_norm_1) + # print("start run_model prefill") + hidden_states, past_key, past_value = run_model(inputs, self.op_parameters, backend_cls, self.op_id, replica=1) + # print("end run model prefill") + cache_kwargs = {"cache_position": cache_position, "max_seq_len":self.max_seq_len, "transpose": self.transpose_value} + key_states, value_states = past_key_value.update(past_key, past_value, self.layer_idx, cache_kwargs) outputs = (hidden_states,) outputs += (past_key_value,) @@ -722,44 +663,36 @@ if __name__ == "__main__": help='Prompt to infer') parser.add_argument('--n-predict', type=int, default=32, help='Max tokens to predict') + parser.add_argument('--max-seq-len', type=int, default=1024) + parser.add_argument('--transpose-value-cache', action="store_true", default=False) args = parser.parse_args() model_path = args.repo_id_or_model_path tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - pipeline = True # default - max_seq_len = 1024 # default - if pipeline: - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29501' + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29501' - dist.init_process_group() - my_rank = dist.get_rank() - my_size = dist.get_world_size() - logger.info(f"rank: {my_rank}, size: {my_size}") + dist.init_process_group() + my_rank = dist.get_rank() + my_size = dist.get_world_size() + logger.info(f"rank: {my_rank}, size: {my_size}") - model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager", - load_in_low_bit="sym_int4", pipeline_parallel_stages=2) + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, + trust_remote_code=True, attn_implementation="eager", + load_in_low_bit="sym_int4", pipeline_parallel_stages=2) - if my_rank == 0: - print(model) - dist.barrier() + if my_rank == 0: + print(model) + dist.barrier() - if my_rank == 1: - print(model) - else: - model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, attn_implementation="eager", - load_in_low_bit="sym_int4") + if my_rank == 1: + print(model) - if pipeline: - layer_start = model.layer_start - layer_end = model.layer_end - num_layers = model.num_layers - else: - layer_start = 0 - layer_end = 32 - num_layers = 32 + layer_start = model.layer_start + layer_end = model.layer_end + num_layers = model.num_layers num_heads = model.model.layers[layer_start].self_attn.num_heads num_key_value_heads = model.model.layers[layer_start].self_attn.num_key_value_heads head_dim = model.model.layers[layer_start].self_attn.head_dim @@ -776,12 +709,10 @@ if __name__ == "__main__": mlp_layer = curr_layer.mlp weights = [ - # model.model.layers[i].input_layernorm.weight.to(torch.float16), (attn_layer.q_proj.weight, attn_layer.q_proj.scale), (attn_layer.k_proj.weight, attn_layer.k_proj.scale), (attn_layer.v_proj.weight, attn_layer.v_proj.scale), (attn_layer.o_proj.weight, attn_layer.o_proj.scale), - # model.model.layers[i].post_attention_layernorm.weight.to(torch.float16), (mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), (mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), (mlp_layer.down_proj.weight, mlp_layer.down_proj.scale)] @@ -797,13 +728,13 @@ if __name__ == "__main__": num_key_value_heads=num_key_value_heads, cached_cos=cached_cos, cached_sin=cached_sin, - # rotary_emb=model.model.layers[i].self_attn.rotary_emb, layer_norm_0=layer_norm_0, layer_norm_1=layer_norm_1, layer_idx=layer_idx, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, - max_seq_len=max_seq_len) + max_seq_len=args.max_seq_len, + transpose_value=args.transpose_value_cache) layer_weights.extend(weights) input_layer_norm_weights.append(layer_norm_0) @@ -822,7 +753,8 @@ if __name__ == "__main__": num_key_value_heads=num_key_value_heads, rms_norm_eps=rms_norm_eps, intermediate_size=intermediate_size, - max_seq_len=max_seq_len, + max_seq_len=args.max_seq_len, + transpose_value=args.transpose_value_cache ) model.model.multi_decoder = multi_decoder diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 444d55ce..d0df25af 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -86,7 +86,7 @@ class _BaseAutoModelClass: if kwargs.get('torch_dtype', None) not in [None, 'auto', torch.float, torch.float16]: warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used") - kwargs['torch_dtype'] = torch.float + kwargs['torch_dtype'] = torch.float32 low_bit = kwargs.pop('load_in_low_bit', 'sym_int4') qtype_map = { diff --git a/python/llm/src/ipex_llm/transformers/npu_models/kv.py b/python/llm/src/ipex_llm/transformers/npu_models/kv.py index ce5b29ee..36260e8e 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/kv.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/kv.py @@ -18,48 +18,88 @@ import torch from typing import Optional, Dict, Tuple, Any from transformers.cache_utils import DynamicCache +import sys -def init_fused_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): - key_cache_storage = torch.zeros(batch_size, num_heads, - max_length, head_dim, - dtype=dtype, device=device) - value_cache_storage = torch.zeros(batch_size, num_heads, - max_length, head_dim, - dtype=dtype, device=device) +def init_fused_kv_cache(batch_size, num_heads, head_dim, + current_length, max_length, dtype, + device, tranpose_value=False): + if not tranpose_value: + key_cache_storage = torch.zeros(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) + value_cache_storage = torch.zeros(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) - key_cache = key_cache_storage.as_strided((batch_size, num_heads, - current_length, head_dim), - key_cache_storage.stride(), - storage_offset=0) - value_cache = value_cache_storage.as_strided((batch_size, num_heads, - current_length, head_dim), - value_cache_storage.stride(), + key_cache = key_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + key_cache_storage.stride(), storage_offset=0) - return key_cache, value_cache + value_cache = value_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + value_cache_storage.stride(), + storage_offset=0) + return key_cache, value_cache + else: + key_cache_storage = torch.zeros(batch_size, num_heads, + max_length, head_dim, + dtype=dtype, device=device) + value_cache_storage = torch.zeros(batch_size, num_heads, + head_dim, max_length, + dtype=dtype, device=device) + + key_cache = key_cache_storage.as_strided((batch_size, num_heads, + current_length, head_dim), + key_cache_storage.stride(), + storage_offset=0) + value_cache = value_cache_storage.as_strided((batch_size, num_heads, + head_dim, current_length), + value_cache_storage.stride(), + storage_offset=0) + return key_cache, value_cache.transpose(-1, -2) -def append_fused_kv_cache(cache_k, cache_v, key_states, value_states): - new_size = (cache_k.size(0), - cache_k.size(1), - cache_k.size(2) + key_states.size(2), - cache_k.size(3)) - new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) - new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states - new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) - new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states - return new_cache_k, new_cache_v +def append_fused_kv_cache(cache_k, cache_v, key_states, value_states, transpose_value=False): + if not transpose_value: + new_size = (cache_k.size(0), + cache_k.size(1), + cache_k.size(2) + key_states.size(2), + cache_k.size(3)) + new_cache_k = cache_k.as_strided(new_size, cache_k.stride(), storage_offset=0) + new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states + new_cache_v = cache_v.as_strided(new_size, cache_v.stride(), storage_offset=0) + new_cache_v[:, :, cache_v.size(2):cache_v.size(2) + key_states.size(2), :] = value_states + return new_cache_k, new_cache_v + else: + new_size_key = (cache_k.size(0), + cache_k.size(1), + cache_k.size(2) + key_states.size(2), + cache_k.size(3)) + new_cache_k = cache_k.as_strided(new_size_key, cache_k.stride(), storage_offset=0) + new_cache_k[:, :, cache_k.size(2):cache_k.size(2) + key_states.size(2), :] = key_states + + new_size_value = (cache_v.size(0), + cache_v.size(1), + cache_v.size(3), + cache_v.size(2) + value_states.size(3), + ) + raw_cache_v = cache_v.transpose(-1, -2) + new_cache_v = raw_cache_v.as_strided(new_size_value, raw_cache_v.stride(), storage_offset=0) + start = raw_cache_v.size(3) + end = raw_cache_v.size(3) + value_states.size(3) + new_cache_v[:, :, :, start:end] = value_states + return new_cache_k, new_cache_v.transpose(-1, -2) class DynamicFusedNormalCache(DynamicCache): # Experimental support for fused decoderlayer implementation on NPU # Currently only for llama2 - KV_ALLOC_BLOCK_LENGTH = 256 def __init__(self) -> None: self.key_cache: Dict[int, torch.Tensor] = {} self.value_cache: Dict[int, torch.Tensor] = {} - self._seen_tokens = 0 # Used in `generate` to keep how many tokens the cache has seen + self.min_layer_idx = sys.maxsize def update( self, @@ -71,28 +111,21 @@ class DynamicFusedNormalCache(DynamicCache): batch_size, num_heads, seq_len, head_dim = key_states.shape - max_seq_length = cache_kwargs.pop("max_seq_len", None) - transpose_value = cache_kwargs.pop("transpose_value", None) - - if layer_idx == 0 or layer_idx == 16: - if hasattr(self, "_seen_tokens"): - # 4.39 uses `_seen_tokens` - self._seen_tokens += seq_len - else: - # 4.37 uses `seen_tokens` - self.seen_tokens += seq_len + max_seq_length = cache_kwargs["max_seq_len"] if "max_seq_len" in cache_kwargs else None + transpose_value = cache_kwargs["transpose"] if "transpose" in cache_kwargs else False # Update the cache # if len(self.key_cache) <= layer_idx: if layer_idx not in self.key_cache: - max_len = max_seq_length if max_seq_length is not None else key_states.size(2) + \ - self.KV_ALLOC_BLOCK_LENGTH + max_len = max_seq_length k_cache, v_cache = init_fused_kv_cache( batch_size, num_heads, head_dim, 0, max_len, key_states.dtype, key_states.device, + tranpose_value=transpose_value, ) - k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states) + k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states, + transpose_value=transpose_value) self.key_cache[layer_idx] = k_cache self.value_cache[layer_idx] = v_cache @@ -101,7 +134,8 @@ class DynamicFusedNormalCache(DynamicCache): v_cache = self.value_cache[layer_idx] kv_seq_len = k_cache.size(2) + key_states.size(2) - k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states) + k_cache, v_cache = append_fused_kv_cache(k_cache, v_cache, key_states, value_states, + transpose_value=transpose_value) self.key_cache[layer_idx] = k_cache self.value_cache[layer_idx] = v_cache @@ -113,3 +147,11 @@ class DynamicFusedNormalCache(DynamicCache): for idx, layer in self.key_cache.items(): return layer.shape[-2] + + @property + def _seen_tokens(self): + return self.get_seq_length() + + @property + def seen_tokens(self): + return self.get_seq_length()