Llm: Initial support of langchain transformer int4 API (#8459)
* first commit of transformer int4 and pipeline * basic examples temp save for embeddings support embeddings and docqa exaple * fix based on comment * small fix
This commit is contained in:
parent
14626fe05b
commit
2f77d485d8
9 changed files with 668 additions and 1 deletions
|
|
@ -33,7 +33,6 @@ def main(args):
|
||||||
model_path = args.model_path
|
model_path = args.model_path
|
||||||
model_family = args.model_family
|
model_family = args.model_family
|
||||||
n_threads = args.thread_num
|
n_threads = args.thread_num
|
||||||
|
|
||||||
template ="""{question}"""
|
template ="""{question}"""
|
||||||
|
|
||||||
prompt = PromptTemplate(template=template, input_variables=["question"])
|
prompt = PromptTemplate(template=template, input_variables=["question"])
|
||||||
63
python/llm/example/langchain/transformers_int4/chat.py
Normal file
63
python/llm/example/langchain/transformers_int4/chat.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM
|
||||||
|
from langchain import PromptTemplate, LLMChain
|
||||||
|
from langchain import HuggingFacePipeline
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
question = args.question
|
||||||
|
model_path = args.model_path
|
||||||
|
template ="""{question}"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(template=template, input_variables=["question"])
|
||||||
|
|
||||||
|
# llm = TransformersPipelineLLM.from_model_id(
|
||||||
|
# model_id=model_path,
|
||||||
|
# task="text-generation",
|
||||||
|
# model_kwargs={"temperature": 0, "max_length": 64, "trust_remote_code": True},
|
||||||
|
# )
|
||||||
|
|
||||||
|
llm = TransformersLLM.from_model_id(
|
||||||
|
model_id=model_path,
|
||||||
|
model_kwargs={"temperature": 0, "max_length": 64, "trust_remote_code": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||||
|
|
||||||
|
output = llm_chain.run(question)
|
||||||
|
print("====output=====")
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Llama-CPP-Python style API Simple Example')
|
||||||
|
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||||
|
help='the path to transformers model')
|
||||||
|
parser.add_argument('-q', '--question', type=str, default='What is AI?',
|
||||||
|
help='qustion you want to ask.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
78
python/llm/example/langchain/transformers_int4/docqa.py
Normal file
78
python/llm/example/langchain/transformers_int4/docqa.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from langchain.vectorstores import Chroma
|
||||||
|
from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT,
|
||||||
|
QA_PROMPT)
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
from bigdl.llm.langchain.llms import TransformersLLM
|
||||||
|
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
input_path = args.input_path
|
||||||
|
model_path = args.model_path
|
||||||
|
query = args.question
|
||||||
|
|
||||||
|
# split texts of input doc
|
||||||
|
with open(input_path) as f:
|
||||||
|
input_doc = f.read()
|
||||||
|
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||||
|
texts = text_splitter.split_text(input_doc)
|
||||||
|
|
||||||
|
# create embeddings and store into vectordb
|
||||||
|
embeddings = TransformersEmbeddings.from_model_id(model_id=model_path)
|
||||||
|
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
|
||||||
|
|
||||||
|
#get relavant texts
|
||||||
|
docs = docsearch.get_relevant_documents(query)
|
||||||
|
|
||||||
|
bigdl_llm = TransformersLLM.from_model_id(
|
||||||
|
model_id=model_path,
|
||||||
|
model_kwargs={"temperature": 0, "max_length": 1024, "trust_remote_code": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_chain = load_qa_chain(
|
||||||
|
bigdl_llm, chain_type="stuff", prompt=QA_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
|
output = doc_chain.run(input_documents=docs, question=query)
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Transformer-int4 style API Simple Example')
|
||||||
|
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||||
|
help='the path to transformers model')
|
||||||
|
parser.add_argument('-i', '--input-path', type=str,
|
||||||
|
help='the path to the input doc.')
|
||||||
|
parser.add_argument('-q', '--question', type=str, default='What is AI?',
|
||||||
|
help='qustion you want to ask.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
|
|
@ -20,7 +20,9 @@
|
||||||
# only search the first bigdl package and end up finding only one sub-package.
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
from .bigdlllm import BigdlNativeEmbeddings
|
from .bigdlllm import BigdlNativeEmbeddings
|
||||||
|
from .transformersembeddings import TransformersEmbeddings
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BigdlNativeEmbeddings",
|
"BigdlNativeEmbeddings",
|
||||||
|
"TransformersEmbeddings"
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
# This file is adapted from
|
||||||
|
# https://github.com/hwchase17/langchain/blob/master/langchain/embeddings/llamacpp.py
|
||||||
|
|
||||||
|
# The MIT License
|
||||||
|
|
||||||
|
# Copyright (c) Harrison Chase
|
||||||
|
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
# THE SOFTWARE.
|
||||||
|
|
||||||
|
"""Wrapper around BigdlLLM embedding models."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
DEFAULT_MODEL_NAME = "gpt2"
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersEmbeddings(BaseModel, Embeddings):
|
||||||
|
"""Wrapper around bigdl-llm transformers embedding models.
|
||||||
|
|
||||||
|
To use, you should have the ``transformers`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
|
||||||
|
embeddings = TransformersEmbeddings.from_model_id(model_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: Any #: :meta private:
|
||||||
|
tokenizer: Any #: :meta private:
|
||||||
|
model_id: str = DEFAULT_MODEL_NAME
|
||||||
|
"""Model id to use."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Key word arguments to pass to the model."""
|
||||||
|
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Key word arguments to pass when calling the `encode` method of the model."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model_id(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
model_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Construct object from model_id"""
|
||||||
|
try:
|
||||||
|
from bigdl.llm.transformers import AutoModel
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"Please install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_kwargs = model_kwargs or {}
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
except:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
|
||||||
|
model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
||||||
|
|
||||||
|
if "trust_remote_code" in _model_kwargs:
|
||||||
|
_model_kwargs = {
|
||||||
|
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||||
|
}
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_kwargs=_model_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def embed(self, text: str):
|
||||||
|
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
input_ids = self.tokenizer.encode(text, return_tensors="pt") # shape: [1, T]
|
||||||
|
embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N]
|
||||||
|
embeddings = embeddings.squeeze(0).detach().numpy()
|
||||||
|
embeddings = np.mean(embeddings, axis=0)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
"""Compute doc embeddings using a HuggingFace transformer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: The list of texts to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embeddings, one for each text.
|
||||||
|
"""
|
||||||
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||||
|
embeddings = [self.embed(text, **self.encode_kwargs).tolist() for text in texts]
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
"""Compute query embeddings using a bigdl-llm transformer model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embeddings for the text.
|
||||||
|
"""
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
embedding = self.embed(text, **self.encode_kwargs)
|
||||||
|
return embedding.tolist()
|
||||||
|
|
@ -24,11 +24,17 @@ from typing import Dict, Type
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
from .bigdlllm import BigdlNativeLLM
|
from .bigdlllm import BigdlNativeLLM
|
||||||
|
from .transformersllm import TransformersLLM
|
||||||
|
from .transformerspipelinellm import TransformersPipelineLLM
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BigdlNativeLLM",
|
"BigdlNativeLLM",
|
||||||
|
"TransformersLLM",
|
||||||
|
"TransformersPipelineLLM"
|
||||||
]
|
]
|
||||||
|
|
||||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
"BigdlNativeLLM": BigdlNativeLLM,
|
"BigdlNativeLLM": BigdlNativeLLM,
|
||||||
|
"TransformersPipelineLLM": TransformersPipelineLLM,
|
||||||
|
"TransformersLLM": TransformersLLM
|
||||||
}
|
}
|
||||||
159
python/llm/src/bigdl/llm/langchain/llms/transformersllm.py
Normal file
159
python/llm/src/bigdl/llm/langchain/llms/transformersllm.py
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
# This file is adapted from
|
||||||
|
# https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py
|
||||||
|
|
||||||
|
# The MIT License
|
||||||
|
|
||||||
|
# Copyright (c) Harrison Chase
|
||||||
|
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
# THE SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
DEFAULT_MODEL_ID = "gpt2"
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersLLM(LLM):
|
||||||
|
"""Wrapper around the BigDL-LLM Transformer-INT4 model
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import TransformersLLM
|
||||||
|
llm = TransformersLLM.from_model_id(model_id="THUDM/chatglm-6b")
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_id: str = DEFAULT_MODEL_ID
|
||||||
|
"""Model name or model path to use."""
|
||||||
|
model_kwargs: Optional[dict] = None
|
||||||
|
"""Key word arguments passed to the model."""
|
||||||
|
model: Any #: :meta private:
|
||||||
|
"""BigDL-LLM Transformer-INT4 model."""
|
||||||
|
tokenizer: Any #: :meta private:
|
||||||
|
"""Huggingface tokenizer model."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model_id(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
model_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLM:
|
||||||
|
"""Construct object from model_id"""
|
||||||
|
try:
|
||||||
|
from bigdl.llm.transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
# AutoModelForSeq2SeqLM,
|
||||||
|
)
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"Please install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_kwargs = model_kwargs or {}
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
except:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
||||||
|
except:
|
||||||
|
model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
||||||
|
|
||||||
|
if "trust_remote_code" in _model_kwargs:
|
||||||
|
_model_kwargs = {
|
||||||
|
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||||
|
}
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_kwargs=_model_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"model_kwargs": self.model_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "BigDL-llm"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||||
|
output = self.model.generate(input_ids, **kwargs)
|
||||||
|
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
|
||||||
|
if stop is not None:
|
||||||
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||||
|
# stop tokens when making calls to huggingface_hub.
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
|
|
@ -0,0 +1,197 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||||
|
# physically located elsewhere.
|
||||||
|
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||||
|
# only search the first bigdl package and end up finding only one sub-package.
|
||||||
|
|
||||||
|
# This file is adapted from
|
||||||
|
# https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py
|
||||||
|
|
||||||
|
# The MIT License
|
||||||
|
|
||||||
|
# Copyright (c) Harrison Chase
|
||||||
|
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
# THE SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
|
||||||
|
DEFAULT_MODEL_ID = "gpt2"
|
||||||
|
DEFAULT_TASK = "text-generation"
|
||||||
|
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersPipelineLLM(LLM):
|
||||||
|
"""Wrapper around the BigDL-LLM Transformer-INT4 model in Transformer.pipeline()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import TransformersPipelineLLM
|
||||||
|
llm = TransformersPipelineLLM.from_model_id(model_id="decapoda-research/llama-7b-hf")
|
||||||
|
"""
|
||||||
|
|
||||||
|
pipeline: Any #: :meta private:
|
||||||
|
model_id: str = DEFAULT_MODEL_ID
|
||||||
|
"""Model name or model path to use."""
|
||||||
|
model_kwargs: Optional[dict] = None
|
||||||
|
"""Key word arguments passed to the model."""
|
||||||
|
pipeline_kwargs: Optional[dict] = None
|
||||||
|
"""Key word arguments passed to the pipeline."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_model_id(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
task: str,
|
||||||
|
model_kwargs: Optional[dict] = None,
|
||||||
|
pipeline_kwargs: Optional[dict] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLM:
|
||||||
|
"""Construct the pipeline object from model_id and task."""
|
||||||
|
try:
|
||||||
|
from bigdl.llm.transformers import (
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
# AutoModelForSeq2SeqLM,
|
||||||
|
)
|
||||||
|
from transformers import AutoTokenizer, LlamaTokenizer
|
||||||
|
from transformers import pipeline as hf_pipeline
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import transformers python package. "
|
||||||
|
"Please install it with `pip install transformers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
_model_kwargs = model_kwargs or {}
|
||||||
|
# TODO: may refactore this code in the future
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
except:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if task == "text-generation":
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
||||||
|
elif task in ("text2text-generation", "summarization"):
|
||||||
|
# TODO: support this when related PR merged
|
||||||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got invalid task {task}, "
|
||||||
|
f"currently only {VALID_TASKS} are supported"
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not load the {task} model due to missing dependencies."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if "trust_remote_code" in _model_kwargs:
|
||||||
|
_model_kwargs = {
|
||||||
|
k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
|
||||||
|
}
|
||||||
|
_pipeline_kwargs = pipeline_kwargs or {}
|
||||||
|
pipeline = hf_pipeline(
|
||||||
|
task=task,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
device='cpu', # only cpu now
|
||||||
|
model_kwargs=_model_kwargs,
|
||||||
|
**_pipeline_kwargs,
|
||||||
|
)
|
||||||
|
if pipeline.task not in VALID_TASKS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got invalid task {pipeline.task}, "
|
||||||
|
f"currently only {VALID_TASKS} are supported"
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
pipeline=pipeline,
|
||||||
|
model_id=model_id,
|
||||||
|
model_kwargs=_model_kwargs,
|
||||||
|
pipeline_kwargs=_pipeline_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"model_kwargs": self.model_kwargs,
|
||||||
|
"pipeline_kwargs": self.pipeline_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "BigDL-llm"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
response = self.pipeline(prompt)
|
||||||
|
if self.pipeline.task == "text-generation":
|
||||||
|
# Text generation return includes the starter text.
|
||||||
|
text = response[0]["generated_text"][len(prompt) :]
|
||||||
|
elif self.pipeline.task == "text2text-generation":
|
||||||
|
text = response[0]["generated_text"]
|
||||||
|
elif self.pipeline.task == "summarization":
|
||||||
|
text = response[0]["summary_text"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got invalid task {self.pipeline.task}, "
|
||||||
|
f"currently only {VALID_TASKS} are supported"
|
||||||
|
)
|
||||||
|
if stop is not None:
|
||||||
|
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||||
|
# stop tokens when making calls to huggingface_hub.
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
Loading…
Reference in a new issue