parent
							
								
									3ef4aa98d1
								
							
						
					
					
						commit
						ac3d53ff5d
					
				
					 5 changed files with 94 additions and 17 deletions
				
			
		| 
						 | 
					@ -11,7 +11,7 @@ python -m ipex_llm.vllm.cpu.entrypoints.openai.api_server \
 | 
				
			||||||
  --device cpu \
 | 
					  --device cpu \
 | 
				
			||||||
  --dtype bfloat16 \
 | 
					  --dtype bfloat16 \
 | 
				
			||||||
  --enforce-eager \
 | 
					  --enforce-eager \
 | 
				
			||||||
  --load-in-low-bit sym_int4 \
 | 
					  --load-in-low-bit bf16 \
 | 
				
			||||||
  --max-model-len 4096 \
 | 
					  --max-model-len 4096 \
 | 
				
			||||||
  --max-num-batched-tokens 10240 \
 | 
					  --max-num-batched-tokens 10240 \
 | 
				
			||||||
  --max-num-seqs 12 \
 | 
					  --max-num-seqs 12 \
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,7 +49,7 @@ llm = LLM(model="YOUR_MODEL",
 | 
				
			||||||
          device="cpu",
 | 
					          device="cpu",
 | 
				
			||||||
          dtype="bfloat16",
 | 
					          dtype="bfloat16",
 | 
				
			||||||
          enforce_eager=True,
 | 
					          enforce_eager=True,
 | 
				
			||||||
          load_in_low_bit="sym_int4",
 | 
					          load_in_low_bit="bf16",
 | 
				
			||||||
          tensor_parallel_size=1)
 | 
					          tensor_parallel_size=1)
 | 
				
			||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
					# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
				
			||||||
# that contain the prompt, generated text, and other information.
 | 
					# that contain the prompt, generated text, and other information.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Create an LLM.
 | 
					# Create an LLM.
 | 
				
			||||||
# llm = LLM(model="facebook/opt-125m")
 | 
					# llm = LLM(model="facebook/opt-125m")
 | 
				
			||||||
llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="sym_int4")
 | 
					llm = LLM(model="YOUR_MODEL_PATH", device="cpu", load_in_low_bit="bf16")
 | 
				
			||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
					# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
				
			||||||
# that contain the prompt, generated text, and other information.
 | 
					# that contain the prompt, generated text, and other information.
 | 
				
			||||||
outputs = llm.generate(prompts, sampling_params)
 | 
					outputs = llm.generate(prompts, sampling_params)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -54,6 +54,7 @@ import sys
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_IS_VLLM_AVAILABLE = None
 | 
					_IS_VLLM_AVAILABLE = None
 | 
				
			||||||
_USE_VLLM = False
 | 
					_USE_VLLM = False
 | 
				
			||||||
 | 
					_VLLM_VERSION = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_auto_gptq_available():
 | 
					def is_auto_gptq_available():
 | 
				
			||||||
