diff --git a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch index 1888f041..ffcad63d 100644 --- a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch +++ b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch @@ -12307,10 +12307,10 @@ index 1b1738f88..2c2ed67b9 100644 layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py new file mode 100644 -index 000000000..8e6e2c11a +index 000000000..332af5a68 --- /dev/null +++ b/vllm/model_executor/models/glm4.py -@@ -0,0 +1,318 @@ +@@ -0,0 +1,329 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Zhipu AI team. @@ -12339,6 +12339,7 @@ index 000000000..8e6e2c11a +import torch +from torch import nn +from transformers import Glm4Config ++import os + +from vllm.attention import Attention, AttentionType +from vllm.compilation.decorators import support_torch_compile @@ -12427,7 +12428,6 @@ index 000000000..8e6e2c11a + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, -+ dtype=torch.float32, + ) + self.attn = Attention(self.num_heads, + self.head_dim, @@ -12461,6 +12461,7 @@ index 000000000..8e6e2c11a + prefix: str = "", + ) -> None: + super().__init__() ++ self.config = config + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) @@ -12519,8 +12520,9 @@ index 000000000..8e6e2c11a + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) -+ hidden_states = hidden_states.to(torch.float32) -+ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36) ++ if self.config.num_hidden_layers == 61: # GLM-4-32B-0414 ++ hidden_states = hidden_states.to(torch.float32) ++ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36) + hidden_states = self.post_mlp_layernorm(hidden_states) + + return hidden_states, residual @@ -12562,6 +12564,15 @@ index 000000000..8e6e2c11a + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config ++ if config.num_hidden_layers == 61: # GLM-4-32B-0414 ++ if vllm_config.model_config.dtype != torch.float32: ++ vllm_config.model_config.dtype = torch.float32 ++ logger.warning_once("vLLM xpu: GLM-4-32B-0414 model is only supported with FP32. Converting dtype to torch.float32.") ++ sdp_disabled_flag = os.getenv("IPEX_LLM_DISABLE_SDP_CAUSAL", None) ++ if sdp_disabled_flag != "1": ++ logger.warning_once("vLLM xpu: GLM-4-32B-0414 model is currently not supported with sdp_causal(). Setting IPEX_LLM_DISABLE_SDP_CAUSAL=1.") ++ os.environ["IPEX_LLM_DISABLE_SDP_CAUSAL"] = "1" ++ + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + @@ -14681,7 +14692,7 @@ index 000000000..bda3de2d3 + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py -index c0a3c59ba..8614c2273 100644 +index c0a3c59ba..b8f54e82e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -57,6 +57,7 @@ _TEXT_GENERATION_MODELS = { @@ -14709,6 +14720,18 @@ index c0a3c59ba..8614c2273 100644 "XverseForCausalLM": ("llama", "LlamaForCausalLM"), "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), # [Encoder-decoder] +@@ -307,8 +311,9 @@ class _LazyRegisteredModel(_BaseRegisteredModel): + + # Performed in another process to avoid initializing CUDA + def inspect_model_cls(self) -> _ModelInfo: +- return _run_in_subprocess( +- lambda: _ModelInfo.from_model_cls(self.load_model_cls())) ++ # return _run_in_subprocess( ++ # lambda: _ModelInfo.from_model_cls(self.load_model_cls())) ++ return _ModelInfo.from_model_cls(self.load_model_cls()) + + def load_model_cls(self) -> Type[nn.Module]: + mod = importlib.import_module(self.module_name) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index a09741a55..d989a12fa 100644 --- a/vllm/model_executor/models/roberta.py