update vllm patch (#13185)

Co-authored-by: gc-fu <guancheng.fu@intel.com>
This commit is contained in:
Shaojun Liu 2025-05-23 15:02:50 +08:00 committed by GitHub
parent 531bef2810
commit c5d919b151
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -12307,10 +12307,10 @@ index 1b1738f88..2c2ed67b9 100644
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py
new file mode 100644
index 000000000..8e6e2c11a
index 000000000..332af5a68
--- /dev/null
+++ b/vllm/model_executor/models/glm4.py
@@ -0,0 +1,318 @@
@@ -0,0 +1,329 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Copyright 2025 The Zhipu AI team.
@ -12339,6 +12339,7 @@ index 000000000..8e6e2c11a
+import torch
+from torch import nn
+from transformers import Glm4Config
+import os
+
+from vllm.attention import Attention, AttentionType
+from vllm.compilation.decorators import support_torch_compile
@ -12427,7 +12428,6 @@ index 000000000..8e6e2c11a
+ rope_scaling=rope_scaling,
+ partial_rotary_factor=partial_rotary_factor,
+ is_neox_style=False,
+ dtype=torch.float32,
+ )
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
@ -12461,6 +12461,7 @@ index 000000000..8e6e2c11a
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 1000000)
+ rope_scaling = getattr(config, "rope_scaling", None)
@ -12519,8 +12520,9 @@ index 000000000..8e6e2c11a
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = hidden_states.to(torch.float32)
+ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36)
+ if self.config.num_hidden_layers == 61: # GLM-4-32B-0414
+ hidden_states = hidden_states.to(torch.float32)
+ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36)
+ hidden_states = self.post_mlp_layernorm(hidden_states)
+
+ return hidden_states, residual
@ -12562,6 +12564,15 @@ index 000000000..8e6e2c11a
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ if config.num_hidden_layers == 61: # GLM-4-32B-0414
+ if vllm_config.model_config.dtype != torch.float32:
+ vllm_config.model_config.dtype = torch.float32
+ logger.warning_once("vLLM xpu: GLM-4-32B-0414 model is only supported with FP32. Converting dtype to torch.float32.")
+ sdp_disabled_flag = os.getenv("IPEX_LLM_DISABLE_SDP_CAUSAL", None)
+ if sdp_disabled_flag != "1":
+ logger.warning_once("vLLM xpu: GLM-4-32B-0414 model is currently not supported with sdp_causal(). Setting IPEX_LLM_DISABLE_SDP_CAUSAL=1.")
+ os.environ["IPEX_LLM_DISABLE_SDP_CAUSAL"] = "1"
+
+ quant_config = vllm_config.quant_config
+ lora_config = vllm_config.lora_config
+
@ -14681,7 +14692,7 @@ index 000000000..bda3de2d3
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index c0a3c59ba..8614c2273 100644
index c0a3c59ba..b8f54e82e 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -57,6 +57,7 @@ _TEXT_GENERATION_MODELS = {
@ -14709,6 +14720,18 @@ index c0a3c59ba..8614c2273 100644
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
# [Encoder-decoder]
@@ -307,8 +311,9 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
# Performed in another process to avoid initializing CUDA
def inspect_model_cls(self) -> _ModelInfo:
- return _run_in_subprocess(
- lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
+ # return _run_in_subprocess(
+ # lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
+ return _ModelInfo.from_model_cls(self.load_model_cls())
def load_model_cls(self) -> Type[nn.Module]:
mod = importlib.import_module(self.module_name)
diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py
index a09741a55..d989a12fa 100644
--- a/vllm/model_executor/models/roberta.py