Update vllm patch for fix telechat2 and baichuan2 error(#13150)
This commit is contained in:
parent
9da1c56fa8
commit
5df03ced2c
1 changed files with 409 additions and 141 deletions
|
|
@ -8692,7 +8692,7 @@ index 000000000..e98db9b65
|
||||||
+ tensor_parallel_size=1,
|
+ tensor_parallel_size=1,
|
||||||
+ )
|
+ )
|
||||||
diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py
|
diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py
|
||||||
index c3d210c27..c3b6ca7eb 100644
|
index c3d210c27..8dd101608 100644
|
||||||
--- a/vllm/_ipex_ops.py
|
--- a/vllm/_ipex_ops.py
|
||||||
+++ b/vllm/_ipex_ops.py
|
+++ b/vllm/_ipex_ops.py
|
||||||
@@ -1,6 +1,4 @@
|
@@ -1,6 +1,4 @@
|
||||||
|
|
@ -8929,7 +8929,7 @@ index c3d210c27..c3b6ca7eb 100644
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def varlen_attention(
|
def varlen_attention(
|
||||||
@@ -220,22 +262,233 @@ class ipex_ops:
|
@@ -220,22 +262,250 @@ class ipex_ops:
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
k_scale: float,
|
k_scale: float,
|
||||||
v_scale: float,
|
v_scale: float,
|
||||||
|
|
@ -9044,30 +9044,47 @@ index c3d210c27..c3b6ca7eb 100644
|
||||||
+ p_dropout: float,
|
+ p_dropout: float,
|
||||||
+ softmax_scale: float,
|
+ softmax_scale: float,
|
||||||
+ zero_tensors: bool,
|
+ zero_tensors: bool,
|
||||||
+ is_caual: bool,
|
+ is_casual: bool,
|
||||||
+ return_softmax: bool,
|
+ return_softmax: bool,
|
||||||
+ gen_: Optional[torch.Generator],
|
+ gen_: Optional[torch.Generator],
|
||||||
+ ):
|
+ ):
|
||||||
+ return torch.ops.torch_ipex.chunked_prefill(
|
+ return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
|
||||||
|
+ output,
|
||||||
+ query.contiguous(),
|
+ query.contiguous(),
|
||||||
+ key_cache,
|
+ key_cache,
|
||||||
+ value_cache,
|
+ value_cache,
|
||||||
+ output,
|
|
||||||
+ cu_seqlens_q,
|
+ cu_seqlens_q,
|
||||||
+ cu_seqlens_k,
|
+ cu_seqlens_k,
|
||||||
+ seq_used_k,
|
|
||||||
+ block_table,
|
|
||||||
+ alibi_slopes,
|
|
||||||
+ max_seqlen_q,
|
+ max_seqlen_q,
|
||||||
+ max_seqlen_k,
|
+ max_seqlen_k,
|
||||||
+ p_dropout,
|
|
||||||
+ softmax_scale,
|
+ softmax_scale,
|
||||||
+ zero_tensors,
|
+ is_casual,
|
||||||
+ is_caual,
|
+ block_table,
|
||||||
+ return_softmax,
|
+ alibi_slopes,
|
||||||
+ gen_,
|
+ k_scale=1.0,
|
||||||
|
+ v_scale=1.0,
|
||||||
)
|
)
|
||||||
|
+ # return torch.ops.torch_ipex.chunked_prefill(
|
||||||
|
+ # query.contiguous(),
|
||||||
|
+ # key_cache,
|
||||||
|
+ # value_cache,
|
||||||
|
+ # output,
|
||||||
|
+ # cu_seqlens_q,
|
||||||
|
+ # cu_seqlens_k,
|
||||||
|
+ # seq_used_k,
|
||||||
|
+ # block_table,
|
||||||
|
+ # alibi_slopes,
|
||||||
|
+ # max_seqlen_q,
|
||||||
|
+ # max_seqlen_k,
|
||||||
|
+ # p_dropout,
|
||||||
|
+ # softmax_scale,
|
||||||
|
+ # zero_tensors,
|
||||||
|
+ # is_caual,
|
||||||
|
+ # return_softmax,
|
||||||
|
+ # gen_,
|
||||||
|
+ # )
|
||||||
|
+
|
||||||
|
+
|
||||||
+ @staticmethod
|
+ @staticmethod
|
||||||
+ def copy_blocks(key_caches: List[torch.Tensor],
|
+ def copy_blocks(key_caches: List[torch.Tensor],
|
||||||
+ value_caches: List[torch.Tensor],
|
+ value_caches: List[torch.Tensor],
|
||||||
|
|
@ -9078,7 +9095,7 @@ index c3d210c27..c3b6ca7eb 100644
|
||||||
+ # block_mapping,
|
+ # block_mapping,
|
||||||
+ # )
|
+ # )
|
||||||
+ vllm._C.cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
+ vllm._C.cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||||
+
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||||
block_mapping: torch.Tensor) -> None:
|
block_mapping: torch.Tensor) -> None:
|
||||||
|
|
@ -11666,6 +11683,143 @@ index 5649cf2dd..66e30984e 100644
|
||||||
if isinstance(load_config.load_format, type):
|
if isinstance(load_config.load_format, type):
|
||||||
return load_config.load_format(load_config)
|
return load_config.load_format(load_config)
|
||||||
|
|
||||||
|
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
|
||||||
|
index 6a3112b5f..7e2b7c862 100644
|
||||||
|
--- a/vllm/model_executor/models/baichuan.py
|
||||||
|
+++ b/vllm/model_executor/models/baichuan.py
|
||||||
|
@@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||||
|
-from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||||
|
+from .utils import (is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, make_layers)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -321,45 +321,6 @@ class BaiChuanModel(nn.Module):
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
- def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
- torch.Tensor]]) -> Set[str]:
|
||||||
|
- stacked_params_mapping = [
|
||||||
|
- # (param_name, shard_name, shard_id)
|
||||||
|
- ("gate_up_proj", "gate_proj", 0),
|
||||||
|
- ("gate_up_proj", "up_proj", 1),
|
||||||
|
- ]
|
||||||
|
- params_dict = dict(self.named_parameters())
|
||||||
|
- loaded_params: Set[str] = set()
|
||||||
|
- for name, loaded_weight in weights:
|
||||||
|
- if "rotary_emb.inv_freq" in name:
|
||||||
|
- continue
|
||||||
|
-
|
||||||
|
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
- if weight_name not in name:
|
||||||
|
- continue
|
||||||
|
- name = name.replace(weight_name, param_name)
|
||||||
|
- # Skip loading extra bias for GPTQ models.
|
||||||
|
- if name.endswith(".bias") and name not in params_dict:
|
||||||
|
- continue
|
||||||
|
- if is_pp_missing_parameter(name, self):
|
||||||
|
- continue
|
||||||
|
- param = params_dict[name]
|
||||||
|
- weight_loader = param.weight_loader
|
||||||
|
- weight_loader(param, loaded_weight, shard_id)
|
||||||
|
- break
|
||||||
|
- else:
|
||||||
|
- # Skip loading extra bias for GPTQ models.
|
||||||
|
- if name.endswith(".bias") and name not in params_dict:
|
||||||
|
- continue
|
||||||
|
- if is_pp_missing_parameter(name, self):
|
||||||
|
- continue
|
||||||
|
- param = params_dict[name]
|
||||||
|
- weight_loader = getattr(param, "weight_loader",
|
||||||
|
- default_weight_loader)
|
||||||
|
- weight_loader(param, loaded_weight)
|
||||||
|
- loaded_params.add(name)
|
||||||
|
- return loaded_params
|
||||||
|
-
|
||||||
|
|
||||||
|
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||||
|
SupportsQuant):
|
||||||
|
@@ -392,7 +353,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
- self.lm_head.weight.weight_loader = self.lm_head_weight_loader
|
||||||
|
if self.config.tie_word_embeddings:
|
||||||
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
@@ -433,22 +393,53 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
- loader = AutoWeightsLoader(self)
|
||||||
|
- return loader.load_weights(weights)
|
||||||
|
-
|
||||||
|
- def lm_head_weight_loader(self, param: nn.Parameter,
|
||||||
|
- loaded_weight: torch.Tensor):
|
||||||
|
- # Unlike Baichuan, Baichuan2 normalizes the head weights.
|
||||||
|
- # Refer to:
|
||||||
|
- # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
||||||
|
- # Distinguish between Baichuan and Baichuan2 by checking the
|
||||||
|
- # vocab size. This is suggested by
|
||||||
|
- # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
||||||
|
- is_baichuan2 = self.config.vocab_size == 125696
|
||||||
|
- if is_baichuan2:
|
||||||
|
- loaded_weight = torch.nn.functional.normalize(loaded_weight)
|
||||||
|
-
|
||||||
|
- default_weight_loader(param, loaded_weight)
|
||||||
|
+ stacked_params_mapping = [
|
||||||
|
+ # (param_name, shard_name, shard_id)
|
||||||
|
+ ("gate_up_proj", "gate_proj", 0),
|
||||||
|
+ ("gate_up_proj", "up_proj", 1),
|
||||||
|
+ ]
|
||||||
|
+ params_dict = dict(self.named_parameters())
|
||||||
|
+ loaded_params: Set[str] = set()
|
||||||
|
+ for name, loaded_weight in weights:
|
||||||
|
+ if "rotary_emb.inv_freq" in name:
|
||||||
|
+ continue
|
||||||
|
+ if name == "lm_head.weight":
|
||||||
|
+ # Unlike Baichuan, Baichuan2 normalizes the head weights.
|
||||||
|
+ # Refer to:
|
||||||
|
+ # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
|
||||||
|
+ # Distinguish between Baichuan and Baichuan2 by checking the
|
||||||
|
+ # vocab size. This is suggested by
|
||||||
|
+ # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
|
||||||
|
+ is_baichuan2 = self.config.vocab_size == 125696
|
||||||
|
+ if is_baichuan2:
|
||||||
|
+ loaded_weight = torch.nn.functional.normalize(
|
||||||
|
+ loaded_weight)
|
||||||
|
+
|
||||||
|
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
+ if weight_name not in name:
|
||||||
|
+ continue
|
||||||
|
+ name = name.replace(weight_name, param_name)
|
||||||
|
+ # Skip loading extra bias for GPTQ models.
|
||||||
|
+ if name.endswith(".bias") and name not in params_dict:
|
||||||
|
+ continue
|
||||||
|
+ if is_pp_missing_parameter(name, self):
|
||||||
|
+ continue
|
||||||
|
+ param = params_dict[name]
|
||||||
|
+ weight_loader = param.weight_loader
|
||||||
|
+ weight_loader(param, loaded_weight, shard_id)
|
||||||
|
+ break
|
||||||
|
+ else:
|
||||||
|
+ # Skip loading extra bias for GPTQ models.
|
||||||
|
+ if name.endswith(".bias") and name not in params_dict:
|
||||||
|
+ continue
|
||||||
|
+ if is_pp_missing_parameter(name, self):
|
||||||
|
+ continue
|
||||||
|
+ param = params_dict[name]
|
||||||
|
+ weight_loader = getattr(param, "weight_loader",
|
||||||
|
+ default_weight_loader)
|
||||||
|
+ weight_loader(param, loaded_weight)
|
||||||
|
+ loaded_params.add(name)
|
||||||
|
+ return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||||
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
|
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
|
||||||
index 1b1738f88..2c2ed67b9 100644
|
index 1b1738f88..2c2ed67b9 100644
|
||||||
--- a/vllm/model_executor/models/chatglm.py
|
--- a/vllm/model_executor/models/chatglm.py
|
||||||
|
|
@ -14147,7 +14301,7 @@ index c0a3c59ba..8614c2273 100644
|
||||||
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
|
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
|
||||||
index cecad9e89..df4cf4776 100644
|
index cecad9e89..7eaabd1db 100644
|
||||||
--- a/vllm/model_executor/models/siglip.py
|
--- a/vllm/model_executor/models/siglip.py
|
||||||
+++ b/vllm/model_executor/models/siglip.py
|
+++ b/vllm/model_executor/models/siglip.py
|
||||||
@@ -140,6 +140,74 @@ class SiglipVisionEmbeddings(nn.Module):
|
@@ -140,6 +140,74 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||||
|
|
@ -14195,9 +14349,9 @@ index cecad9e89..df4cf4776 100644
|
||||||
+
|
+
|
||||||
+ query, key, value = (x.transpose(1, 2)
|
+ query, key, value = (x.transpose(1, 2)
|
||||||
+ for x in (query, key, value))
|
+ for x in (query, key, value))
|
||||||
+ from ipex_llm.transformers.models.utils import use_sdp_causal
|
|
||||||
+ from vllm.attention.backends.ipex_attn import use_sdp_causal
|
+ from vllm.attention.backends.ipex_attn import use_sdp_causal
|
||||||
+ import xe_addons, math
|
+ import xe_addons, math
|
||||||
|
+ from vllm.attention.backends.abstract import AttentionType
|
||||||
+ mask = None
|
+ mask = None
|
||||||
+ scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale
|
+ scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale
|
||||||
+ from ipex_llm.transformers.models.common import padding_qkv_hd
|
+ from ipex_llm.transformers.models.common import padding_qkv_hd
|
||||||
|
|
@ -14209,7 +14363,7 @@ index cecad9e89..df4cf4776 100644
|
||||||
+ query, key, value,
|
+ query, key, value,
|
||||||
+ self.head_size, num
|
+ self.head_size, num
|
||||||
+ )
|
+ )
|
||||||
+ if use_sdp_causal(query.shape[-1], query, 0):
|
+ if use_sdp_causal(query.shape[-1], query, 0, AttentionType.DECODER):
|
||||||
+ out = xe_addons.sdp_non_causal(query.contiguous(), key.contiguous(), value.contiguous(), mask, scale)[:, :, :, :self.head_size].transpose(1, 2)
|
+ out = xe_addons.sdp_non_causal(query.contiguous(), key.contiguous(), value.contiguous(), mask, scale)[:, :, :, :self.head_size].transpose(1, 2)
|
||||||
+ # import torch.nn.functional as F
|
+ # import torch.nn.functional as F
|
||||||
+ # out = F.scaled_dot_product_attention(query,
|
+ # out = F.scaled_dot_product_attention(query,
|
||||||
|
|
@ -14239,10 +14393,23 @@ index cecad9e89..df4cf4776 100644
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py
|
diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py
|
||||||
index a38035e37..9631fbd83 100644
|
index a38035e37..570f2bcdd 100644
|
||||||
--- a/vllm/model_executor/models/telechat2.py
|
--- a/vllm/model_executor/models/telechat2.py
|
||||||
+++ b/vllm/model_executor/models/telechat2.py
|
+++ b/vllm/model_executor/models/telechat2.py
|
||||||
@@ -44,9 +44,9 @@ class TeleChat2Model(LlamaModel):
|
@@ -22,10 +22,12 @@
|
||||||
|
from typing import Iterable, Set, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
+import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
|
||||||
|
+from .llama import LlamaDecoderLayer
|
||||||
|
|
||||||
|
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
|
is_pp_missing_parameter)
|
||||||
|
@@ -44,9 +46,9 @@ class TeleChat2Model(LlamaModel):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
if not isinstance(layer, PPMissingLayer):
|
if not isinstance(layer, PPMissingLayer):
|
||||||
layer.self_attn.qkv_proj.bias = None
|
layer.self_attn.qkv_proj.bias = None
|
||||||
|
|
@ -14254,6 +14421,18 @@ index a38035e37..9631fbd83 100644
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
@@ -120,7 +122,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
- def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
+ def _init_model(self,
|
||||||
|
+ vllm_config: VllmConfig,
|
||||||
|
+ prefix: str = "",
|
||||||
|
+ layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||||
|
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
|
diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
|
||||||
index fc0fb8929..6454e7006 100644
|
index fc0fb8929..6454e7006 100644
|
||||||
--- a/vllm/multimodal/utils.py
|
--- a/vllm/multimodal/utils.py
|
||||||
|
|
@ -14319,7 +14498,7 @@ index b6f6029de..b90fea9fd 100644
|
||||||
def is_neuron(self) -> bool:
|
def is_neuron(self) -> bool:
|
||||||
return self._enum == PlatformEnum.NEURON
|
return self._enum == PlatformEnum.NEURON
|
||||||
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
|
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
|
||||||
index 225e756cd..4fd7fe220 100644
|
index 225e756cd..25b83549a 100644
|
||||||
--- a/vllm/platforms/xpu.py
|
--- a/vllm/platforms/xpu.py
|
||||||
+++ b/vllm/platforms/xpu.py
|
+++ b/vllm/platforms/xpu.py
|
||||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
|
||||||
|
|
@ -14330,7 +14509,17 @@ index 225e756cd..4fd7fe220 100644
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
@@ -33,8 +34,13 @@ class XPUPlatform(Platform):
|
@@ -25,6 +26,9 @@ class XPUPlatform(Platform):
|
||||||
|
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
|
||||||
|
ray_device_key: str = "GPU"
|
||||||
|
device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"
|
||||||
|
+ additional_env_vars: list[str] = [
|
||||||
|
+ "IPEX_LLM_LOWBIT",
|
||||||
|
+ ]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
@@ -33,8 +37,13 @@ class XPUPlatform(Platform):
|
||||||
use_mla: bool) -> str:
|
use_mla: bool) -> str:
|
||||||
if selected_backend != _Backend.IPEX:
|
if selected_backend != _Backend.IPEX:
|
||||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||||
|
|
@ -14346,7 +14535,7 @@ index 225e756cd..4fd7fe220 100644
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_capability(
|
def get_device_capability(
|
||||||
@@ -63,6 +69,8 @@ class XPUPlatform(Platform):
|
@@ -63,6 +72,8 @@ class XPUPlatform(Platform):
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
|
@ -14355,7 +14544,7 @@ index 225e756cd..4fd7fe220 100644
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|
||||||
@@ -87,31 +95,46 @@ class XPUPlatform(Platform):
|
@@ -87,31 +98,46 @@ class XPUPlatform(Platform):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"XPU does not support speculative decoding")
|
"XPU does not support speculative decoding")
|
||||||
|
|
||||||
|
|
@ -14412,6 +14601,15 @@ index 225e756cd..4fd7fe220 100644
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
|
@@ -140,3 +166,7 @@ class XPUPlatform(Platform):
|
||||||
|
@classmethod
|
||||||
|
def get_device_communicator_cls(cls) -> str:
|
||||||
|
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa
|
||||||
|
+
|
||||||
|
+ @classmethod
|
||||||
|
+ def use_all_gather(cls) -> bool:
|
||||||
|
+ return False
|
||||||
|
\ No newline at end of file
|
||||||
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
|
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
|
||||||
index 53699341b..6bc039068 100644
|
index 53699341b..6bc039068 100644
|
||||||
--- a/vllm/transformers_utils/configs/__init__.py
|
--- a/vllm/transformers_utils/configs/__init__.py
|
||||||
|
|
@ -14432,6 +14630,35 @@ index 53699341b..6bc039068 100644
|
||||||
"ChatGLMConfig",
|
"ChatGLMConfig",
|
||||||
"Cohere2Config",
|
"Cohere2Config",
|
||||||
"DbrxConfig",
|
"DbrxConfig",
|
||||||
|
diff --git a/vllm/utils.py b/vllm/utils.py
|
||||||
|
index 5f32f8cb6..2ee0c1906 100644
|
||||||
|
--- a/vllm/utils.py
|
||||||
|
+++ b/vllm/utils.py
|
||||||
|
@@ -128,6 +128,8 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||||
|
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||||
|
}
|
||||||
|
|
||||||
|
+BMG_TARGET_IDS = ["0xe20b", "0xe210"]
|
||||||
|
+
|
||||||
|
# Constants related to forcing the attention backend selection
|
||||||
|
|
||||||
|
# String name of register which may be set in order to
|
||||||
|
@@ -2564,3 +2566,14 @@ def sha256(input) -> int:
|
||||||
|
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||||
|
byteorder="big")
|
||||||
|
+
|
||||||
|
+@cache
|
||||||
|
+def is_bmg_platform():
|
||||||
|
+ if not torch.xpu.is_available():
|
||||||
|
+ raise ValueError("Cannot detect the usage of XPU!")
|
||||||
|
+ device_index = torch.xpu.current_device()
|
||||||
|
+ device_name = torch.xpu.get_device_name(device_index)
|
||||||
|
+ for target_id in BMG_TARGET_IDS:
|
||||||
|
+ if target_id in device_name:
|
||||||
|
+ return True
|
||||||
|
+ return False
|
||||||
|
\ No newline at end of file
|
||||||
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
|
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
|
||||||
index c271f438e..cf7180606 100755
|
index c271f438e..cf7180606 100755
|
||||||
--- a/vllm/v1/attention/backends/flash_attn.py
|
--- a/vllm/v1/attention/backends/flash_attn.py
|
||||||
|
|
@ -14457,10 +14684,10 @@ index c271f438e..cf7180606 100755
|
||||||
assert sliding_window == (-1, -1), (
|
assert sliding_window == (-1, -1), (
|
||||||
diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
|
diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
|
||||||
new file mode 100644
|
new file mode 100644
|
||||||
index 000000000..29cde02f3
|
index 000000000..f4a435eaa
|
||||||
--- /dev/null
|
--- /dev/null
|
||||||
+++ b/vllm/v1/attention/backends/ipex_attn.py
|
+++ b/vllm/v1/attention/backends/ipex_attn.py
|
||||||
@@ -0,0 +1,358 @@
|
@@ -0,0 +1,392 @@
|
||||||
+from dataclasses import dataclass
|
+from dataclasses import dataclass
|
||||||
+from typing import Any, Dict, List, Optional, Tuple, Type
|
+from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
+
|
+
|
||||||
|
|
@ -14474,6 +14701,7 @@ index 000000000..29cde02f3
|
||||||
+from vllm.attention.ops.paged_attn import (PagedAttention,
|
+from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
+ PagedAttentionMetadata)
|
+ PagedAttentionMetadata)
|
||||||
+from vllm.attention.backends.ipex_attn import use_gqa_kernel
|
+from vllm.attention.backends.ipex_attn import use_gqa_kernel
|
||||||
|
+from vllm.utils import is_bmg_platform
|
||||||
+import os
|
+import os
|
||||||
+
|
+
|
||||||
+@dataclass
|
+@dataclass
|
||||||
|
|
@ -14509,9 +14737,9 @@ index 000000000..29cde02f3
|
||||||
+ # if block_size % 16 != 0:
|
+ # if block_size % 16 != 0:
|
||||||
+ # raise ValueError("Block size must be a multiple of 16.")
|
+ # raise ValueError("Block size must be a multiple of 16.")
|
||||||
+ # This needs to be changed...
|
+ # This needs to be changed...
|
||||||
+ # return (2, num_blocks, block_size, num_kv_heads, head_size)
|
+ return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
+ # return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||||
+ num_kv_heads, head_size)
|
+ # num_kv_heads, head_size)
|
||||||
+
|
+
|
||||||
+
|
+
|
||||||
+
|
+
|
||||||
|
|
@ -14557,6 +14785,8 @@ index 000000000..29cde02f3
|
||||||
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
+ self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
+
|
+
|
||||||
+ support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes()
|
+ support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes()
|
||||||
|
+ self.using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap)
|
||||||
|
+ self.is_bmg_platform = is_bmg_platform()
|
||||||
+ if head_size not in support_head_sizes:
|
+ if head_size not in support_head_sizes:
|
||||||
+ raise ValueError(
|
+ raise ValueError(
|
||||||
+ f"Head size {head_size} is not supported by FlashAttention. "
|
+ f"Head size {head_size} is not supported by FlashAttention. "
|
||||||
|
|
@ -14567,7 +14797,6 @@ index 000000000..29cde02f3
|
||||||
+ "are not implemented for "
|
+ "are not implemented for "
|
||||||
+ "IpexAttnBackendImpl")
|
+ "IpexAttnBackendImpl")
|
||||||
+
|
+
|
||||||
+ # TODO(gc): Refine this logic..., because of bad performance...
|
|
||||||
+ def forward(
|
+ def forward(
|
||||||
+ self,
|
+ self,
|
||||||
+ layer: AttentionLayer,
|
+ layer: AttentionLayer,
|
||||||
|
|
@ -14610,6 +14839,8 @@ index 000000000..29cde02f3
|
||||||
+ k_scale,
|
+ k_scale,
|
||||||
+ v_scale,
|
+ v_scale,
|
||||||
+ self.scale,
|
+ self.scale,
|
||||||
|
+ self.using_gqa_kernel,
|
||||||
|
+ self.is_bmg_platform,
|
||||||
+ self.sliding_window,
|
+ self.sliding_window,
|
||||||
+ self.alibi_slopes,
|
+ self.alibi_slopes,
|
||||||
+ self.logits_soft_cap,
|
+ self.logits_soft_cap,
|
||||||
|
|
@ -14682,6 +14913,8 @@ index 000000000..29cde02f3
|
||||||
+ k_scale: float,
|
+ k_scale: float,
|
||||||
+ v_scale: float,
|
+ v_scale: float,
|
||||||
+ scale: float,
|
+ scale: float,
|
||||||
|
+ using_gqa_kernel: bool,
|
||||||
|
+ is_bmg_platform: bool,
|
||||||
+ sliding_window: Optional[List[int]] = None,
|
+ sliding_window: Optional[List[int]] = None,
|
||||||
+ alibi_slopes: Optional[torch.Tensor] = None,
|
+ alibi_slopes: Optional[torch.Tensor] = None,
|
||||||
+ logits_soft_cap: Optional[float] = None,
|
+ logits_soft_cap: Optional[float] = None,
|
||||||
|
|
@ -14700,9 +14933,38 @@ index 000000000..29cde02f3
|
||||||
+ key = key.view(-1, num_kv_heads, head_size)
|
+ key = key.view(-1, num_kv_heads, head_size)
|
||||||
+ value = value.view(-1, num_kv_heads, head_size)
|
+ value = value.view(-1, num_kv_heads, head_size)
|
||||||
+
|
+
|
||||||
+ using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap)
|
+ if is_bmg_platform:
|
||||||
+
|
+ key_cache, value_cache = kv_cache.unbind(0)
|
||||||
+
|
+ ipex_ops.reshape_and_cache_flash(
|
||||||
|
+ key[:num_actual_tokens],
|
||||||
|
+ value[:num_actual_tokens],
|
||||||
|
+ key_cache,
|
||||||
|
+ value_cache,
|
||||||
|
+ attn_metadata.slot_mapping,
|
||||||
|
+ kv_cache_dtype,
|
||||||
|
+ k_scale,
|
||||||
|
+ v_scale,
|
||||||
|
+ )
|
||||||
|
+ ipex_ops.chunked_prefill(
|
||||||
|
+ query[:num_actual_tokens].contiguous(),
|
||||||
|
+ key_cache,
|
||||||
|
+ value_cache,
|
||||||
|
+ output[:num_actual_tokens],
|
||||||
|
+ attn_metadata.query_start_loc,
|
||||||
|
+ attn_metadata.seq_start_loc,
|
||||||
|
+ None,
|
||||||
|
+ attn_metadata.block_table,
|
||||||
|
+ alibi_slopes,
|
||||||
|
+ attn_metadata.max_query_len,
|
||||||
|
+ attn_metadata.max_seq_len,
|
||||||
|
+ 0.0,
|
||||||
|
+ scale,
|
||||||
|
+ False,
|
||||||
|
+ True,
|
||||||
|
+ False,
|
||||||
|
+ None,
|
||||||
|
+ )
|
||||||
|
+ else:
|
||||||
+ if using_gqa_kernel:
|
+ if using_gqa_kernel:
|
||||||
+ key_cache, value_cache = split_kv_cache_ipexllm(
|
+ key_cache, value_cache = split_kv_cache_ipexllm(
|
||||||
+ kv_cache, num_kv_heads, head_size)
|
+ kv_cache, num_kv_heads, head_size)
|
||||||
|
|
@ -14750,7 +15012,6 @@ index 000000000..29cde02f3
|
||||||
+
|
+
|
||||||
+
|
+
|
||||||
+
|
+
|
||||||
+
|
|
||||||
+@torch.library.custom_op("vllm::ipex_attn_chunked_prefill",
|
+@torch.library.custom_op("vllm::ipex_attn_chunked_prefill",
|
||||||
+ mutates_args=["output", "kv_cache"])
|
+ mutates_args=["output", "kv_cache"])
|
||||||
+def ipex_attn_chunked_prefill(
|
+def ipex_attn_chunked_prefill(
|
||||||
|
|
@ -15648,10 +15909,10 @@ index 000000000..8612d3d77
|
||||||
+ self.kv_caches)
|
+ self.kv_caches)
|
||||||
diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
|
diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
|
||||||
new file mode 100644
|
new file mode 100644
|
||||||
index 000000000..1bc531e28
|
index 000000000..1fb0dca87
|
||||||
--- /dev/null
|
--- /dev/null
|
||||||
+++ b/vllm/v1/worker/xpu_worker.py
|
+++ b/vllm/v1/worker/xpu_worker.py
|
||||||
@@ -0,0 +1,168 @@
|
@@ -0,0 +1,175 @@
|
||||||
+# SPDX-License-Identifier: Apache-2.0
|
+# SPDX-License-Identifier: Apache-2.0
|
||||||
+import os
|
+import os
|
||||||
+from typing import Optional
|
+from typing import Optional
|
||||||
|
|
@ -15685,8 +15946,15 @@ index 000000000..1bc531e28
|
||||||
+ assert device_config.device_type == "xpu"
|
+ assert device_config.device_type == "xpu"
|
||||||
+ assert current_platform.is_xpu()
|
+ assert current_platform.is_xpu()
|
||||||
+
|
+
|
||||||
+ def load_model(self) -> None:
|
+ import os
|
||||||
+ self.model_runner.load_model()
|
+ lowbit = os.getenv("IPEX_LLM_LOWBIT", None)
|
||||||
|
+ if lowbit is not None:
|
||||||
|
+ from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
||||||
|
+ _ipex_llm_convert(lowbit)
|
||||||
|
+
|
||||||
|
+
|
||||||
|
+ def compile_or_warm_up_model(self) -> None:
|
||||||
|
+ pass
|
||||||
+
|
+
|
||||||
+ # we provide this function due to `torch.xpu.mem_get_info()` doesn't
|
+ # we provide this function due to `torch.xpu.mem_get_info()` doesn't
|
||||||
+ # return correct free_gpu_memory on intel client GPU. We need to
|
+ # return correct free_gpu_memory on intel client GPU. We need to
|
||||||
|
|
@ -15838,7 +16106,7 @@ index 86e6d9752..ad80bf54e 100644
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
|
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
|
||||||
index 9d49b4385..67f07f5b1 100644
|
index 9d49b4385..7396b0c89 100644
|
||||||
--- a/vllm/worker/xpu_model_runner.py
|
--- a/vllm/worker/xpu_model_runner.py
|
||||||
+++ b/vllm/worker/xpu_model_runner.py
|
+++ b/vllm/worker/xpu_model_runner.py
|
||||||
@@ -5,8 +5,8 @@ import time
|
@@ -5,8 +5,8 @@ import time
|
||||||
|
|
@ -16163,7 +16431,7 @@ index 9d49b4385..67f07f5b1 100644
|
||||||
+ slot_mapping_tensor = torch.tensor(slot_mapping,
|
+ slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||||
+ dtype=torch.long,
|
+ dtype=torch.long,
|
||||||
+ device=self.device)
|
+ device=self.device)
|
||||||
+ if need_block_table:
|
+ if need_block_table or "bge" in self.runner.model_config.model.lower():
|
||||||
+ seq_lens_tensor = torch.tensor(seq_lens,
|
+ seq_lens_tensor = torch.tensor(seq_lens,
|
||||||
+ dtype=torch.int,
|
+ dtype=torch.int,
|
||||||
+ device=self.device)
|
+ device=self.device)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue