LLM: support fp16 embedding & add mlp fusion for iq2_xxs (#10219)
* add fp16 embed * small fixes * fix style * fix style * fix comment
This commit is contained in:
		
							parent
							
								
									eeecd9fc08
								
							
						
					
					
						commit
						28513f3978
					
				
					 6 changed files with 38 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -192,7 +192,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                 current_key_name=None, convert_shape_only=False,
 | 
			
		||||
                                 cpu_embedding=False, prefix_name='',
 | 
			
		||||
                                 imatrix_data=None, embedding_qtype=None,
 | 
			
		||||
                                 model_type=None):
 | 
			
		||||
                                 model_type=None, torch_dtype=torch.float32):
 | 
			
		||||
    from bigdl.llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
 | 
			
		||||
        FP16Linear, BF16Linear
 | 
			
		||||
    from bigdl.llm.transformers.embedding import LLMEmbedding, LowBitEmbedding
 | 
			
		||||
| 
						 | 
				
			
			@ -326,6 +326,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                _weight=module.weight.data,
 | 
			
		||||
            )
 | 
			
		||||
        elif type(module) == nn.Embedding and embedding_qtype is not None:
 | 
			
		||||
            if torch_dtype == "auto":
 | 
			
		||||
                torch_dtype = torch.float32
 | 
			
		||||
            q_embedding = LowBitEmbedding(
 | 
			
		||||
                num_embeddings=module.num_embeddings,
 | 
			
		||||
                embedding_dim=module.embedding_dim,
 | 
			
		||||
| 
						 | 
				
			
			@ -336,6 +338,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                sparse=module.sparse,
 | 
			
		||||
                _weight=module.weight.data,
 | 
			
		||||
                qtype=embedding_qtype,
 | 
			
		||||
                torch_dtype=torch_dtype
 | 
			
		||||
            )
 | 
			
		||||
            device = module.weight.data.device
 | 
			
		||||
            # Copy the weights
 | 
			
		||||
| 
						 | 
				
			
			@ -364,7 +367,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                prefix_name=prefix_name + '.' + name if prefix_name != '' else name,
 | 
			
		||||
                imatrix_data=imatrix_data,
 | 
			
		||||
                embedding_qtype=embedding_qtype,
 | 
			
		||||
                model_type=model_type
 | 
			
		||||
                model_type=model_type,
 | 
			
		||||
                torch_dtype=torch_dtype
 | 
			
		||||
            )
 | 
			
		||||
            has_been_replaced = _flag or has_been_replaced
 | 
			
		||||
    return model, has_been_replaced
 | 
			
		||||
| 
						 | 
				
			
			@ -571,7 +575,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
 | 
			
		|||
        None, convert_shape_only, cpu_embedding,
 | 
			
		||||
        imatrix_data=imatrix_data,
 | 
			
		||||
        embedding_qtype=embedding_qtype,
 | 
			
		||||
        model_type=model_type
 | 
			
		||||
        model_type=model_type,
 | 
			
		||||
        torch_dtype=torch_dtype
 | 
			
		||||
    )
 | 
			
		||||
    if not has_been_replaced:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -88,7 +88,8 @@ class LowBitEmbedding(torch.nn.Embedding):
 | 
			
		|||
                 _weight: Optional[Tensor] = None,
 | 
			
		||||
                 _freeze: bool = False,
 | 
			
		||||
                 device=None, dtype=None,
 | 
			
		||||
                 qtype=None) -> None:
 | 
			
		||||
                 qtype=None,
 | 
			
		||||
                 torch_dtype=torch.float32) -> None:
 | 
			
		||||
        super().__init__(num_embeddings, embedding_dim, padding_idx,
 | 
			
		||||
                         max_norm, norm_type, scale_grad_by_freq, sparse,
 | 
			
		||||
                         _weight, device, dtype)
 | 
			
		||||
