Add vLLM bf16 support (#10278)
* add argument load_in_low_bit * add docs * modify gpu doc * done --------- Co-authored-by: ivy-lv11 <lvzc@lamda.nju.edu.cn>
This commit is contained in:
parent
13b0bc9075
commit
2d930bdca8
13 changed files with 83 additions and 37 deletions
|
|
@ -21,8 +21,6 @@ pip3 install numpy
|
|||
pip3 install --pre --upgrade bigdl-llm[all]
|
||||
pip3 install psutil
|
||||
pip3 install sentencepiece # Required for LLaMA tokenizer.
|
||||
pip3 install "torch==2.0.1"
|
||||
pip3 install "transformers>=4.33.1" # Required for Code Llama.
|
||||
pip3 install fastapi
|
||||
pip3 install "uvicorn[standard]"
|
||||
pip3 install "pydantic<2" # Required for OpenAI server.
|
||||
|
|
@ -44,6 +42,7 @@ To run offline inference using vLLM for a quick impression, use the following ex
|
|||
#!/bin/bash
|
||||
|
||||
# Please first modify the MODEL_PATH in offline_inference.py
|
||||
# Modify load_in_low_bit to use different quantization dtype
|
||||
|
||||
numactl -C 48-95 -m 1 python offline_inference.py
|
||||
|
||||
|
|
@ -60,6 +59,7 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
|
|||
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \
|
||||
--model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000 \
|
||||
--load-format 'auto' --device cpu --dtype bfloat16 \
|
||||
--load-in-low-bit sym_int4 \
|
||||
--max-num-batched-tokens 4096
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||
|
||||
# Create an LLM.
|
||||
# llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16")
|
||||
llm = LLM(model="YOUR_MODEL_PATH", load_in_low_bit="sym_int4")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ conda activate bigdl-vllm
|
|||
pip3 install psutil
|
||||
pip3 install sentencepiece # Required for LLaMA tokenizer.
|
||||
pip3 install numpy
|
||||
pip3 install "transformers>=4.33.1" # Required for Code Llama.
|
||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
|
||||
pip install --pre --upgrade "bigdl-llm[xpu]" -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
pip3 install fastapi
|
||||
|
|
@ -62,6 +61,7 @@ To run offline inference using vLLM for a quick impression, use the following ex
|
|||
#!/bin/bash
|
||||
|
||||
# Please first modify the MODEL_PATH in offline_inference.py
|
||||
# Modify load_in_low_bit to use different quantization dtype
|
||||
python offline_inference.py
|
||||
```
|
||||
|
||||
|
|
@ -76,6 +76,7 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
|
|||
python -m bigdl.llm.vllm.entrypoints.openai.api_server \
|
||||
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
|
||||
--load-format 'auto' --device xpu --dtype bfloat16 \
|
||||
--load-in-low-bit sym_int4 \
|
||||
--max-num-batched-tokens 4096
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||
|
||||
# Create an LLM.
|
||||
# llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16", device="xpu")
|
||||
llm = LLM(model="YOUR_MODEL_PATH", load_in_low_bit="sym_int4", dtype="bfloat16", device="xpu")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ class ModelConfig:
|
|||
weights. If None, we assume the model weights are not quantized.
|
||||
device: The device to be used for the model. If None, we will default
|
||||
to use CPU as the device.
|
||||
load_in_low_bit: The low-bit quantization for model to be loaded. Default int4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -95,6 +96,7 @@ class ModelConfig:
|
|||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
device: Optional[str] = 'cpu',
|
||||
load_in_low_bit: str = 'sym_int4',
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
|
@ -107,6 +109,7 @@ class ModelConfig:
|
|||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.device = device
|
||||
self.load_in_low_bit = load_in_low_bit
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ class EngineArgs:
|
|||
# bigdl-llm change start
|
||||
# summary: add device option
|
||||
device: Optional[str] = 'cpu'
|
||||
load_in_low_bit: str = 'sym_int4'
|
||||
# bigdl-llm change end
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
@ -212,6 +213,10 @@ class EngineArgs:
|
|||
choices=['gpu', 'cpu', 'xpu', None],
|
||||
default=None,
|
||||
help='Device to execute LLM model')
|
||||
parser.add_argument('--load-in-low-bit',
|
||||
type=str,
|
||||
default='sym_int4',
|
||||
help='low_bit_quantization')
|
||||
|
||||
return parser
|
||||
|
||||
|
|
@ -229,7 +234,7 @@ class EngineArgs:
|
|||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.tokenizer_revision, self.max_model_len,
|
||||
self.quantization, self.device)
|
||||
self.quantization, self.device, self.load_in_low_bit)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len)
|
||||
|
|
|
|||
|
|
@ -119,7 +119,9 @@ class LLMEngine:
|
|||
f"load_format={model_config.load_format}, "
|
||||
# f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"seed={model_config.seed})"
|
||||
f"seed={model_config.seed}), "
|
||||
f"device={model_config.device}, "
|
||||
f"load_in_low_bit={model_config.load_in_low_bit}"
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ class LLM:
|
|||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
device: The device to be used for the model. If None, we will default
|
||||
to use CPU as the device.
|
||||
load_in_low_bit: The low-bit quantization for model to be loaded. Default int4.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -112,6 +113,7 @@ class LLM:
|
|||
# bigdl-llm change start
|
||||
# summary: add device option
|
||||
device: Optional[str] = "cpu",
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
# bigdl-llm change end
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
|
@ -134,6 +136,7 @@ class LLM:
|
|||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
device=device,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
|
|
|||
|
|
@ -114,8 +114,10 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
model = model_class(model_config.hf_config, quant_config)
|
||||
else:
|
||||
# TODO: change for other models
|
||||
model = model_class(model_config.hf_config, device=model_config.device,
|
||||
max_model_len=model_config.max_model_len)
|
||||
max_model_len=model_config.max_model_len,
|
||||
load_in_low_bit=model_config.load_in_low_bit)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(model_config.model, model_config.download_dir,
|
||||
model_config.load_format, model_config.revision)
|
||||
|
|
|
|||
|
|
@ -58,23 +58,28 @@ class BigDLChatGLMForCausalLM(BigDLModelForCausalLM):
|
|||
config,
|
||||
device: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
load_in_low_bit: str = 'sym_int4'
|
||||
):
|
||||
super().__init__(config, device, max_model_len)
|
||||
self.config = config
|
||||
# TODO(gc): later change this to a switch?
|
||||
if True:
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm import optimize_model
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
torch_dtype = 'auto'
|
||||
|
||||
# low_bit = 'sym_int4'
|
||||
if load_in_low_bit == 'bf16':
|
||||
torch_dtype = torch.bfloat16
|
||||
elif load_in_low_bit == 'fp16':
|
||||
torch_dtype = torch.float16
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True,
|
||||
)
|
||||
self.model = optimize_model(model)
|
||||
# self.model = optimize_model(model)
|
||||
self.sampler = BigDLSampler(config.vocab_size, device)
|
||||
elif device == 'xpu':
|
||||
try:
|
||||
|
|
@ -83,10 +88,10 @@ class BigDLChatGLMForCausalLM(BigDLModelForCausalLM):
|
|||
print("Intel Extension for PyTorch is not installed, \
|
||||
but is required for xpu inference.")
|
||||
|
||||
low_bit = 'sym_int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=low_bit,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
optimize_model=True,
|
||||
use_cache=True,
|
||||
|
|
|
|||
|
|
@ -63,20 +63,31 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
config: LlamaConfig,
|
||||
device: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
load_in_low_bit: str = 'sym_int4'
|
||||
):
|
||||
super().__init__(config, device, max_model_len)
|
||||
self.config = config
|
||||
# Always enable bigdl-llm model
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm import optimize_model
|
||||
# TODO: we will need to pass the argument through command line argument
|
||||
# from bigdl.llm import optimize_model
|
||||
torch_dtype = 'auto'
|
||||
|
||||
if load_in_low_bit == 'bf16':
|
||||
torch_dtype = torch.bfloat16
|
||||
elif load_in_low_bit == 'fp16':
|
||||
torch_dtype = torch.float16
|
||||
# bf16 will require to set torch_dtype to bf16
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True,
|
||||
)
|
||||
self.model = optimize_model(model)
|
||||
# self.model = optimize_model(model)
|
||||
self.sampler = BigDLSampler(config.vocab_size, device)
|
||||
elif device == 'xpu':
|
||||
try:
|
||||
|
|
@ -85,10 +96,10 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
|
|||
print("Intel Extension for PyTorch is not installed, \
|
||||
but is required for xpu inference.")
|
||||
|
||||
low_bit = 'sym_int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=low_bit,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
use_cache=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -58,23 +58,31 @@ class BigDLMistralForCausalLM(BigDLModelForCausalLM):
|
|||
config,
|
||||
device: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
load_in_low_bit: str = 'sym_int4'
|
||||
):
|
||||
super().__init__(config, device, max_model_len)
|
||||
self.config = config
|
||||
# TODO(gc): later change this to a switch?
|
||||
if True:
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm import optimize_model
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
# from bigdl.llm import optimize_model
|
||||
|
||||
torch_dtype = 'auto'
|
||||
|
||||
if load_in_low_bit == 'bf16':
|
||||
torch_dtype = torch.bfloat16
|
||||
elif load_in_low_bit == 'fp16':
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# low_bit = 'sym_int4'
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True,
|
||||
)
|
||||
self.model = optimize_model(model)
|
||||
# self.model = optimize_model(model)
|
||||
self.sampler = BigDLSampler(config.vocab_size, device)
|
||||
elif device == 'xpu':
|
||||
try:
|
||||
|
|
@ -83,10 +91,10 @@ class BigDLMistralForCausalLM(BigDLModelForCausalLM):
|
|||
print("Intel Extension for PyTorch is not installed, \
|
||||
but is required for xpu inference.")
|
||||
|
||||
low_bit = 'sym_int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=low_bit,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
optimize_model=True,
|
||||
use_cache=True,
|
||||
|
|
|
|||
|
|
@ -58,23 +58,29 @@ class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
|
|||
config,
|
||||
device: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
load_in_low_bit: str = 'sym_int4'
|
||||
):
|
||||
super().__init__(config, device, max_model_len)
|
||||
self.config = config
|
||||
# TODO(gc): later change this to a switch?
|
||||
if True:
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from bigdl.llm import optimize_model
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
|
||||
torch_dtype = 'auto'
|
||||
|
||||
if load_in_low_bit == 'bf16':
|
||||
torch_dtype = torch.bfloat16
|
||||
elif load_in_low_bit == 'fp16':
|
||||
torch_dtype = torch.float16
|
||||
|
||||
# low_bit = 'sym_int4'
|
||||
if device == 'cpu':
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
use_cache=True,
|
||||
)
|
||||
self.model = optimize_model(model)
|
||||
self.sampler = BigDLSampler(config.vocab_size, device)
|
||||
elif device == 'xpu':
|
||||
try:
|
||||
|
|
@ -83,10 +89,10 @@ class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
|
|||
print("Intel Extension for PyTorch is not installed, \
|
||||
but is required for xpu inference.")
|
||||
|
||||
low_bit = 'sym_int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
config._name_or_path,
|
||||
load_in_low_bit=low_bit,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
optimize_model=True,
|
||||
use_cache=True,
|
||||
|
|
|
|||
Loading…
Reference in a new issue