update vllm patch (#13185)
Co-authored-by: gc-fu <guancheng.fu@intel.com>
This commit is contained in:
parent
531bef2810
commit
c5d919b151
1 changed files with 29 additions and 6 deletions
|
|
@ -12307,10 +12307,10 @@ index 1b1738f88..2c2ed67b9 100644
|
|||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py
|
||||
new file mode 100644
|
||||
index 000000000..8e6e2c11a
|
||||
index 000000000..332af5a68
|
||||
--- /dev/null
|
||||
+++ b/vllm/model_executor/models/glm4.py
|
||||
@@ -0,0 +1,318 @@
|
||||
@@ -0,0 +1,329 @@
|
||||
+# SPDX-License-Identifier: Apache-2.0
|
||||
+
|
||||
+# Copyright 2025 The Zhipu AI team.
|
||||
|
|
@ -12339,6 +12339,7 @@ index 000000000..8e6e2c11a
|
|||
+import torch
|
||||
+from torch import nn
|
||||
+from transformers import Glm4Config
|
||||
+import os
|
||||
+
|
||||
+from vllm.attention import Attention, AttentionType
|
||||
+from vllm.compilation.decorators import support_torch_compile
|
||||
|
|
@ -12427,7 +12428,6 @@ index 000000000..8e6e2c11a
|
|||
+ rope_scaling=rope_scaling,
|
||||
+ partial_rotary_factor=partial_rotary_factor,
|
||||
+ is_neox_style=False,
|
||||
+ dtype=torch.float32,
|
||||
+ )
|
||||
+ self.attn = Attention(self.num_heads,
|
||||
+ self.head_dim,
|
||||
|
|
@ -12461,6 +12461,7 @@ index 000000000..8e6e2c11a
|
|||
+ prefix: str = "",
|
||||
+ ) -> None:
|
||||
+ super().__init__()
|
||||
+ self.config = config
|
||||
+ self.hidden_size = config.hidden_size
|
||||
+ rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
+ rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
|
@ -12519,8 +12520,9 @@ index 000000000..8e6e2c11a
|
|||
+ # Fully Connected
|
||||
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
+ hidden_states = self.mlp(hidden_states)
|
||||
+ hidden_states = hidden_states.to(torch.float32)
|
||||
+ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36)
|
||||
+ if self.config.num_hidden_layers == 61: # GLM-4-32B-0414
|
||||
+ hidden_states = hidden_states.to(torch.float32)
|
||||
+ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36)
|
||||
+ hidden_states = self.post_mlp_layernorm(hidden_states)
|
||||
+
|
||||
+ return hidden_states, residual
|
||||
|
|
@ -12562,6 +12564,15 @@ index 000000000..8e6e2c11a
|
|||
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
+ super().__init__()
|
||||
+ config = vllm_config.model_config.hf_config
|
||||
+ if config.num_hidden_layers == 61: # GLM-4-32B-0414
|
||||
+ if vllm_config.model_config.dtype != torch.float32:
|
||||
+ vllm_config.model_config.dtype = torch.float32
|
||||
+ logger.warning_once("vLLM xpu: GLM-4-32B-0414 model is only supported with FP32. Converting dtype to torch.float32.")
|
||||
+ sdp_disabled_flag = os.getenv("IPEX_LLM_DISABLE_SDP_CAUSAL", None)
|
||||
+ if sdp_disabled_flag != "1":
|
||||
+ logger.warning_once("vLLM xpu: GLM-4-32B-0414 model is currently not supported with sdp_causal(). Setting IPEX_LLM_DISABLE_SDP_CAUSAL=1.")
|
||||
+ os.environ["IPEX_LLM_DISABLE_SDP_CAUSAL"] = "1"
|
||||
+
|
||||
+ quant_config = vllm_config.quant_config
|
||||
+ lora_config = vllm_config.lora_config
|
||||
+
|
||||
|
|
@ -14681,7 +14692,7 @@ index 000000000..bda3de2d3
|
|||
+ loaded_params.add(name)
|
||||
+ return loaded_params
|
||||
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
|
||||
index c0a3c59ba..8614c2273 100644
|
||||
index c0a3c59ba..b8f54e82e 100644
|
||||
--- a/vllm/model_executor/models/registry.py
|
||||
+++ b/vllm/model_executor/models/registry.py
|
||||
@@ -57,6 +57,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
|
|
@ -14709,6 +14720,18 @@ index c0a3c59ba..8614c2273 100644
|
|||
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
||||
# [Encoder-decoder]
|
||||
@@ -307,8 +311,9 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
|
||||
# Performed in another process to avoid initializing CUDA
|
||||
def inspect_model_cls(self) -> _ModelInfo:
|
||||
- return _run_in_subprocess(
|
||||
- lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
|
||||
+ # return _run_in_subprocess(
|
||||
+ # lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
|
||||
+ return _ModelInfo.from_model_cls(self.load_model_cls())
|
||||
|
||||
def load_model_cls(self) -> Type[nn.Module]:
|
||||
mod = importlib.import_module(self.module_name)
|
||||
diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py
|
||||
index a09741a55..d989a12fa 100644
|
||||
--- a/vllm/model_executor/models/roberta.py
|
||||
|
|
|
|||
Loading…
Reference in a new issue