| 
						 | 
				
			
			@ -96,6 +97,7 @@ class LowBitEmbedding(torch.nn.Embedding):
 | 
			
		|||
                                requires_grad=False,
 | 
			
		||||
                                quantized=False, _shape=None, qtype=qtype)
 | 
			
		||||
        self.embedding_dim = embedding_dim
 | 
			
		||||
        self.torch_dtype = torch_dtype
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: Tensor):
 | 
			
		||||
        invalidInputError(x.device.type == "xpu",
 | 
			
		||||
| 
						 | 
				
			
			@ -109,4 +111,4 @@ class LowBitEmbedding(torch.nn.Embedding):
 | 
			
		|||
 | 
			
		||||
        result = linear_q4_0.dequantize_rows(x.contiguous(), self.weight.data,
 | 
			
		||||
                                             self.weight.qtype, self.embedding_dim)
 | 
			
		||||
        return result
 | 
			
		||||
        return result.to(self.torch_dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -339,8 +339,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        invalidInputError(q_k in ggml_tensor_qtype,
 | 
			
		||||
                          f"Unknown load_in_low_bit value: {q_k}, expected:"
 | 
			
		||||
                          f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, "
 | 
			
		||||
                          f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16,  bf16, iq2_xxs, iq2_xs, "
 | 
			
		||||
                          f"mixed_fp4 or mixed_fp8.")
 | 
			
		||||
                          f"fp4, fp8, fp8_e4m3, fp8_e5m2, fp16,  bf16, gguf_iq2_xxs, "
 | 
			
		||||
                          f"gguf_iq2_xs, mixed_fp4 or mixed_fp8.")
 | 
			
		||||
        qtype = ggml_tensor_qtype[q_k]
 | 
			
		||||
 | 
			
		||||
        # In case it needs a second try,
 | 
			
		||||
| 
						 | 
				
			
			@ -510,7 +510,8 @@ class _BaseAutoModelClass:
 | 
			
		|||
        optimize_model = kwargs.pop("optimize_model", True)
 | 
			
		||||
 | 
			
		||||
        qtype = ggml_tensor_qtype[bigdl_transformers_low_bit]
 | 
			
		||||
        if bigdl_transformers_low_bit in ["iq2_xxs", "iq2_xs", "q2_k"] and not cpu_embedding:
 | 
			
		||||
        if bigdl_transformers_low_bit in ["gguf_iq2_xxs", "gguf_iq2_xs", "q2_k"] and \
 | 
			
		||||
                not cpu_embedding:
 | 
			
		||||
            embedding_qtype = "q2_k"
 | 
			
		||||
        if embedding_qtype is not None:
 | 
			
		||||
            embedding_qtype = ggml_tensor_qtype[embedding_qtype]
 | 
			
		||||
| 
						 | 
				
			
			@ -595,7 +596,7 @@ class _BaseAutoModelClass:
 | 
			
		|||
        model = ggml_convert_low_bit(model, qtype, optimize_model, device=quant_device,
 | 
			
		||||
                                     modules_to_not_convert=modules_to_not_convert,
 | 
			
		||||
                                     cpu_embedding=cpu_embedding, lightweight_bmm=lightweight_bmm,
 | 
			
		||||
                                     embedding_qtype=embedding_qtype)
 | 
			
		||||
                                     embedding_qtype=embedding_qtype, torch_dtype=torch_dtype)
 | 
			
		||||
 | 
			
		||||
        if is_sharded:
 | 
			
		||||
            loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,7 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, \
 | 
			
		|||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
 | 
			
		||||
    is_enough_kv_cache_room_4_36
 | 
			
		||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4, FP8E5
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -242,6 +242,15 @@ def mistral_attention_forward(
 | 
			
		|||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, hidden_size)
 | 
			
		||||
    elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.reshape(bsz, q_len, hidden_size)
 | 
			
		||||
    else:
 | 
			
		||||
        attn_output, attn_weights = compute_attn_outputs_weights(query_states,
 | 
			
		||||
                                                                 key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -49,7 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
 | 
			
		|||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
 | 
			
		||||
    apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36
 | 
			
		||||
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -272,6 +272,14 @@ def mixtral_attention_forward(
 | 
			
		|||
                                                     value_states,
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(query_states.shape[2], key_states.shape[2],
 | 
			
		||||
                       self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(
 | 
			
		||||
            query_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -310,7 +310,8 @@ def mlp_fusion_check(x, qtype, training):
 | 
			
		|||
        return False
 | 
			
		||||
    if x.device.type != 'xpu':
 | 
			
		||||
        return False
 | 
			
		||||
    if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]:
 | 
			
		||||
    if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"],
 | 
			
		||||
                     ggml_tensor_qtype["gguf_iq2_xxs"]]:
 | 
			
		||||
        return False
 | 
			
		||||
    if training or x.requires_grad:
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue