Add vLLM bf16 support (#10278)

* add argument load_in_low_bit

* add docs

* modify gpu doc

* done

---------

Co-authored-by: ivy-lv11 <lvzc@lamda.nju.edu.cn>
This commit is contained in:
Guancheng Fu 2024-02-29 16:33:42 +08:00 committed by GitHub
parent 13b0bc9075
commit 2d930bdca8
13 changed files with 83 additions and 37 deletions

View file

@ -21,8 +21,6 @@ pip3 install numpy
pip3 install --pre --upgrade bigdl-llm[all]
pip3 install psutil
pip3 install sentencepiece # Required for LLaMA tokenizer.
pip3 install "torch==2.0.1"
pip3 install "transformers>=4.33.1" # Required for Code Llama.
pip3 install fastapi
pip3 install "uvicorn[standard]"
pip3 install "pydantic<2" # Required for OpenAI server.
@ -44,6 +42,7 @@ To run offline inference using vLLM for a quick impression, use the following ex
#!/bin/bash
# Please first modify the MODEL_PATH in offline_inference.py
# Modify load_in_low_bit to use different quantization dtype
numactl -C 48-95 -m 1 python offline_inference.py
@ -60,6 +59,7 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000 \
--load-format 'auto' --device cpu --dtype bfloat16 \
--load-in-low-bit sym_int4 \
--max-num-batched-tokens 4096
```

View file

@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16")
llm = LLM(model="YOUR_MODEL_PATH", load_in_low_bit="sym_int4")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

View file

@ -37,7 +37,6 @@ conda activate bigdl-vllm
pip3 install psutil
pip3 install sentencepiece # Required for LLaMA tokenizer.
pip3 install numpy
pip3 install "transformers>=4.33.1" # Required for Code Llama.
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
pip install --pre --upgrade "bigdl-llm[xpu]" -f https://developer.intel.com/ipex-whl-stable-xpu
pip3 install fastapi
@ -62,6 +61,7 @@ To run offline inference using vLLM for a quick impression, use the following ex
#!/bin/bash
# Please first modify the MODEL_PATH in offline_inference.py
# Modify load_in_low_bit to use different quantization dtype
python offline_inference.py
```
@ -76,6 +76,7 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
python -m bigdl.llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
--load-format 'auto' --device xpu --dtype bfloat16 \
--load-in-low-bit sym_int4 \
--max-num-batched-tokens 4096
```

View file

@ -46,7 +46,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16", device="xpu")
llm = LLM(model="YOUR_MODEL_PATH", load_in_low_bit="sym_int4", dtype="bfloat16", device="xpu")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

View file

@ -78,6 +78,7 @@ class ModelConfig:
weights. If None, we assume the model weights are not quantized.
device: The device to be used for the model. If None, we will default
to use CPU as the device.
load_in_low_bit: The low-bit quantization for model to be loaded. Default int4.
"""
def __init__(
@ -95,6 +96,7 @@ class ModelConfig:
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
device: Optional[str] = 'cpu',
load_in_low_bit: str = 'sym_int4',
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -107,6 +109,7 @@ class ModelConfig:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.device = device
self.load_in_low_bit = load_in_low_bit
self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)

View file

@ -71,6 +71,7 @@ class EngineArgs:
# bigdl-llm change start
# summary: add device option
device: Optional[str] = 'cpu'
load_in_low_bit: str = 'sym_int4'
# bigdl-llm change end
def __post_init__(self):
@ -212,6 +213,10 @@ class EngineArgs:
choices=['gpu', 'cpu', 'xpu', None],
default=None,
help='Device to execute LLM model')
parser.add_argument('--load-in-low-bit',
type=str,
default='sym_int4',
help='low_bit_quantization')
return parser
@ -229,7 +234,7 @@ class EngineArgs:
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization, self.device)
self.quantization, self.device, self.load_in_low_bit)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len)

View file

@ -119,7 +119,9 @@ class LLMEngine:
f"load_format={model_config.load_format}, "
# f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})"
f"seed={model_config.seed}), "
f"device={model_config.device}, "
f"load_in_low_bit={model_config.load_in_low_bit}"
)
# TODO(woosuk): Print more configs in debug mode.

View file

@ -93,6 +93,7 @@ class LLM:
Otherwise, too small values may cause out-of-memory (OOM) errors.
device: The device to be used for the model. If None, we will default
to use CPU as the device.
load_in_low_bit: The low-bit quantization for model to be loaded. Default int4.
"""
def __init__(
@ -112,6 +113,7 @@ class LLM:
# bigdl-llm change start
# summary: add device option
device: Optional[str] = "cpu",
load_in_low_bit: str = "sym_int4",
# bigdl-llm change end
**kwargs,
) -> None:
@ -134,6 +136,7 @@ class LLM:
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
device=device,
load_in_low_bit=load_in_low_bit,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)

View file

@ -114,8 +114,10 @@ def get_model(model_config: ModelConfig) -> nn.Module:
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
# TODO: change for other models
model = model_class(model_config.hf_config, device=model_config.device,
max_model_len=model_config.max_model_len)
max_model_len=model_config.max_model_len,
load_in_low_bit=model_config.load_in_low_bit)
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision)

View file

@ -58,23 +58,28 @@ class BigDLChatGLMForCausalLM(BigDLModelForCausalLM):
config,
device: Optional[str] = None,
max_model_len: Optional[int] = None,
load_in_low_bit: str = 'sym_int4'
):
super().__init__(config, device, max_model_len)
self.config = config
# TODO(gc): later change this to a switch?
if True:
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
from bigdl.llm.transformers import AutoModelForCausalLM
torch_dtype = 'auto'
# low_bit = 'sym_int4'
if load_in_low_bit == 'bf16':
torch_dtype = torch.bfloat16
elif load_in_low_bit == 'fp16':
torch_dtype = torch.float16
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_cache=True,
)
self.model = optimize_model(model)
# self.model = optimize_model(model)
self.sampler = BigDLSampler(config.vocab_size, device)
elif device == 'xpu':
try:
@ -83,10 +88,10 @@ class BigDLChatGLMForCausalLM(BigDLModelForCausalLM):
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=low_bit,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
trust_remote_code=True,
optimize_model=True,
use_cache=True,

View file

@ -63,20 +63,31 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
config: LlamaConfig,
device: Optional[str] = None,
max_model_len: Optional[int] = None,
load_in_low_bit: str = 'sym_int4'
):
super().__init__(config, device, max_model_len)
self.config = config
# Always enable bigdl-llm model
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
# TODO: we will need to pass the argument through command line argument
# from bigdl.llm import optimize_model
torch_dtype = 'auto'
if load_in_low_bit == 'bf16':
torch_dtype = torch.bfloat16
elif load_in_low_bit == 'fp16':
torch_dtype = torch.float16
# bf16 will require to set torch_dtype to bf16
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_cache=True,
)
self.model = optimize_model(model)
# self.model = optimize_model(model)
self.sampler = BigDLSampler(config.vocab_size, device)
elif device == 'xpu':
try:
@ -85,10 +96,10 @@ class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=low_bit,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
trust_remote_code=True,
use_cache=True,
)

View file

@ -58,23 +58,31 @@ class BigDLMistralForCausalLM(BigDLModelForCausalLM):
config,
device: Optional[str] = None,
max_model_len: Optional[int] = None,
load_in_low_bit: str = 'sym_int4'
):
super().__init__(config, device, max_model_len)
self.config = config
# TODO(gc): later change this to a switch?
if True:
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
from bigdl.llm.transformers import AutoModelForCausalLM
# from bigdl.llm import optimize_model
torch_dtype = 'auto'
if load_in_low_bit == 'bf16':
torch_dtype = torch.bfloat16
elif load_in_low_bit == 'fp16':
torch_dtype = torch.float16
# low_bit = 'sym_int4'
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_cache=True,
)
self.model = optimize_model(model)
# self.model = optimize_model(model)
self.sampler = BigDLSampler(config.vocab_size, device)
elif device == 'xpu':
try:
@ -83,10 +91,10 @@ class BigDLMistralForCausalLM(BigDLModelForCausalLM):
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=low_bit,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
trust_remote_code=True,
optimize_model=True,
use_cache=True,

View file

@ -58,23 +58,29 @@ class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
config,
device: Optional[str] = None,
max_model_len: Optional[int] = None,
load_in_low_bit: str = 'sym_int4'
):
super().__init__(config, device, max_model_len)
self.config = config
# TODO(gc): later change this to a switch?
if True:
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
from bigdl.llm.transformers import AutoModelForCausalLM
torch_dtype = 'auto'
if load_in_low_bit == 'bf16':
torch_dtype = torch.bfloat16
elif load_in_low_bit == 'fp16':
torch_dtype = torch.float16
# low_bit = 'sym_int4'
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(
self.model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_cache=True,
)
self.model = optimize_model(model)
self.sampler = BigDLSampler(config.vocab_size, device)
elif device == 'xpu':
try:
@ -83,10 +89,10 @@ class BigDLMixtralForCausalLM(BigDLModelForCausalLM):
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=low_bit,
load_in_low_bit=load_in_low_bit,
torch_dtype=torch_dtype,
trust_remote_code=True,
optimize_model=True,
use_cache=True,