| 
						 | 
					@ -77,6 +78,14 @@ def is_vllm_available():
 | 
				
			||||||
    return _IS_VLLM_AVAILABLE
 | 
					    return _IS_VLLM_AVAILABLE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_package_version(package_name):
 | 
				
			||||||
 | 
					    result = subprocess.run(['pip', 'list'], capture_output=True, text=True)
 | 
				
			||||||
 | 
					    for line in result.stdout.splitlines():
 | 
				
			||||||
 | 
					        if line.startswith(package_name):
 | 
				
			||||||
 | 
					            return line.split()[1]
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_use_vllm():
 | 
					def get_use_vllm():
 | 
				
			||||||
    return _USE_VLLM
 | 
					    return _USE_VLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -133,13 +142,24 @@ def is_linear_module(module):
 | 
				
			||||||
    is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
					    is_awq = is_auto_awq_available() and isinstance(module, WQLinear_GEMM)
 | 
				
			||||||
    if is_vllm_available():
 | 
					    if is_vllm_available():
 | 
				
			||||||
        # Only convert vllm modules
 | 
					        # Only convert vllm modules
 | 
				
			||||||
 | 
					        global _VLLM_VERSION
 | 
				
			||||||
 | 
					        if _VLLM_VERSION is None:
 | 
				
			||||||
 | 
					            _VLLM_VERSION = get_package_version('vllm')
 | 
				
			||||||
 | 
					        if 'xpu' in _VLLM_VERSION:
 | 
				
			||||||
 | 
					            # For vllm xpu
 | 
				
			||||||
 | 
					            from vllm.model_executor.parallel_utils.parallel_state import (
 | 
				
			||||||
 | 
					                get_tensor_model_parallel_group,
 | 
				
			||||||
 | 
					                get_tensor_model_parallel_world_size
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            tp_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # For vllm cpu
 | 
				
			||||||
 | 
					            tp_size = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from vllm.model_executor.layers.linear import (
 | 
					        from vllm.model_executor.layers.linear import (
 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        from vllm.model_executor.parallel_utils.parallel_state import (
 | 
					
 | 
				
			||||||
            get_tensor_model_parallel_group,
 | 
					 | 
				
			||||||
            get_tensor_model_parallel_world_size
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        VLLM_LINEAR_LIST = [
 | 
					        VLLM_LINEAR_LIST = [
 | 
				
			||||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
					            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
| 
						 | 
					@ -148,7 +168,6 @@ def is_linear_module(module):
 | 
				
			||||||
            out_features = module.output_size
 | 
					            out_features = module.output_size
 | 
				
			||||||
            result = True
 | 
					            result = True
 | 
				
			||||||
            mp_group = None
 | 
					            mp_group = None
 | 
				
			||||||
            tp_size = get_tensor_model_parallel_world_size()
 | 
					 | 
				
			||||||
            if isinstance(module, RowParallelLinear) and tp_size >= 2:
 | 
					            if isinstance(module, RowParallelLinear) and tp_size >= 2:
 | 
				
			||||||
                mp_group = get_tensor_model_parallel_group()
 | 
					                mp_group = get_tensor_model_parallel_group()
 | 
				
			||||||
                in_features = module.input_size_per_partition
 | 
					                in_features = module.input_size_per_partition
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,8 +24,10 @@ from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
 | 
				
			||||||
from vllm.attention import Attention, AttentionMetadata
 | 
					from vllm.attention import Attention, AttentionMetadata
 | 
				
			||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
 | 
					from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
 | 
				
			||||||
from vllm.config import DeviceConfig
 | 
					from vllm.config import DeviceConfig
 | 
				
			||||||
from typing import Tuple
 | 
					
 | 
				
			||||||
 | 
					from vllm._C import ops
 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from typing import List, Optional, Tuple, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _MLP_forward(self, x):
 | 
					def _MLP_forward(self, x):
 | 
				
			||||||
| 
						 | 
					@ -42,7 +44,7 @@ def _Attention_forward(
 | 
				
			||||||
    kv_cache: torch.Tensor,
 | 
					    kv_cache: torch.Tensor,
 | 
				
			||||||
    attn_metadata: AttentionMetadata,
 | 
					    attn_metadata: AttentionMetadata,
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
    qkv = self.qkv_proj(hidden_states)
 | 
					    qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
 | 
				
			||||||
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
 | 
					    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
 | 
				
			||||||
    q, k = self.rotary_emb(positions, q, k)
 | 
					    q, k = self.rotary_emb(positions, q, k)
 | 
				
			||||||
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
 | 
					    attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
 | 
				
			||||||
| 
						 | 
					@ -145,9 +147,61 @@ def _model_attention_convert():
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _ipex_llm_convert(load_in_low_bit):
 | 
					def _ipex_llm_convert(load_in_low_bit):
 | 
				
			||||||
    from vllm.worker.model_runner import ModelRunner
 | 
					    from vllm.worker.cpu_model_runner import CPUModelRunner
 | 
				
			||||||
    import vllm.model_executor.model_loader as model_loader
 | 
					    import vllm.model_executor.model_loader as model_loader
 | 
				
			||||||
    setattr(ModelRunner, "load_model", get_load_function(load_in_low_bit))
 | 
					    setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
 | 
				
			||||||
 | 
					    setattr(RotaryEmbedding, "forward", _ipex_llm_rotary_embedding_forward)
 | 
				
			||||||
 | 
					    from vllm.model_executor.layers.layernorm import RMSNorm
 | 
				
			||||||
 | 
					    setattr(RMSNorm, "forward", _ipex_llm_rmsnorm_forward)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _ipex_llm_rotary_embedding_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    positions: torch.Tensor,
 | 
				
			||||||
 | 
					    query: torch.Tensor,
 | 
				
			||||||
 | 
					    key: torch.Tensor,
 | 
				
			||||||
 | 
					    offsets: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					    self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # ops.rotary_embedding()/batched_rotary_embedding()
 | 
				
			||||||
 | 
					    # are in-place operations that update the query and key tensors.
 | 
				
			||||||
 | 
					    if offsets is not None:
 | 
				
			||||||
 | 
					        ops.batched_rotary_embedding(positions, query, key, self.head_size,
 | 
				
			||||||
 | 
					                                     self.cos_sin_cache,
 | 
				
			||||||
 | 
					                                     self.is_neox_style, self.rotary_dim,
 | 
				
			||||||
 | 
					                                     offsets)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        ops.rotary_embedding(positions, query, key, self.head_size,
 | 
				
			||||||
 | 
					                             self.cos_sin_cache, self.is_neox_style)
 | 
				
			||||||
 | 
					    return query, key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _ipex_llm_rmsnorm_forward(
 | 
				
			||||||
 | 
					    self,
 | 
				
			||||||
 | 
					    x: torch.Tensor,
 | 
				
			||||||
 | 
					    residual: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 | 
				
			||||||
 | 
					    x = x.to(dtype=self.weight.data.dtype)
 | 
				
			||||||
 | 
					    if residual is not None:
 | 
				
			||||||
 | 
					        residual = residual.to(dtype=self.weight.data.dtype)
 | 
				
			||||||
 | 
					        ops.fused_add_rms_norm(
 | 
				
			||||||
 | 
					            x,
 | 
				
			||||||
 | 
					            residual,
 | 
				
			||||||
 | 
					            self.weight.data,
 | 
				
			||||||
 | 
					            self.variance_epsilon,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return x, residual
 | 
				
			||||||
 | 
					    out = torch.empty_like(x)
 | 
				
			||||||
 | 
					    ops.rms_norm(
 | 
				
			||||||
 | 
					        out,
 | 
				
			||||||
 | 
					        x,
 | 
				
			||||||
 | 
					        self.weight.data,
 | 
				
			||||||
 | 
					        self.variance_epsilon,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_load_function(low_bit):
 | 
					def get_load_function(low_bit):
 | 
				
			||||||
| 
						 | 
					@ -155,11 +209,15 @@ def get_load_function(low_bit):
 | 
				
			||||||
        _model_mlp_convert()
 | 
					        _model_mlp_convert()
 | 
				
			||||||
        _model_attention_convert()
 | 
					        _model_attention_convert()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.model = get_model(self.model_config,
 | 
					        self.model = get_model(
 | 
				
			||||||
                               self.device_config,
 | 
					            model_config=self.model_config,
 | 
				
			||||||
                               lora_config=self.lora_config,
 | 
					            load_config=self.load_config,
 | 
				
			||||||
                               parallel_config=self.parallel_config,
 | 
					            device_config=self.device_config,
 | 
				
			||||||
                               scheduler_config=self.scheduler_config)
 | 
					            vision_language_config=self.vision_language_config,
 | 
				
			||||||
 | 
					            lora_config=self.lora_config,
 | 
				
			||||||
 | 
					            parallel_config=self.parallel_config,
 | 
				
			||||||
 | 
					            scheduler_config=self.scheduler_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        from ipex_llm import optimize_model
 | 
					        from ipex_llm import optimize_model
 | 
				
			||||||
        optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
 | 
					        optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue