Add vLLM bf16 support (#10278)
* add argument load_in_low_bit * add docs * modify gpu doc * done --------- Co-authored-by: ivy-lv11 <lvzc@lamda.nju.edu.cn>
This commit is contained in:
		
							parent
							
								
									13b0bc9075
								
							
						
					
					
						commit
						2d930bdca8
					
				
					 13 changed files with 83 additions and 37 deletions
				
			
		| 
						 | 
					@ -21,8 +21,6 @@ pip3 install numpy
 | 
				
			||||||
pip3 install --pre --upgrade bigdl-llm[all]
 | 
					pip3 install --pre --upgrade bigdl-llm[all]
 | 
				
			||||||
pip3 install psutil
 | 
					pip3 install psutil
 | 
				
			||||||
pip3 install sentencepiece  # Required for LLaMA tokenizer.
 | 
					pip3 install sentencepiece  # Required for LLaMA tokenizer.
 | 
				
			||||||
pip3 install "torch==2.0.1"
 | 
					 | 
				
			||||||
pip3 install "transformers>=4.33.1"  # Required for Code Llama.
 | 
					 | 
				
			||||||
pip3 install fastapi
 | 
					pip3 install fastapi
 | 
				
			||||||
pip3 install "uvicorn[standard]"
 | 
					pip3 install "uvicorn[standard]"
 | 
				
			||||||
pip3 install "pydantic<2"  # Required for OpenAI server.
 | 
					pip3 install "pydantic<2"  # Required for OpenAI server.
 | 
				
			||||||
| 
						 | 
					@ -44,6 +42,7 @@ To run offline inference using vLLM for a quick impression, use the following ex
 | 
				
			||||||
#!/bin/bash
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Please first modify the MODEL_PATH in offline_inference.py
 | 
					# Please first modify the MODEL_PATH in offline_inference.py
 | 
				
			||||||
 | 
					# Modify load_in_low_bit to use different quantization dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
numactl -C 48-95 -m 1 python offline_inference.py
 | 
					numactl -C 48-95 -m 1 python offline_inference.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,6 +59,7 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
 | 
				
			||||||
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \
 | 
					numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \
 | 
				
			||||||
        --model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000  \
 | 
					        --model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000  \
 | 
				
			||||||
        --load-format 'auto' --device cpu --dtype bfloat16 \
 | 
					        --load-format 'auto' --device cpu --dtype bfloat16 \
 | 
				
			||||||
 | 
					        --load-in-low-bit sym_int4 \
 | 
				
			||||||
        --max-num-batched-tokens 4096
 | 
					        --max-num-batched-tokens 4096
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Create an LLM.
 | 
					# Create an LLM.
 | 
				
			||||||
# llm = LLM(model="facebook/opt-125m")
 | 
					# llm = LLM(model="facebook/opt-125m")
 | 
				
			||||||
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16")
 | 
					llm = LLM(model="YOUR_MODEL_PATH", load_in_low_bit="sym_int4")
 | 
				
			||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
					# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
				
			||||||
# that contain the prompt, generated text, and other information.
 | 
					# that contain the prompt, generated text, and other information.
 | 
				
			||||||
outputs = llm.generate(prompts, sampling_params)
 | 
					outputs = llm.generate(prompts, sampling_params)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,7 +37,6 @@ conda activate bigdl-vllm
 | 
				
			||||||
pip3 install psutil
 | 
					pip3 install psutil
 | 
				
			||||||
pip3 install sentencepiece  # Required for LLaMA tokenizer.
 | 
					pip3 install sentencepiece  # Required for LLaMA tokenizer.
 | 
				
			||||||
pip3 install numpy
 | 
					pip3 install numpy
 | 
				
			||||||
pip3 install "transformers>=4.33.1"  # Required for Code Llama.
 | 
					 | 
				
			||||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
					# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
				
			||||||
pip install --pre --upgrade "bigdl-llm[xpu]" -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
					pip install --pre --upgrade "bigdl-llm[xpu]" -f https://developer.intel.com/ipex-whl-stable-xpu
 | 
				
			||||||
pip3 install fastapi
 | 
					pip3 install fastapi
 | 
				
			||||||
| 
						 | 
					@ -62,6 +61,7 @@ To run offline inference using vLLM for a quick impression, use the following ex
 | 
				
			||||||
#!/bin/bash
 | 
					#!/bin/bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Please first modify the MODEL_PATH in offline_inference.py
 | 
					# Please first modify the MODEL_PATH in offline_inference.py
 | 
				
			||||||
 | 
					# Modify load_in_low_bit to use different quantization dtype
 | 
				
			||||||
python offline_inference.py
 | 
					python offline_inference.py
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -76,6 +76,7 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
 | 
				
			||||||
python -m bigdl.llm.vllm.entrypoints.openai.api_server \
 | 
					python -m bigdl.llm.vllm.entrypoints.openai.api_server \
 | 
				
			||||||
        --model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000  \
 | 
					        --model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000  \
 | 
				
			||||||
        --load-format 'auto' --device xpu --dtype bfloat16 \
 | 
					        --load-format 'auto' --device xpu --dtype bfloat16 \
 | 
				
			||||||
 | 
					        --load-in-low-bit sym_int4 \
 | 
				
			||||||
        --max-num-batched-tokens 4096
 | 
					        --max-num-batched-tokens 4096
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Create an LLM.
 | 
					# Create an LLM.
 | 
				
			||||||
# llm = LLM(model="facebook/opt-125m")
 | 
					# llm = LLM(model="facebook/opt-125m")
 | 
				
			||||||
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16", device="xpu")
 | 
					llm = LLM(model="YOUR_MODEL_PATH", load_in_low_bit="sym_int4", dtype="bfloat16", device="xpu")
 | 
				
			||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
					# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
				
			||||||
# that contain the prompt, generated text, and other information.
 | 
					# that contain the prompt, generated text, and other information.
 | 
				
			||||||
outputs = llm.generate(prompts, sampling_params)
 | 
					outputs = llm.generate(prompts, sampling_params)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -78,6 +78,7 @@ class ModelConfig:
 | 
				
			||||||
            weights. If None, we assume the model weights are not quantized.
 | 
					            weights. If None, we assume the model weights are not quantized.
 | 
				
			||||||
        device: The device to be used for the model. If None, we will default
 | 
					        device: The device to be used for the model. If None, we will default
 | 
				
			||||||
            to use CPU as the device.
 | 
					            to use CPU as the device.
 | 
				
			||||||
 | 
					        load_in_low_bit: The low-bit quantization for model to be loaded. Default int4.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
| 
						 | 
					@ -95,6 +96,7 @@ class ModelConfig:
 | 
				
			||||||
        max_model_len: Optional[int] = None,
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
        quantization: Optional[str] = None,
 | 
					        quantization: Optional[str] = None,
 | 
				
			||||||
        device: Optional[str] = 'cpu',
 | 
					        device: Optional[str] = 'cpu',
 | 
				
			||||||
 | 
					        load_in_low_bit: str = 'sym_int4',
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.tokenizer = tokenizer
 | 
					        self.tokenizer = tokenizer
 | 
				
			||||||
| 
						 | 
					@ -107,6 +109,7 @@ class ModelConfig:
 | 
				
			||||||
        self.tokenizer_revision = tokenizer_revision
 | 
					        self.tokenizer_revision = tokenizer_revision
 | 
				
			||||||
        self.quantization = quantization
 | 
					        self.quantization = quantization
 | 
				
			||||||
        self.device = device
 | 
					        self.device = device
 | 
				
			||||||
 | 
					        self.load_in_low_bit = load_in_low_bit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.hf_config = get_config(model, trust_remote_code, revision)
 | 
					        self.hf_config = get_config(model, trust_remote_code, revision)
 | 
				
			||||||
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
 | 
					        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -71,6 +71,7 @@ class EngineArgs:
 | 
				
			||||||
    # bigdl-llm change start
 | 
					    # bigdl-llm change start
 | 
				
			||||||
    # summary: add device option
 | 
					    # summary: add device option
 | 
				
			||||||
    device: Optional[str] = 'cpu'
 | 
					    device: Optional[str] = 'cpu'
 | 
				
			||||||
 | 
					    load_in_low_bit: str = 'sym_int4'
 | 
				
			||||||
    # bigdl-llm change end
 | 
					    # bigdl-llm change end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init__(self):
 | 
					    def __post_init__(self):
 | 
				
			||||||
| 
						 | 
					@ -212,6 +213,10 @@ class EngineArgs:
 | 
				
			||||||
                            choices=['gpu', 'cpu', 'xpu', None],
 | 
					                            choices=['gpu', 'cpu', 'xpu', None],
 | 
				
			||||||
                            default=None,
 | 
					                            default=None,
 | 
				
			||||||
                            help='Device to execute LLM model')
 | 
					                            help='Device to execute LLM model')
 | 
				
			||||||
 | 
					        parser.add_argument('--load-in-low-bit',
 | 
				
			||||||
 | 
					                            type=str,
 | 
				
			||||||
 | 
					                            default='sym_int4',
 | 
				
			||||||
 | 
					                            help='low_bit_quantization')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return parser
 | 
					        return parser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -229,7 +234,7 @@ class EngineArgs:
 | 
				
			||||||
                                   self.download_dir, self.load_format,
 | 
					                                   self.download_dir, self.load_format,
 | 
				
			||||||
                                   self.dtype, self.seed, self.revision,
 | 
					                                   self.dtype, self.seed, self.revision,
 | 
				
			||||||
                                   self.tokenizer_revision, self.max_model_len,
 | 
					                                   self.tokenizer_revision, self.max_model_len,
 | 
				
			||||||
                                   self.quantization, self.device)
 | 
					                                   self.quantization, self.device, self.load_in_low_bit)
 | 
				
			||||||
        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
 | 
					        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
 | 
				
			||||||
                                           self.max_num_seqs,
 | 
					                                           self.max_num_seqs,
 | 
				
			||||||
                                           model_config.max_model_len)
 | 
					                                           model_config.max_model_len)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,7 +119,9 @@ class LLMEngine:
 | 
				
			||||||
            f"load_format={model_config.load_format}, "
 | 
					            f"load_format={model_config.load_format}, "
 | 
				
			||||||
            # f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
 | 
					            # f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
 | 
				
			||||||
            f"quantization={model_config.quantization}, "
 | 
					            f"quantization={model_config.quantization}, "
 | 
				
			||||||
            f"seed={model_config.seed})"
 | 
					            f"seed={model_config.seed}), "
 | 
				
			||||||
 | 
					            f"device={model_config.device}, "
 | 
				
			||||||
 | 
					            f"load_in_low_bit={model_config.load_in_low_bit}"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        # TODO(woosuk): Print more configs in debug mode.
 | 
					        # TODO(woosuk): Print more configs in debug mode.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,6 +93,7 @@ class LLM:
 | 
				
			||||||
            Otherwise, too small values may cause out-of-memory (OOM) errors.
 | 
					            Otherwise, too small values may cause out-of-memory (OOM) errors.
 | 
				
			||||||
        device: The device to be used for the model. If None, we will default
 | 
					        device: The device to be used for the model. If None, we will default
 | 
				
			||||||
            to use CPU as the device.
 | 
					            to use CPU as the device.
 | 
				
			||||||
 | 
					        load_in_low_bit: The low-bit quantization for model to be loaded. Default int4.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
| 
						 | 
					@ -112,6 +113,7 @@ class LLM:
 | 
				
			||||||
        # bigdl-llm change start
 | 
					        # bigdl-llm change start
 | 
				
			||||||
        # summary: add device option
 | 
					        # summary: add device option
 | 
				
			||||||
        device: Optional[str] = "cpu",
 | 
					        device: Optional[str] = "cpu",
 | 
				
			||||||
 | 
					        load_in_low_bit: str = "sym_int4",
 | 
				
			||||||
        # bigdl-llm change end
 | 
					        # bigdl-llm change end
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
| 
						 | 
					@ -134,6 +136,7 @@ class LLM:
 | 
				
			||||||
            gpu_memory_utilization=gpu_memory_utilization,
 | 
					            gpu_memory_utilization=gpu_memory_utilization,
 | 
				
			||||||
            swap_space=swap_space,
 | 
					            swap_space=swap_space,
 | 
				
			||||||
            device=device,
 | 
					            device=device,
 | 
				
			||||||
 | 
					            load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        self.llm_engine = LLMEngine.from_engine_args(engine_args)
 | 
					        self.llm_engine = LLMEngine.from_engine_args(engine_args)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -114,8 +114,10 @@ def get_model(model_config: ModelConfig) -> nn.Module:
 | 
				
			||||||
        if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
 | 
					        if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
 | 
				
			||||||
            model = model_class(model_config.hf_config, quant_config)
 | 
					            model = model_class(model_config.hf_config, quant_config)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 | 
					            # TODO: change for other models
 | 
				
			||||||
            model = model_class(model_config.hf_config, device=model_config.device,
 | 
					            model = model_class(model_config.hf_config, device=model_config.device,
 | 
				
			||||||
                                max_model_len=model_config.max_model_len)
 | 
					                                max_model_len=model_config.max_model_len,
 | 
				
			||||||
 | 
					                                load_in_low_bit=model_config.load_in_low_bit)
 | 
				
			||||||
        # Load the weights from the cached or downloaded files.
 | 
					        # Load the weights from the cached or downloaded files.
 | 
				
			||||||
        model.load_weights(model_config.model, model_config.download_dir,
 | 
					        model.load_weights(model_config.model, model_config.download_dir,
 | 
				
			||||||
                           model_config.load_format, model_config.revision)
 | 
					                           model_config.load_format, model_config.revision)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -58,23 +58,28 @@ class BigDLChatGLMForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        config,
 | 
					        config,
 | 
				
			||||||
        device: Optional[str] = None,
 | 
					        device: Optional[str] = None,
 | 
				
			||||||
        max_model_len: Optional[int] = None,
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
 | 
					        load_in_low_bit: str = 'sym_int4'
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(config, device, max_model_len)
 | 
					        super().__init__(config, device, max_model_len)
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        # TODO(gc): later change this to a switch?
 | 
					        # TODO(gc): later change this to a switch?
 | 
				
			||||||
        if True:
 | 
					 | 
				
			||||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
            from bigdl.llm import optimize_model
 | 
					        torch_dtype = 'auto'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # low_bit = 'sym_int4'
 | 
					        if load_in_low_bit == 'bf16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.bfloat16
 | 
				
			||||||
 | 
					        elif load_in_low_bit == 'fp16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.float16
 | 
				
			||||||
        if device == 'cpu':
 | 
					        if device == 'cpu':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            self.model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                low_cpu_mem_usage=True,
 | 
					                low_cpu_mem_usage=True,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.model = optimize_model(model)
 | 
					            # self.model = optimize_model(model)
 | 
				
			||||||
            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
				
			||||||
        elif device == 'xpu':
 | 
					        elif device == 'xpu':
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -83,10 +88,10 @@ class BigDLChatGLMForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
                print("Intel Extension for PyTorch is not installed, \
 | 
					                print("Intel Extension for PyTorch is not installed, \
 | 
				
			||||||
                       but is required for xpu inference.")
 | 
					                       but is required for xpu inference.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            low_bit = 'sym_int4'
 | 
					 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
                load_in_low_bit=low_bit,
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                optimize_model=True,
 | 
					                optimize_model=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,20 +63,31 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        config: LlamaConfig,
 | 
					        config: LlamaConfig,
 | 
				
			||||||
        device: Optional[str] = None,
 | 
					        device: Optional[str] = None,
 | 
				
			||||||
        max_model_len: Optional[int] = None,
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
 | 
					        load_in_low_bit: str = 'sym_int4'
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(config, device, max_model_len)
 | 
					        super().__init__(config, device, max_model_len)
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        # Always enable bigdl-llm model
 | 
					        # Always enable bigdl-llm model
 | 
				
			||||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
        from bigdl.llm import optimize_model
 | 
					        # TODO: we will need to pass the argument through command line argument
 | 
				
			||||||
 | 
					        # from bigdl.llm import optimize_model
 | 
				
			||||||
 | 
					        torch_dtype = 'auto'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if load_in_low_bit == 'bf16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.bfloat16
 | 
				
			||||||
 | 
					        elif load_in_low_bit == 'fp16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.float16
 | 
				
			||||||
 | 
					        # bf16 will require to set torch_dtype to bf16
 | 
				
			||||||
        if device == 'cpu':
 | 
					        if device == 'cpu':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            self.model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                low_cpu_mem_usage=True,
 | 
					                low_cpu_mem_usage=True,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.model = optimize_model(model)
 | 
					            # self.model = optimize_model(model)
 | 
				
			||||||
            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
				
			||||||
        elif device == 'xpu':
 | 
					        elif device == 'xpu':
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -85,10 +96,10 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
                print("Intel Extension for PyTorch is not installed, \
 | 
					                print("Intel Extension for PyTorch is not installed, \
 | 
				
			||||||
                    but is required for xpu inference.")
 | 
					                    but is required for xpu inference.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            low_bit = 'sym_int4'
 | 
					 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
                load_in_low_bit=low_bit,
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -58,23 +58,31 @@ class BigDLMistralForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        config,
 | 
					        config,
 | 
				
			||||||
        device: Optional[str] = None,
 | 
					        device: Optional[str] = None,
 | 
				
			||||||
        max_model_len: Optional[int] = None,
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
 | 
					        load_in_low_bit: str = 'sym_int4'
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(config, device, max_model_len)
 | 
					        super().__init__(config, device, max_model_len)
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        # TODO(gc): later change this to a switch?
 | 
					        # TODO(gc): later change this to a switch?
 | 
				
			||||||
        if True:
 | 
					 | 
				
			||||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
            from bigdl.llm import optimize_model
 | 
					        # from bigdl.llm import optimize_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        torch_dtype = 'auto'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if load_in_low_bit == 'bf16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.bfloat16
 | 
				
			||||||
 | 
					        elif load_in_low_bit == 'fp16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.float16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # low_bit = 'sym_int4'
 | 
					 | 
				
			||||||
        if device == 'cpu':
 | 
					        if device == 'cpu':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            self.model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                low_cpu_mem_usage=True,
 | 
					                low_cpu_mem_usage=True,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.model = optimize_model(model)
 | 
					            # self.model = optimize_model(model)
 | 
				
			||||||
            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
				
			||||||
        elif device == 'xpu':
 | 
					        elif device == 'xpu':
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -83,10 +91,10 @@ class BigDLMistralForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
                print("Intel Extension for PyTorch is not installed, \
 | 
					                print("Intel Extension for PyTorch is not installed, \
 | 
				
			||||||
                       but is required for xpu inference.")
 | 
					                       but is required for xpu inference.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            low_bit = 'sym_int4'
 | 
					 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
                load_in_low_bit=low_bit,
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                optimize_model=True,
 | 
					                optimize_model=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -58,23 +58,29 @@ class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
        config,
 | 
					        config,
 | 
				
			||||||
        device: Optional[str] = None,
 | 
					        device: Optional[str] = None,
 | 
				
			||||||
        max_model_len: Optional[int] = None,
 | 
					        max_model_len: Optional[int] = None,
 | 
				
			||||||
 | 
					        load_in_low_bit: str = 'sym_int4'
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        super().__init__(config, device, max_model_len)
 | 
					        super().__init__(config, device, max_model_len)
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
        # TODO(gc): later change this to a switch?
 | 
					        # TODO(gc): later change this to a switch?
 | 
				
			||||||
        if True:
 | 
					 | 
				
			||||||
        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
					        from bigdl.llm.transformers import AutoModelForCausalLM
 | 
				
			||||||
            from bigdl.llm import optimize_model
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # low_bit = 'sym_int4'
 | 
					        torch_dtype = 'auto'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if load_in_low_bit == 'bf16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.bfloat16
 | 
				
			||||||
 | 
					        elif load_in_low_bit == 'fp16':
 | 
				
			||||||
 | 
					            torch_dtype = torch.float16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if device == 'cpu':
 | 
					        if device == 'cpu':
 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            self.model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                low_cpu_mem_usage=True,
 | 
					                low_cpu_mem_usage=True,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            self.model = optimize_model(model)
 | 
					 | 
				
			||||||
            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
					            self.sampler = BigDLSampler(config.vocab_size, device)
 | 
				
			||||||
        elif device == 'xpu':
 | 
					        elif device == 'xpu':
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -83,10 +89,10 @@ class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
 | 
				
			||||||
                print("Intel Extension for PyTorch is not installed, \
 | 
					                print("Intel Extension for PyTorch is not installed, \
 | 
				
			||||||
                       but is required for xpu inference.")
 | 
					                       but is required for xpu inference.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            low_bit = 'sym_int4'
 | 
					 | 
				
			||||||
            model = AutoModelForCausalLM.from_pretrained(
 | 
					            model = AutoModelForCausalLM.from_pretrained(
 | 
				
			||||||
                config._name_or_path,
 | 
					                config._name_or_path,
 | 
				
			||||||
                load_in_low_bit=low_bit,
 | 
					                load_in_low_bit=load_in_low_bit,
 | 
				
			||||||
 | 
					                torch_dtype=torch_dtype,
 | 
				
			||||||
                trust_remote_code=True,
 | 
					                trust_remote_code=True,
 | 
				
			||||||
                optimize_model=True,
 | 
					                optimize_model=True,
 | 
				
			||||||
                use_cache=True,
 | 
					                use_cache=True,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue