LLM: accelerate sample of gptneox and update quantize (#8262)
* update quantize & accelerate sample * fix style check * fix style error
This commit is contained in:
parent
2bc0e7abbb
commit
8bd2992a8d
6 changed files with 70 additions and 29 deletions
|
|
@ -1,3 +1,19 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
from bigdl.llm.ggml.convert import _convert_to_ggml
|
from bigdl.llm.ggml.convert import _convert_to_ggml
|
||||||
from bigdl.llm.ggml.quantize import quantize
|
from bigdl.llm.ggml.quantize import quantize
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
class Bloom:
|
class Bloom:
|
||||||
|
|
@ -81,8 +82,7 @@ class Bloom:
|
||||||
A Bloom instance.
|
A Bloom instance.
|
||||||
"""
|
"""
|
||||||
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
||||||
if not self.ctx:
|
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
|
||||||
raise RuntimeError(f"Failed to load model from {model_path}")
|
|
||||||
self.n_ctx = n_ctx
|
self.n_ctx = n_ctx
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.logits_all = logits_all
|
self.logits_all = logits_all
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,8 @@ from ctypes import (
|
||||||
)
|
)
|
||||||
import pathlib
|
import pathlib
|
||||||
from bigdl.llm.utils import get_avx_flags
|
from bigdl.llm.utils import get_avx_flags
|
||||||
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
||||||
|
|
||||||
# Load the library
|
# Load the library
|
||||||
def _load_shared_library(lib_base_name: str):
|
def _load_shared_library(lib_base_name: str):
|
||||||
|
|
@ -71,7 +73,7 @@ def _load_shared_library(lib_base_name: str):
|
||||||
elif sys.platform == "win32":
|
elif sys.platform == "win32":
|
||||||
lib_ext = ".dll"
|
lib_ext = ".dll"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported platform")
|
invalidInputError(False, "Unsupported platform")
|
||||||
|
|
||||||
avx = get_avx_flags()
|
avx = get_avx_flags()
|
||||||
|
|
||||||
|
|
@ -101,9 +103,10 @@ def _load_shared_library(lib_base_name: str):
|
||||||
try:
|
try:
|
||||||
return ctypes.CDLL(str(_lib_path))
|
return ctypes.CDLL(str(_lib_path))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
invalidInputError(False,
|
||||||
|
f"Failed to load shared library '{_lib_path}': {e}")
|
||||||
|
|
||||||
raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
|
invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found")
|
||||||
|
|
||||||
|
|
||||||
# Specify the base name of the shared library to load
|
# Specify the base name of the shared library to load
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ import uuid
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import ctypes
|
||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
||||||
from collections import deque, OrderedDict
|
from collections import deque, OrderedDict
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
|
@ -342,22 +343,29 @@ class Gptneox:
|
||||||
"The attribute `eval_logits` of `Gptneox` object is None.")
|
"The attribute `eval_logits` of `Gptneox` object is None.")
|
||||||
n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx))
|
n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx))
|
||||||
logits = self.eval_logits[-1]
|
logits = self.eval_logits[-1]
|
||||||
data = (gptneox_cpp.gptneox_token_data * n_vocab)(
|
# accelerate below code by moving to cpp
|
||||||
*[
|
# data = (gptneox_cpp.gptneox_token_data * n_vocab)(
|
||||||
gptneox_cpp.gptneox_token_data(
|
# *[
|
||||||
id=gptneox_cpp.gptneox_token(i),
|
# gptneox_cpp.gptneox_token_data(
|
||||||
logit=logits[i],
|
# id=gptneox_cpp.gptneox_token(i),
|
||||||
p=gptneox_cpp.c_float(0.0),
|
# logit=logits[i],
|
||||||
)
|
# p=gptneox_cpp.c_float(0.0),
|
||||||
for i in range(n_vocab)
|
# )
|
||||||
]
|
# for i in range(n_vocab)
|
||||||
)
|
# ]
|
||||||
size = gptneox_cpp.c_size_t(n_vocab)
|
# )
|
||||||
sorted = False
|
# size = gptneox_cpp.c_size_t(n_vocab)
|
||||||
candidates = gptneox_cpp.gptneox_token_data_array(
|
# sorted = False
|
||||||
data=data,
|
# candidates = gptneox_cpp.gptneox_token_data_array(
|
||||||
size=size,
|
# data=data,
|
||||||
sorted=sorted,
|
# size=size,
|
||||||
|
# sorted=sorted,
|
||||||
|
# )
|
||||||
|
logits = (ctypes.c_float * n_vocab)(*logits)
|
||||||
|
candidates = gptneox_cpp.gptneox_get_candidates(
|
||||||
|
ctx=self.ctx,
|
||||||
|
n_vocab=n_vocab,
|
||||||
|
logits=logits
|
||||||
)
|
)
|
||||||
gptneox_cpp.gptneox_sample_repetition_penalty(
|
gptneox_cpp.gptneox_sample_repetition_penalty(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
|
|
||||||
|
|
@ -529,6 +529,24 @@ _lib.gptneox_token_eos.restype = gptneox_token
|
||||||
# Sampling functions
|
# Sampling functions
|
||||||
|
|
||||||
|
|
||||||
|
def gptneox_get_candidates(
|
||||||
|
ctx: gptneox_context_p,
|
||||||
|
n_vocab: c_int,
|
||||||
|
logits: c_float_p,
|
||||||
|
):
|
||||||
|
return _lib.gptneox_get_candidates(
|
||||||
|
ctx, n_vocab, logits
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_lib.gptneox_get_candidates.argtypes = [
|
||||||
|
gptneox_context_p,
|
||||||
|
c_int,
|
||||||
|
c_float_p
|
||||||
|
]
|
||||||
|
_lib.gptneox_get_candidates.restype = gptneox_token_data_array
|
||||||
|
|
||||||
|
|
||||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
|
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
|
||||||
# with negative logit fix.
|
# with negative logit fix.
|
||||||
def gptneox_sample_repetition_penalty(
|
def gptneox_sample_repetition_penalty(
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,6 @@ _bloom_quantize_type = {"q4_0": 2,
|
||||||
"q4_1": 3}
|
"q4_1": 3}
|
||||||
_gptneox_quantize_type = {"q4_0": 2,
|
_gptneox_quantize_type = {"q4_0": 2,
|
||||||
"q4_1": 3,
|
"q4_1": 3,
|
||||||
"q4_2": 5,
|
|
||||||
"q5_0": 8,
|
"q5_0": 8,
|
||||||
"q5_1": 9,
|
"q5_1": 9,
|
||||||
"q8_0": 7}
|
"q8_0": 7}
|
||||||
|
|
@ -42,9 +41,6 @@ _quantize_type = {"llama": _llama_quantize_type,
|
||||||
"bloom": _bloom_quantize_type,
|
"bloom": _bloom_quantize_type,
|
||||||
"gptneox": _gptneox_quantize_type}
|
"gptneox": _gptneox_quantize_type}
|
||||||
|
|
||||||
_valid_types = set(list(_llama_quantize_type.keys()) + list(_bloomz_quantize_type.keys()) +
|
|
||||||
list(_gptneox_quantize_type.keys()))
|
|
||||||
|
|
||||||
|
|
||||||
def quantize(input_path: str, output_path: str=None,
|
def quantize(input_path: str, output_path: str=None,
|
||||||
model_family: str = 'llama', dtype: str='q4_0'):
|
model_family: str = 'llama', dtype: str='q4_0'):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue