LLM: accelerate sample of gptneox and update quantize (#8262)
* update quantize & accelerate sample * fix style check * fix style error
This commit is contained in:
		
							parent
							
								
									2bc0e7abbb
								
							
						
					
					
						commit
						8bd2992a8d
					
				
					 6 changed files with 70 additions and 29 deletions
				
			
		| 
						 | 
				
			
			@ -1,3 +1,19 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.ggml.convert import _convert_to_ggml
 | 
			
		||||
from bigdl.llm.ggml.quantize import quantize
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -46,6 +46,7 @@
 | 
			
		|||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Bloom:
 | 
			
		||||
| 
						 | 
				
			
			@ -81,8 +82,7 @@ class Bloom:
 | 
			
		|||
            A Bloom instance.
 | 
			
		||||
        """
 | 
			
		||||
        self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
 | 
			
		||||
        if not self.ctx:
 | 
			
		||||
            raise RuntimeError(f"Failed to load model from {model_path}")
 | 
			
		||||
        invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
 | 
			
		||||
        self.n_ctx = n_ctx
 | 
			
		||||
        self.seed = seed
 | 
			
		||||
        self.logits_all = logits_all
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -62,6 +62,8 @@ from ctypes import (
 | 
			
		|||
)
 | 
			
		||||
import pathlib
 | 
			
		||||
from bigdl.llm.utils import get_avx_flags
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Load the library
 | 
			
		||||
def _load_shared_library(lib_base_name: str):
 | 
			
		||||
| 
						 | 
				
			
			@ -71,7 +73,7 @@ def _load_shared_library(lib_base_name: str):
 | 
			
		|||
    elif sys.platform == "win32":
 | 
			
		||||
        lib_ext = ".dll"
 | 
			
		||||
    else:
 | 
			
		||||
        raise RuntimeError("Unsupported platform")
 | 
			
		||||
        invalidInputError(False, "Unsupported platform")
 | 
			
		||||
 | 
			
		||||
    avx = get_avx_flags()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -101,9 +103,10 @@ def _load_shared_library(lib_base_name: str):
 | 
			
		|||
            try:
 | 
			
		||||
                return ctypes.CDLL(str(_lib_path))
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
 | 
			
		||||
                invalidInputError(False,
 | 
			
		||||
                                  f"Failed to load shared library '{_lib_path}': {e}")
 | 
			
		||||
 | 
			
		||||
    raise FileNotFoundError(f"Shared library with base name '{lib_base_name}' not found")
 | 
			
		||||
    invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Specify the base name of the shared library to load
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,6 +51,7 @@ import uuid
 | 
			
		|||
import time
 | 
			
		||||
import math
 | 
			
		||||
import multiprocessing
 | 
			
		||||
import ctypes
 | 
			
		||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
 | 
			
		||||
from collections import deque, OrderedDict
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
| 
						 | 
				
			
			@ -342,22 +343,29 @@ class Gptneox:
 | 
			
		|||
                          "The attribute `eval_logits` of `Gptneox` object is None.")
 | 
			
		||||
        n_vocab = int(gptneox_cpp.gptneox_n_vocab(self.ctx))
 | 
			
		||||
        logits = self.eval_logits[-1]
 | 
			
		||||
        data = (gptneox_cpp.gptneox_token_data * n_vocab)(
 | 
			
		||||
            *[
 | 
			
		||||
                gptneox_cpp.gptneox_token_data(
 | 
			
		||||
                    id=gptneox_cpp.gptneox_token(i),
 | 
			
		||||
                    logit=logits[i],
 | 
			
		||||
                    p=gptneox_cpp.c_float(0.0),
 | 
			
		||||
                )
 | 
			
		||||
                for i in range(n_vocab)
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        size = gptneox_cpp.c_size_t(n_vocab)
 | 
			
		||||
        sorted = False
 | 
			
		||||
        candidates = gptneox_cpp.gptneox_token_data_array(
 | 
			
		||||
            data=data,
 | 
			
		||||
            size=size,
 | 
			
		||||
            sorted=sorted,
 | 
			
		||||
        # accelerate below code by moving to cpp
 | 
			
		||||
        # data = (gptneox_cpp.gptneox_token_data * n_vocab)(
 | 
			
		||||
        #     *[
 | 
			
		||||
        #         gptneox_cpp.gptneox_token_data(
 | 
			
		||||
        #             id=gptneox_cpp.gptneox_token(i),
 | 
			
		||||
        #             logit=logits[i],
 | 
			
		||||
        #             p=gptneox_cpp.c_float(0.0),
 | 
			
		||||
        #         )
 | 
			
		||||
        #         for i in range(n_vocab)
 | 
			
		||||
        #     ]
 | 
			
		||||
        # )
 | 
			
		||||
        # size = gptneox_cpp.c_size_t(n_vocab)
 | 
			
		||||
        # sorted = False
 | 
			
		||||
        # candidates = gptneox_cpp.gptneox_token_data_array(
 | 
			
		||||
        #     data=data,
 | 
			
		||||
        #     size=size,
 | 
			
		||||
        #     sorted=sorted,
 | 
			
		||||
        # )
 | 
			
		||||
        logits = (ctypes.c_float * n_vocab)(*logits)
 | 
			
		||||
        candidates = gptneox_cpp.gptneox_get_candidates(
 | 
			
		||||
            ctx=self.ctx,
 | 
			
		||||
            n_vocab=n_vocab,
 | 
			
		||||
            logits=logits
 | 
			
		||||
        )
 | 
			
		||||
        gptneox_cpp.gptneox_sample_repetition_penalty(
 | 
			
		||||
            ctx=self.ctx,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -529,6 +529,24 @@ _lib.gptneox_token_eos.restype = gptneox_token
 | 
			
		|||
# Sampling functions
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gptneox_get_candidates(
 | 
			
		||||
    ctx: gptneox_context_p,
 | 
			
		||||
    n_vocab: c_int,
 | 
			
		||||
    logits: c_float_p,
 | 
			
		||||
):
 | 
			
		||||
    return _lib.gptneox_get_candidates(
 | 
			
		||||
        ctx, n_vocab, logits
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_lib.gptneox_get_candidates.argtypes = [
 | 
			
		||||
    gptneox_context_p,
 | 
			
		||||
    c_int,
 | 
			
		||||
    c_float_p
 | 
			
		||||
]
 | 
			
		||||
_lib.gptneox_get_candidates.restype = gptneox_token_data_array
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
 | 
			
		||||
# with negative logit fix.
 | 
			
		||||
def gptneox_sample_repetition_penalty(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,7 +33,6 @@ _bloom_quantize_type = {"q4_0": 2,
 | 
			
		|||
                        "q4_1": 3}
 | 
			
		||||
_gptneox_quantize_type = {"q4_0": 2,
 | 
			
		||||
                          "q4_1": 3,
 | 
			
		||||
                          "q4_2": 5,
 | 
			
		||||
                          "q5_0": 8,
 | 
			
		||||
                          "q5_1": 9,
 | 
			
		||||
                          "q8_0": 7}
 | 
			
		||||
| 
						 | 
				
			
			@ -42,9 +41,6 @@ _quantize_type = {"llama": _llama_quantize_type,
 | 
			
		|||
                  "bloom": _bloom_quantize_type,
 | 
			
		||||
                  "gptneox": _gptneox_quantize_type}
 | 
			
		||||
 | 
			
		||||
_valid_types = set(list(_llama_quantize_type.keys()) + list(_bloomz_quantize_type.keys()) +
 | 
			
		||||
                   list(_gptneox_quantize_type.keys()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quantize(input_path: str, output_path: str=None,
 | 
			
		||||
             model_family: str = 'llama', dtype: str='q4_0'):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue