LLM: better FP16 support for Intel GPUs (#9791)
* initial support * fix * fix style * fix * limi esimd usage condition * refactor code * fix style * small fix * meet code review * small fix
This commit is contained in:
		
							parent
							
								
									7d9f6c6efc
								
							
						
					
					
						commit
						99bddd3ab4
					
				
					 4 changed files with 153 additions and 75 deletions
				
			
		| 
						 | 
				
			
			@ -200,8 +200,8 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            bias=has_bias,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                        )
 | 
			
		||||
                        device_type = module.qweight.data.device.type
 | 
			
		||||
                        invalidInputError(device_type != "meta",
 | 
			
		||||
                        device = module.qweight.data.device
 | 
			
		||||
                        invalidInputError(device.type != "meta",
 | 
			
		||||
                                          "converting from meta device is not supported")
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=convert_gptq(module, awq=is_awq),
 | 
			
		||||
| 
						 | 
				
			
			@ -209,11 +209,11 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                                                 quantized=True,
 | 
			
		||||
                                                 _shape=(out_features, in_features),
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=qtype).to(device_type)
 | 
			
		||||
                                                 qtype=qtype).to(device)
 | 
			
		||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                        if has_bias:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device_type)
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    elif qtype not in [ggml_tensor_qtype["fp16"], ggml_tensor_qtype["bf16"]]:
 | 
			
		||||
                        new_linear = LowBitLinear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
