LLM: accelerate sample of gptneox and update quantize (#8262)

* update quantize & accelerate sample

* fix style check

* fix style error
This commit is contained in:
Ruonan Wang 2023-06-05 15:36:00 +08:00 committed by GitHub
parent 2bc0e7abbb
commit 8bd2992a8d
6 changed files with 70 additions and 29 deletions

View file

@ -1,3 +1,19 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.llm.ggml.convert import _convert_to_ggml from bigdl.llm.ggml.convert import _convert_to_ggml
from bigdl.llm.ggml.quantize import quantize from bigdl.llm.ggml.quantize import quantize
from pathlib import Path from pathlib import Path

View file

@ -46,6 +46,7 @@
# only search the first bigdl package and end up finding only one sub-package. # only search the first bigdl package and end up finding only one sub-package.
from .bloom_cpp import bloom_load, bloom_free, bloom_run from .bloom_cpp import bloom_load, bloom_free, bloom_run
from bigdl.llm.utils.common import invalidInputError
class Bloom: class Bloom:
@ -81,8 +82,7 @@ class Bloom:
A Bloom instance. A Bloom instance.
""" """
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads) self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
if not self.ctx: invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
raise RuntimeError(f"Failed to load model from {model_path}")
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.seed = seed self.seed = seed
self.logits_all = logits_all self.logits_all = logits_all

View file

@ -62,6 +62,8 @@ from ctypes import (
) )
import pathlib import pathlib
from bigdl.llm.utils import get_avx_flags from bigdl.llm.utils import get_avx_flags
from bigdl.llm.utils.common import invalidInputError
# Load the library # Load the library
def _load_shared_library(lib_base_name: str): def _load_shared_library(lib_base_name: str):
@ -71,7 +73,7 @@ def _load_shared_library(lib_base_name: str):
elif sys.platform == "win32": elif sys.platform == "win32":
lib_ext = ".dll" lib_ext = ".dll"
else: else:
raise RuntimeError("Unsupported platform") invalidInputError(False, "Unsupported platform")
avx = get_avx_flags() avx = get_avx_flags()
@ -101,9 +103,10 @@ def _load_shared_library(lib_base_name: str):
try: try:
return ctypes.CDLL(str(_lib_path)) return ctypes.CDLL(str(_lib_path))
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") invalidInputError(False,
f"Failed to load shared library '{_lib_path}': {e}")
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found") invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found")
# Specify the base name of the shared library to load # Specify the base name of the shared library to load

View file

@ -51,6 +51,7 @@ import uuid
import time import time
import math import math
import multiprocessing import multiprocessing
import ctypes
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque, OrderedDict from collections import deque, OrderedDict
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
@ -342,22 +343,29 @@ class Gptneox:
"The attribute `eval_logits` of `Gptneox` object is None.") "The attribute `eval_logits` of `Gptneox` object is None.")
n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx)) n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx))
logits = self.eval_logits[-1] logits = self.eval_logits[-1]
data = (gptneox_cpp.gptneox_token_data * n_vocab)( # accelerate below code by moving to cpp
*[ # data = (gptneox_cpp.gptneox_token_data * n_vocab)(
gptneox_cpp.gptneox_token_data( # *[
id=gptneox_cpp.gptneox_token(i), # gptneox_cpp.gptneox_token_data(
logit=logits[i], # id=gptneox_cpp.gptneox_token(i),
p=gptneox_cpp.c_float(0.0), # logit=logits[i],
) # p=gptneox_cpp.c_float(0.0),
for i in range(n_vocab) # )
] # for i in range(n_vocab)
) # ]
size = gptneox_cpp.c_size_t(n_vocab) # )
sorted = False # size = gptneox_cpp.c_size_t(n_vocab)
candidates = gptneox_cpp.gptneox_token_data_array( # sorted = False
data=data, # candidates = gptneox_cpp.gptneox_token_data_array(
size=size, # data=data,
sorted=sorted, # size=size,
# sorted=sorted,
# )
logits = (ctypes.c_float * n_vocab)(*logits)
candidates = gptneox_cpp.gptneox_get_candidates(
ctx=self.ctx,
n_vocab=n_vocab,
logits=logits
) )
gptneox_cpp.gptneox_sample_repetition_penalty( gptneox_cpp.gptneox_sample_repetition_penalty(
ctx=self.ctx, ctx=self.ctx,

View file

@ -529,6 +529,24 @@ _lib.gptneox_token_eos.restype = gptneox_token
# Sampling functions # Sampling functions
def gptneox_get_candidates(
ctx: gptneox_context_p,
n_vocab: c_int,
logits: c_float_p,
):
return _lib.gptneox_get_candidates(
ctx, n_vocab, logits
)
_lib.gptneox_get_candidates.argtypes = [
gptneox_context_p,
c_int,
c_float_p
]
_lib.gptneox_get_candidates.restype = gptneox_token_data_array
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, # @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
# with negative logit fix. # with negative logit fix.
def gptneox_sample_repetition_penalty( def gptneox_sample_repetition_penalty(

View file

@ -33,7 +33,6 @@ _bloom_quantize_type = {"q4_0": 2,
"q4_1": 3} "q4_1": 3}
_gptneox_quantize_type = {"q4_0": 2, _gptneox_quantize_type = {"q4_0": 2,
"q4_1": 3, "q4_1": 3,
"q4_2": 5,
"q5_0": 8, "q5_0": 8,
"q5_1": 9, "q5_1": 9,
"q8_0": 7} "q8_0": 7}
@ -42,9 +41,6 @@ _quantize_type = {"llama": _llama_quantize_type,
"bloom": _bloom_quantize_type, "bloom": _bloom_quantize_type,
"gptneox": _gptneox_quantize_type} "gptneox": _gptneox_quantize_type}
_valid_types = set(list(_llama_quantize_type.keys()) + list(_bloomz_quantize_type.keys()) +
list(_gptneox_quantize_type.keys()))
def quantize(input_path: str, output_path: str=None, def quantize(input_path: str, output_path: str=None,
model_family: str = 'llama', dtype: str='q4_0'): model_family: str = 'llama', dtype: str='q4_0'):