diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index d969ac05..ca26f416 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -567,6 +567,7 @@ def _optimize_post(model, lightweight_bmm=False): from bigdl.llm.transformers.models.llama import llama_model_selective_batching_forward_4_31 from bigdl.llm.transformers.models.llama import llama_rms_norm_forward from bigdl.llm.transformers.models.llama import llama_mlp_forward + from bigdl.llm.transformers.models.llama import llama_decoder_forward from transformers.modeling_utils import PreTrainedModel # All huggingface format models are inherited from `PreTrainedModel` @@ -588,6 +589,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, transformers.models.llama.modeling_llama.LlamaMLP, llama_mlp_forward) + convert_forward(model, + transformers.models.llama.modeling_llama.LlamaDecoderLayer, + llama_decoder_forward) if version.parse(trans_version) >= version.parse("4.36.0"): # transformers version >= 4.36.0 from bigdl.llm.transformers.models.llama import llama_attention_forward_4_36 diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index df1deff4..b773c424 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -44,7 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp -from bigdl.llm.transformers.models.utils import mlp_fusion_check +from bigdl.llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5 from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -104,19 +104,47 @@ def llama_rms_norm_forward(self, hidden_states): def llama_mlp_forward( self, x: torch.Tensor, + residual=None ) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) + bsz, hidden_size = x_2d.shape qtype = getattr(self.gate_proj, "qtype", None) if mlp_fusion_check(x_2d, qtype, self.training): import linear_q4_0 if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() - return self.down_proj(linear_q4_0.mlp_forward_xpu( + out = self.down_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, qtype )) - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if residual is not None: + return out + residual + else: + return out + elif fp16_fusion_check(self.gate_proj, x, self.training) and \ + hidden_size == 4096 and bsz == 1: + hidden_states1 = torch.ops.torch_ipex.mm_silu(x, self.gate_proj.weight) + hidden_states = torch.ops.torch_ipex.mm_resmul( + x, self.up_proj.weight, hidden_states1 + ) + if residual is None: + hidden_states = torch.matmul(hidden_states, self.down_proj.weight) + else: + attn_output = torch.addmm( + residual.flatten(0, -2), + hidden_states.flatten(0, -2), + self.down_proj.weight, + beta=1, + ) + hidden_states = attn_output.view(x.shape) + return hidden_states + else: + out = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + if residual is not None: + return out + residual + else: + return out def should_use_fuse_rope(self, query_states, position_ids): @@ -136,6 +164,56 @@ def should_use_fast_rope(self, query_states, position_ids): return use_fuse_rope +def llama_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37." + "Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # add residual into mlp + hidden_states = self.mlp(hidden_states, residual) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + def llama_attention_forward_4_31( self, hidden_states: torch.Tensor, @@ -147,7 +225,7 @@ def llama_attention_forward_4_31( padding_mask: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype @@ -202,9 +280,31 @@ def llama_attention_forward_4_31( for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ + hidden_size == 4096: + # only use mm_qkv_out on pvc for llama-7b + if not hasattr(self, "qkv_proj_weight"): + self.qkv_proj_weight = torch.stack([self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight]).contiguous() + self.q_proj.weight.data = self.qkv_proj_weight[0, :, :] + self.k_proj.weight.data = self.qkv_proj_weight[1, :, :] + self.v_proj.weight.data = self.qkv_proj_weight[2, :, :] + torch.xpu.empty_cache() + query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, + device=hidden_states.device) + key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, + device=hidden_states.device) + value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.torch_ipex.mm_qkv_out( + hidden_states, self.qkv_proj_weight, None, + query_states, key_states, value_states + ) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -598,9 +698,31 @@ def llama_attention_forward_4_36( for i in range(self.config.pretraining_tp)] value_states = torch.cat(value_states, dim=-1) else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ + hidden_size == 4096: + # only use mm_qkv_out on pvc for llama-7b + if not hasattr(self, "qkv_proj_weight"): + self.qkv_proj_weight = torch.stack([self.q_proj.weight, + self.k_proj.weight, + self.v_proj.weight]).contiguous() + self.q_proj.weight.data = self.qkv_proj_weight[0, :, :] + self.k_proj.weight.data = self.qkv_proj_weight[1, :, :] + self.v_proj.weight.data = self.qkv_proj_weight[2, :, :] + torch.xpu.empty_cache() + query_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, + device=hidden_states.device) + key_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, + device=hidden_states.device) + value_states = torch.empty(bsz, q_len, hidden_size, dtype=hidden_states.dtype, + device=hidden_states.device) + torch.ops.torch_ipex.mm_qkv_out( + hidden_states, self.qkv_proj_weight, None, + query_states, key_states, value_states + ) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index bb0682c8..b0aefa84 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -312,3 +312,19 @@ def use_fused_layer_norm(x: torch.Tensor, training: bool): or x.numel() // x.size(-1) == 1 ) ) + + +def fp16_fusion_check(proj, x, training): + # only use fp16 fusion on PVC inference + if proj.qtype != ggml_tensor_qtype["fp16"]: + return False + if proj.weight_type != 2: + return False + if training: + return False + if x.requires_grad: + return False + device_type = get_xpu_device_type(x) + if device_type != "pvc": + return False + return True diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 3361b9a7..89c93b55 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -166,6 +166,8 @@ def get_ipex_version(): def get_xpu_device_type(x): + if x.device.type != "xpu": + return x.device.type name = torch.xpu.get_device_name(x.device.index) if name.startswith("Intel(R) Arc(TM) A"): return "arc"