| 
						 | 
				
			
			@ -223,44 +223,39 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            mp_group=mp_group,
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        device_type = module.weight.data.device.type
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # Copy the weights
 | 
			
		||||
                        paramsLowBit = FP4Params(data=module.weight.data,
 | 
			
		||||
                                                 requires_grad=False,
 | 
			
		||||
                                                 quantized=False,
 | 
			
		||||
                                                 _shape=None,
 | 
			
		||||
                                                 convert_shape_only=convert_shape_only,
 | 
			
		||||
                                                 qtype=qtype).to(device_type)
 | 
			
		||||
                                                 qtype=qtype).to(device)
 | 
			
		||||
                        new_linear._parameters['weight'] = paramsLowBit
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device_type)
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    elif qtype == ggml_tensor_qtype["fp16"]:
 | 
			
		||||
                        #  only support two size now
 | 
			
		||||
                        #  may generalize to other sizes
 | 
			
		||||
                        if module.in_features in [4096, 11008]:
 | 
			
		||||
                            # esimd fp16 path
 | 
			
		||||
                            new_linear = FP16Linear(
 | 
			
		||||
                                in_features,
 | 
			
		||||
                                out_features,
 | 
			
		||||
                                qtype,
 | 
			
		||||
                                module.bias is not None,
 | 
			
		||||
                                mp_group=mp_group,
 | 
			
		||||
                            )
 | 
			
		||||
                            device_type = module.weight.data.device.type
 | 
			
		||||
 | 
			
		||||
                            # convert here
 | 
			
		||||
                            m, n = module.weight.data.shape
 | 
			
		||||
                            if module.in_features == 11008:
 | 
			
		||||
                                trans_weight = module.weight.data.reshape(m//8, 8, n)
 | 
			
		||||
                                trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
                            elif module.in_features == 4096:
 | 
			
		||||
                                trans_weight = module.weight.data.reshape(m//16, 16, n)
 | 
			
		||||
                                trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(trans_weight)
 | 
			
		||||
                            if module.bias is not None:
 | 
			
		||||
                                new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                    .to(device_type)
 | 
			
		||||
                        module.to(torch.float16)
 | 
			
		||||
                        new_linear = FP16Linear(
 | 
			
		||||
                            in_features,
 | 
			
		||||
                            out_features,
 | 
			
		||||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                        )
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        from bigdl.llm.transformers.utils import get_ipex_version
 | 
			
		||||
                        if get_ipex_version() < "2.1.10+xpu":
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                        else:
 | 
			
		||||
                            # only from 2.1, ipex provides matmul_bias_out
 | 
			
		||||
                            # so we need to transpose weight
 | 
			
		||||
                            new_weight = module.weight.transpose(0, 1).contiguous()
 | 
			
		||||
                            new_linear._parameters['weight'] = nn.Parameter(new_weight)
 | 
			
		||||
                            new_linear.weight_type = 2
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device)
 | 
			
		||||
                    elif qtype == ggml_tensor_qtype["bf16"]:
 | 
			
		||||
                        module.to(torch.bfloat16)
 | 
			
		||||
                        new_linear = BF16Linear(
 | 
			
		||||
| 
						 | 
				
			
			@ -269,12 +264,12 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
 | 
			
		|||
                            module.bias is not None,
 | 
			
		||||
                            mp_group=mp_group,
 | 
			
		||||
                        )
 | 
			
		||||
                        device_type = module.weight.data.device.type
 | 
			
		||||
                        device = module.weight.data.device
 | 
			
		||||
                        # convert here
 | 
			
		||||
                        new_linear._parameters['weight'] = nn.Parameter(module.weight)
 | 
			
		||||
                        if module.bias is not None:
 | 
			
		||||
                            new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
 | 
			
		||||
                                .to(device_type)
 | 
			
		||||
                                .to(device)
 | 
			
		||||
 | 
			
		||||
                    if new_linear is not None:
 | 
			
		||||
                        if not module.training:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,7 +50,8 @@ from torch import Tensor, device, dtype, nn
 | 
			
		|||
from operator import mul
 | 
			
		||||
from functools import reduce
 | 
			
		||||
from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd
 | 
			
		||||
from bigdl.llm.transformers.utils import get_autocast_dtype
 | 
			
		||||
from bigdl.llm.transformers.utils import get_autocast_dtype, get_xpu_device_type, \
 | 
			
		||||
    get_ipex_version
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T", bound="torch.nn.Module")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -538,57 +539,111 @@ class LowBitLinear(nn.Linear):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class FP16Linear(nn.Linear):
 | 
			
		||||
    def __init__(self, input_features, output_features, qtype, bias=True,
 | 
			
		||||
                 conver_to_half=True, mp_group=None):
 | 
			
		||||
    def __init__(self, input_features, output_features, bias=True,
 | 
			
		||||
                 mp_group=None, weight_type=1):
 | 
			
		||||
        super().__init__(input_features, output_features, bias)
 | 
			
		||||
        self.in_len = input_features
 | 
			
		||||
        self.out_len = output_features
 | 
			
		||||
        self.weight_shape = (self.out_len, self.in_len)
 | 
			
		||||
        self.weight_length = self.out_len * self.in_len
 | 
			
		||||
        self.qtype = qtype
 | 
			
		||||
        self.conver_to_half = conver_to_half
 | 
			
		||||
        self.qtype = ggml_tensor_qtype["fp16"]
 | 
			
		||||
        self.mp_group = mp_group
 | 
			
		||||
        # weigh_type = 1 means original weight
 | 
			
		||||
        # weigh_type = 2 means weight has been transposed
 | 
			
		||||
        # weigh_type = 3 means weight has been transposed by esimd method
 | 
			
		||||
        self.weight_type = 1
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
            self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        x_shape = x.shape
 | 
			
		||||
        x_2d = x.view(-1, x_shape[-1])
 | 
			
		||||
 | 
			
		||||
        x0 = self.weight.data
 | 
			
		||||
        # only work for GPU
 | 
			
		||||
        invalidInputError(x0.device.type == "xpu",
 | 
			
		||||
                          "FP16 only works for GPU")
 | 
			
		||||
        try:
 | 
			
		||||
            import intel_extension_for_pytorch
 | 
			
		||||
            import linear_fp16_esimd
 | 
			
		||||
        except ModuleNotFoundError:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "Please `pip install bigdl_core_xe` first.")
 | 
			
		||||
        invalidInputError(x.device.type == "xpu",
 | 
			
		||||
                          "FP16Linear only works for Intel GPUs")
 | 
			
		||||
        x = x.to(torch.float16)
 | 
			
		||||
        if self.bias is not None and self.bias.dtype != x.dtype:
 | 
			
		||||
                self.bias.data = self.bias.data.to(x.dtype)
 | 
			
		||||
        if self.weight is not None and self.weight.dtype != x.dtype:
 | 
			
		||||
            self.weight.data = self.weight.data.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
        if x_2d.is_contiguous() is False:
 | 
			
		||||
            x_2d = x_2d.contiguous()
 | 
			
		||||
 | 
			
		||||
        if x_2d.shape[0] > 1:
 | 
			
		||||
            # first token or batch size > 1, re-convert weight
 | 
			
		||||
            original_weight = self.weight.data.transpose(1, 2)
 | 
			
		||||
            original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
			
		||||
            result = F.linear(x_2d, original_weight.contiguous())
 | 
			
		||||
            del original_weight
 | 
			
		||||
        if not self.use_esimd_kernel(x):
 | 
			
		||||
            if get_ipex_version() < "2.1.10+xpu":
 | 
			
		||||
                if self.weight_type == 2:
 | 
			
		||||
                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
			
		||||
                    self.weight_type = 1
 | 
			
		||||
                return F.linear(x, self.weight, self.bias)
 | 
			
		||||
            else:
 | 
			
		||||
                if self.weight_type == 1:
 | 
			
		||||
                    self.weight = self.weight.transpose(0, 1).contiguous()
 | 
			
		||||
                    self.weight_type = 2
 | 
			
		||||
                return torch.ops.torch_ipex.matmul_bias_out(x, self.weight, self.bias)
 | 
			
		||||
        else:
 | 
			
		||||
            # rest token, use esimd optimization
 | 
			
		||||
            result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
			
		||||
            if self.weight_type != 3:
 | 
			
		||||
                # convert weight first to use esimd fp16 kernel
 | 
			
		||||
                self.convert_weight_for_esimd_kernel()
 | 
			
		||||
            # esimd fp16 kernel for inference
 | 
			
		||||
            x_shape = x.shape
 | 
			
		||||
            x_2d = x.view(-1, x_shape[-1])
 | 
			
		||||
            if x_2d.is_contiguous() is False:
 | 
			
		||||
                x_2d = x_2d.contiguous()
 | 
			
		||||
 | 
			
		||||
        new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
        result = result.view(new_shape)
 | 
			
		||||
        if self.mp_group is not None:
 | 
			
		||||
            from deepspeed import comm as dist
 | 
			
		||||
            dist.inference_all_reduce(result, group=self.mp_group)
 | 
			
		||||
        if self.bias is not None:
 | 
			
		||||
            result += self.bias
 | 
			
		||||
            try:
 | 
			
		||||
                import intel_extension_for_pytorch
 | 
			
		||||
                import linear_fp16_esimd
 | 
			
		||||
            except ModuleNotFoundError:
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  "Please `pip install bigdl_core_xe_esimd` first.")
 | 
			
		||||
 | 
			
		||||
        return result.to(x.dtype)
 | 
			
		||||
            if x_2d.shape[0] > 1:
 | 
			
		||||
                # first token or batch size > 1, re-convert weight
 | 
			
		||||
                original_weight = self.weight.data.transpose(1, 2)
 | 
			
		||||
                original_weight = original_weight.reshape(self.out_len, self.in_len)
 | 
			
		||||
                result = F.linear(x_2d, original_weight.contiguous())
 | 
			
		||||
                del original_weight
 | 
			
		||||
            else:
 | 
			
		||||
                # rest token, use esimd optimization
 | 
			
		||||
                result = linear_fp16_esimd.forward(x_2d, self.weight.data)
 | 
			
		||||
 | 
			
		||||
            new_shape = x_shape[:-1] + (self.out_len,)
 | 
			
		||||
            result = result.view(new_shape)
 | 
			
		||||
            if self.mp_group is not None:
 | 
			
		||||
                from deepspeed import comm as dist
 | 
			
		||||
                dist.inference_all_reduce(result, group=self.mp_group)
 | 
			
		||||
            if self.bias is not None:
 | 
			
		||||
                result += self.bias
 | 
			
		||||
 | 
			
		||||
            return result.to(x.dtype)
 | 
			
		||||
 | 
			
		||||
    def use_esimd_kernel(self, x):
 | 
			
		||||
        gpu_type = get_xpu_device_type(x)
 | 
			
		||||
        # esimd kernel can only be used for Arc and Flex
 | 
			
		||||
        if gpu_type not in ["arc", "flex"]:
 | 
			
		||||
            return False
 | 
			
		||||
        # now esimd kernel can only be used for specific cases (llama2-7b shape)
 | 
			
		||||
        if self.in_len == 11008 and self.out_features == 4096:
 | 
			
		||||
            return True
 | 
			
		||||
        if self.in_len == 4096 and self.out_features in [4096, 11008]:
 | 
			
		||||
            # seems has some issue with Mistral,
 | 
			
		||||
            # need a further look to check whether can be used for other out features
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def convert_weight_for_esimd_kernel(self):
 | 
			
		||||
        m, n = self.out_len, self.in_len
 | 
			
		||||
        if self.in_len == 11008:
 | 
			
		||||
            if self.weight_type == 2:
 | 
			
		||||
                trans_weight = self.weight.data.transpose(0, 1)
 | 
			
		||||
            else:
 | 
			
		||||
                trans_weight = self.weight.data
 | 
			
		||||
            trans_weight = trans_weight.data.reshape(m//8, 8, n)
 | 
			
		||||
            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
            self.weight.data = trans_weight
 | 
			
		||||
        elif self.in_len == 4096:
 | 
			
		||||
            if self.weight_type == 2:
 | 
			
		||||
                trans_weight = self.weight.data.transpose(0, 1)
 | 
			
		||||
            else:
 | 
			
		||||
                trans_weight = self.weight.data
 | 
			
		||||
            trans_weight = trans_weight.data.reshape(m//16, 16, n)
 | 
			
		||||
            trans_weight = trans_weight.transpose(1, 2).contiguous()
 | 
			
		||||
            self.weight.data = trans_weight
 | 
			
		||||
        self.weight_type = 3
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BF16Linear(nn.Linear):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -100,8 +100,9 @@ def llama_mlp_forward(
 | 
			
		|||
    x: torch.Tensor,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    x_2d = x.view(-1, x.shape[-1])
 | 
			
		||||
    qtype = getattr(self.gate_proj, "qtype", None)
 | 
			
		||||
    if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
 | 
			
		||||
            and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
 | 
			
		||||
            and qtype == ggml_tensor_qtype["sym_int4"] \
 | 
			
		||||
            and not (self.training and x.requires_grad):
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
        if not x_2d.is_contiguous():
 | 
			
		||||
| 
						 | 
				
			
			@ -147,7 +148,8 @@ def llama_attention_forward_4_31(
 | 
			
		|||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
 | 
			
		||||
    is_q4_0 = self.q_proj.qtype == SYM_INT4
 | 
			
		||||
    qtype = getattr(self.q_proj, "qtype", None)
 | 
			
		||||
    is_q4_0 = qtype == SYM_INT4
 | 
			
		||||
    no_tp = not self.config.pretraining_tp > 1
 | 
			
		||||
    decoding_fast_path = (no_tp and is_q4_0 and use_fuse_rope and
 | 
			
		||||
                          enough_kv_room and bsz * q_len == 1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -149,3 +149,29 @@ def get_autocast_dtype(x):
 | 
			
		|||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"Device {x.device} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_ipex_version = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_ipex_version():
 | 
			
		||||
 | 
			
		||||
    global _ipex_version
 | 
			
		||||
    if _ipex_version is not None:
 | 
			
		||||
        return _ipex_version
 | 
			
		||||
 | 
			
		||||
    import intel_extension_for_pytorch as ipex
 | 
			
		||||
    _ipex_version = ipex.__version__
 | 
			
		||||
    return _ipex_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_xpu_device_type(x):
 | 
			
		||||
    name = torch.xpu.get_device_name(x.device.index)
 | 
			
		||||
    if name.startswith("Intel(R) Arc(TM) A"):
 | 
			
		||||
        return "arc"
 | 
			
		||||
    elif name.startswith("Intel(R) Data Center GPU Flex"):
 | 
			
		||||
        return "flex"
 | 
			
		||||
    elif name.startswith("Intel(R) Data Center GPU Max"):
 | 
			
		||||
        return "pvc"
 | 
			
		||||
    else:
 | 
			
		||||
        return "others"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue