# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Some parts of this file is adapted from # https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/config.py # which is licensed under Apache License 2.0 # # Copyright 2023 The vLLM team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch from transformers import AutoConfig, PretrainedConfig from ipex_llm.vllm.logger import init_logger from ipex_llm.utils.common import invalidInputError logger = init_logger(__name__) class ModelConfig: """Configuration for the model. Args: model: Name or path of the huggingface model to use. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. download_dir: Directory to download and load the weights, default to the default cache directory of huggingface. load_format: The format of the model weights to load: "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available. "pt" will load the weights in the pytorch bin format. "safetensors" will load the weights in the safetensors format. "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. "dummy" will initialize the weights with random values, which is mainly for profiling. dtype: Data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. seed: Random seed for reproducibility. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. max_model_len: Maximum length of a sequence (including prompt and output). If None, will be derived from the model. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. device: The device to be used for the model. If None, we will default to use CPU as the device. load_in_low_bit: The low-bit quantization for model to be loaded. Default int4. """ def __init__( self, model: str, tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, download_dir: Optional[str], load_format: str, dtype: str, seed: int, revision: Optional[str] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, device: Optional[str] = 'cpu', load_in_low_bit: str = 'sym_int4', ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode self.trust_remote_code = trust_remote_code self.download_dir = download_dir self.load_format = load_format self.seed = seed self.revision = revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization self.device = device self.load_in_low_bit = load_in_low_bit self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() def _verify_load_format(self) -> None: load_format = self.load_format.lower() if load_format not in [ "auto", "pt", "safetensors", "npcache", "dummy" ]: invalidInputError( False, f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") self.load_format = load_format def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() if tokenizer_mode not in ["auto", "slow"]: invalidInputError( False, f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: supported_quantization = ["awq"] if self.quantization is None: return quantization = self.quantization.lower() if quantization not in supported_quantization: invalidInputError( False, f"Unknown quantization: {self.quantization}. Must be one of " f"{supported_quantization}.") self.quantization = quantization def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: total_num_attention_heads = self.hf_config.num_attention_heads tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: invalidInputError( False, f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") total_num_hidden_layers = self.hf_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size if total_num_hidden_layers % pipeline_parallel_size != 0: invalidInputError( False, f"Total number of hidden layers ({total_num_hidden_layers}) " "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") def get_hidden_size(self) -> int: return self.hf_config.hidden_size def get_head_size(self) -> int: # FIXME(woosuk): This may not be true for all models. return self.hf_config.hidden_size // self.hf_config.num_attention_heads def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU worker.""" # For GPTBigCode & Falcon: # Note: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of # KV heads. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types and getattr(self.hf_config, "new_decoder_architecture", False)) if not new_decoder_arch_falcon and getattr(self.hf_config, "multi_query", False): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 # For Falcon: if getattr(self.hf_config, "n_head_kv", None) is not None: return (self.hf_config.n_head_kv // parallel_config.tensor_parallel_size) if getattr(self.hf_config, "num_kv_heads", None) is not None: return (self.hf_config.num_kv_heads // parallel_config.tensor_parallel_size) # For LLaMA-2: if getattr(self.hf_config, "num_key_value_heads", None) is not None: return (self.hf_config.num_key_value_heads // parallel_config.tensor_parallel_size) total_num_attention_heads = self.hf_config.num_attention_heads return total_num_attention_heads // parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, "float": torch.float32, "float32": torch.float32, "bfloat16": torch.bfloat16, } def _get_and_verify_dtype( config: PretrainedConfig, dtype: str, ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. config_dtype = getattr(config, "torch_dtype", None) if config_dtype is None: config_dtype = torch.float32 dtype = dtype.lower() if dtype == "auto": if config_dtype == torch.float32: # Following the common practice, we use float16 for float32 models. torch_dtype = torch.float16 else: torch_dtype = config_dtype else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: invalidInputError(False, f"Unknown dtype: {dtype}") torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: # Upcasting to float32 is allowed. pass elif config_dtype == torch.float32: # Downcasting from float32 to float16 or bfloat16 is allowed. pass else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning(f"Casting {config_dtype} to {torch_dtype}.") return torch_dtype def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") possible_keys = [ # OPT "max_position_embeddings", # GPT-2 "n_positions", # MPT "max_seq_len", # Others "max_sequence_length", "max_seq_length", "seq_len", ] for key in possible_keys: max_len_key = getattr(hf_config, key, None) if max_len_key is not None: derived_max_model_len = min(derived_max_model_len, max_len_key) if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. return max_model_len default_max_len = 2048 logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " f"{possible_keys}. Assuming the model's maximum length is " f"{default_max_len}.") derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None: invalidInputError("factor" in rope_scaling, "invalid hf_config value for rope_scaling") scaling_factor = rope_scaling["factor"] derived_max_model_len *= scaling_factor if max_model_len is None: max_model_len = derived_max_model_len elif max_model_len > derived_max_model_len: invalidInputError( False, f"User-specified max_model_len ({max_model_len}) is greater than " f"the derived max_model_len ({max_len_key}={derived_max_model_len}" " in model's config.json). This may lead to incorrect model " "outputs or CUDA errors. Make sure the value is correct and " "within the model context size.") return int(max_model_len) def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None) -> PretrainedConfig: # NOTE: Because the Mistral model in HF hub does not have # `configuration_mistral.py`, we cannot use `AutoConfig` to load the # config. Instead, we use `MistralConfig` directly. # NOTE: This is a hack. This does not work for local models. # FIXME: Remove this once the Mistral model is available in the stable # version of HF transformers. if "mistral" in model.lower(): return MistralConfig.from_pretrained(model, revision=revision) try: config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision) except ValueError as e: if (not trust_remote_code and "requires you to execute the configuration file" in str(e)): err_msg = ( "Failed to load the model config. If the model is a custom " "model not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " "or using the `--trust-remote-code` flag in the CLI.") invalidInputError(err_msg) else: invalidInputError(e) return config class ParallelConfig: """Configuration for the distributed execution. Args: pipeline_parallel_size: Number of pipeline parallel groups. tensor_parallel_size: Number of tensor parallel groups. worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. """ def __init__( self, pipeline_parallel_size: int, tensor_parallel_size: int, worker_use_ray: bool, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True self._verify_args() def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: invalidInputError( "Pipeline parallelism is not supported yet.") class SchedulerConfig: """Scheduler configuration. Args: max_num_batched_tokens: Maximum number of tokens to be processed in a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). """ def __init__( self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: # If max_model_len is too short, use 2048 as the default value for # higher throughput. self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self._verify_args() def _verify_args(self) -> None: if self.max_num_batched_tokens < self.max_model_len: invalidInputError( False, f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") if self.max_num_batched_tokens < self.max_num_seqs: invalidInputError( False, f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).")