From 2f77d485d8b9781520e7a417cca6d628cd2f77d4 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Thu, 6 Jul 2023 17:50:05 +0800 Subject: [PATCH] Llm: Initial support of langchain transformer int4 API (#8459) * first commit of transformer int4 and pipeline * basic examples temp save for embeddings support embeddings and docqa exaple * fix based on comment * small fix --- .../langchain/{ => native_int4}/docqa.py | 0 .../langchain/{ => native_int4}/streamchat.py | 1 - .../langchain/transformers_int4/chat.py | 63 ++++++ .../langchain/transformers_int4/docqa.py | 78 +++++++ .../llm/langchain/embeddings/__init__.py | 2 + .../embeddings/transformersembeddings.py | 163 +++++++++++++++ .../src/bigdl/llm/langchain/llms/__init__.py | 6 + .../llm/langchain/llms/transformersllm.py | 159 ++++++++++++++ .../langchain/llms/transformerspipelinellm.py | 197 ++++++++++++++++++ 9 files changed, 668 insertions(+), 1 deletion(-) rename python/llm/example/langchain/{ => native_int4}/docqa.py (100%) rename python/llm/example/langchain/{ => native_int4}/streamchat.py (99%) create mode 100644 python/llm/example/langchain/transformers_int4/chat.py create mode 100644 python/llm/example/langchain/transformers_int4/docqa.py create mode 100644 python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py create mode 100644 python/llm/src/bigdl/llm/langchain/llms/transformersllm.py create mode 100644 python/llm/src/bigdl/llm/langchain/llms/transformerspipelinellm.py diff --git a/python/llm/example/langchain/docqa.py b/python/llm/example/langchain/native_int4/docqa.py similarity index 100% rename from python/llm/example/langchain/docqa.py rename to python/llm/example/langchain/native_int4/docqa.py diff --git a/python/llm/example/langchain/streamchat.py b/python/llm/example/langchain/native_int4/streamchat.py similarity index 99% rename from python/llm/example/langchain/streamchat.py rename to python/llm/example/langchain/native_int4/streamchat.py index f3b32e91..e9d52838 100644 --- a/python/llm/example/langchain/streamchat.py +++ b/python/llm/example/langchain/native_int4/streamchat.py @@ -33,7 +33,6 @@ def main(args): model_path = args.model_path model_family = args.model_family n_threads = args.thread_num - template ="""{question}""" prompt = PromptTemplate(template=template, input_variables=["question"]) diff --git a/python/llm/example/langchain/transformers_int4/chat.py b/python/llm/example/langchain/transformers_int4/chat.py new file mode 100644 index 00000000..3c704672 --- /dev/null +++ b/python/llm/example/langchain/transformers_int4/chat.py @@ -0,0 +1,63 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +import argparse + +from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM +from langchain import PromptTemplate, LLMChain +from langchain import HuggingFacePipeline + + +def main(args): + + question = args.question + model_path = args.model_path + template ="""{question}""" + + prompt = PromptTemplate(template=template, input_variables=["question"]) + + # llm = TransformersPipelineLLM.from_model_id( + # model_id=model_path, + # task="text-generation", + # model_kwargs={"temperature": 0, "max_length": 64, "trust_remote_code": True}, + # ) + + llm = TransformersLLM.from_model_id( + model_id=model_path, + model_kwargs={"temperature": 0, "max_length": 64, "trust_remote_code": True}, + ) + + llm_chain = LLMChain(prompt=prompt, llm=llm) + + output = llm_chain.run(question) + print("====output=====") + print(output) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Llama-CPP-Python style API Simple Example') + parser.add_argument('-m','--model-path', type=str, required=True, + help='the path to transformers model') + parser.add_argument('-q', '--question', type=str, default='What is AI?', + help='qustion you want to ask.') + args = parser.parse_args() + + main(args) diff --git a/python/llm/example/langchain/transformers_int4/docqa.py b/python/llm/example/langchain/transformers_int4/docqa.py new file mode 100644 index 00000000..e42fa866 --- /dev/null +++ b/python/llm/example/langchain/transformers_int4/docqa.py @@ -0,0 +1,78 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +import argparse + +from langchain.vectorstores import Chroma +from langchain.chains.chat_vector_db.prompts import (CONDENSE_QUESTION_PROMPT, + QA_PROMPT) +from langchain.text_splitter import CharacterTextSplitter +from langchain.chains.question_answering import load_qa_chain +from langchain.callbacks.manager import CallbackManager + +from bigdl.llm.langchain.llms import TransformersLLM +from bigdl.llm.langchain.embeddings import TransformersEmbeddings + + + +def main(args): + + input_path = args.input_path + model_path = args.model_path + query = args.question + + # split texts of input doc + with open(input_path) as f: + input_doc = f.read() + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) + texts = text_splitter.split_text(input_doc) + + # create embeddings and store into vectordb + embeddings = TransformersEmbeddings.from_model_id(model_id=model_path) + docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever() + + #get relavant texts + docs = docsearch.get_relevant_documents(query) + + bigdl_llm = TransformersLLM.from_model_id( + model_id=model_path, + model_kwargs={"temperature": 0, "max_length": 1024, "trust_remote_code": True}, + ) + + doc_chain = load_qa_chain( + bigdl_llm, chain_type="stuff", prompt=QA_PROMPT + ) + + output = doc_chain.run(input_documents=docs, question=query) + print(output) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Transformer-int4 style API Simple Example') + parser.add_argument('-m','--model-path', type=str, required=True, + help='the path to transformers model') + parser.add_argument('-i', '--input-path', type=str, + help='the path to the input doc.') + parser.add_argument('-q', '--question', type=str, default='What is AI?', + help='qustion you want to ask.') + args = parser.parse_args() + + main(args) diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py index f5c9fac3..0483f80b 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py @@ -20,7 +20,9 @@ # only search the first bigdl package and end up finding only one sub-package. from .bigdlllm import BigdlNativeEmbeddings +from .transformersembeddings import TransformersEmbeddings __all__ = [ "BigdlNativeEmbeddings", + "TransformersEmbeddings" ] diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py new file mode 100644 index 00000000..51a98dc1 --- /dev/null +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -0,0 +1,163 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +# This file is adapted from +# https://github.com/hwchase17/langchain/blob/master/langchain/embeddings/llamacpp.py + +# The MIT License + +# Copyright (c) Harrison Chase + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +"""Wrapper around BigdlLLM embedding models.""" +from typing import Any, Dict, List, Optional +import numpy as np + +from pydantic import BaseModel, Extra, Field + +from langchain.embeddings.base import Embeddings + +DEFAULT_MODEL_NAME = "gpt2" + + +class TransformersEmbeddings(BaseModel, Embeddings): + """Wrapper around bigdl-llm transformers embedding models. + + To use, you should have the ``transformers`` python package installed. + + Example: + .. code-block:: python + + from bigdl.llm.langchain.embeddings import TransformersEmbeddings + embeddings = TransformersEmbeddings.from_model_id(model_id) + """ + + model: Any #: :meta private: + tokenizer: Any #: :meta private: + model_id: str = DEFAULT_MODEL_NAME + """Model id to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass to the model.""" + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Key word arguments to pass when calling the `encode` method of the model.""" + + @classmethod + def from_model_id( + cls, + model_id: str, + model_kwargs: Optional[dict] = None, + **kwargs: Any, + ): + """Construct object from model_id""" + try: + from bigdl.llm.transformers import AutoModel + from transformers import AutoTokenizer, LlamaTokenizer + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + + _model_kwargs = model_kwargs or {} + # TODO: may refactore this code in the future + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) + except: + tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs) + + model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs) + + if "trust_remote_code" in _model_kwargs: + _model_kwargs = { + k: v for k, v in _model_kwargs.items() if k != "trust_remote_code" + } + + return cls( + model_id=model_id, + model=model, + tokenizer=tokenizer, + model_kwargs=_model_kwargs, + **kwargs, + ) + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed(self, text: str): + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + input_ids = self.tokenizer.encode(text, return_tensors="pt") # shape: [1, T] + embeddings = self.model(input_ids, return_dict=False)[0] # shape: [1, T, N] + embeddings = embeddings.squeeze(0).detach().numpy() + embeddings = np.mean(embeddings, axis=0) + return embeddings + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + texts = list(map(lambda x: x.replace("\n", " "), texts)) + embeddings = [self.embed(text, **self.encode_kwargs).tolist() for text in texts] + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a bigdl-llm transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + text = text.replace("\n", " ") + embedding = self.embed(text, **self.encode_kwargs) + return embedding.tolist() diff --git a/python/llm/src/bigdl/llm/langchain/llms/__init__.py b/python/llm/src/bigdl/llm/langchain/llms/__init__.py index 5ec5b38d..c77dbe69 100644 --- a/python/llm/src/bigdl/llm/langchain/llms/__init__.py +++ b/python/llm/src/bigdl/llm/langchain/llms/__init__.py @@ -24,11 +24,17 @@ from typing import Dict, Type from langchain.llms.base import BaseLLM from .bigdlllm import BigdlNativeLLM +from .transformersllm import TransformersLLM +from .transformerspipelinellm import TransformersPipelineLLM __all__ = [ "BigdlNativeLLM", + "TransformersLLM", + "TransformersPipelineLLM" ] type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "BigdlNativeLLM": BigdlNativeLLM, + "TransformersPipelineLLM": TransformersPipelineLLM, + "TransformersLLM": TransformersLLM } \ No newline at end of file diff --git a/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py b/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py new file mode 100644 index 00000000..d0801041 --- /dev/null +++ b/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py @@ -0,0 +1,159 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +# This file is adapted from +# https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py + +# The MIT License + +# Copyright (c) Harrison Chase + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +import importlib.util +import logging +from typing import Any, List, Mapping, Optional + +from pydantic import Extra + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens + +DEFAULT_MODEL_ID = "gpt2" + + +class TransformersLLM(LLM): + """Wrapper around the BigDL-LLM Transformer-INT4 model + + Example: + .. code-block:: python + + from langchain.llms import TransformersLLM + llm = TransformersLLM.from_model_id(model_id="THUDM/chatglm-6b") + """ + + model_id: str = DEFAULT_MODEL_ID + """Model name or model path to use.""" + model_kwargs: Optional[dict] = None + """Key word arguments passed to the model.""" + model: Any #: :meta private: + """BigDL-LLM Transformer-INT4 model.""" + tokenizer: Any #: :meta private: + """Huggingface tokenizer model.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @classmethod + def from_model_id( + cls, + model_id: str, + model_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> LLM: + """Construct object from model_id""" + try: + from bigdl.llm.transformers import ( + AutoModel, + AutoModelForCausalLM, + # AutoModelForSeq2SeqLM, + ) + from transformers import AutoTokenizer, LlamaTokenizer + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + + _model_kwargs = model_kwargs or {} + # TODO: may refactore this code in the future + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) + except: + tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs) + + # TODO: may refactore this code in the future + try: + model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs) + except: + model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs) + + if "trust_remote_code" in _model_kwargs: + _model_kwargs = { + k: v for k, v in _model_kwargs.items() if k != "trust_remote_code" + } + + return cls( + model_id=model_id, + model=model, + tokenizer=tokenizer, + model_kwargs=_model_kwargs, + **kwargs, + ) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model_id": self.model_id, + "model_kwargs": self.model_kwargs, + } + + @property + def _llm_type(self) -> str: + return "BigDL-llm" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + output = self.model.generate(input_ids, **kwargs) + text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :] + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + return text diff --git a/python/llm/src/bigdl/llm/langchain/llms/transformerspipelinellm.py b/python/llm/src/bigdl/llm/langchain/llms/transformerspipelinellm.py new file mode 100644 index 00000000..a120acf5 --- /dev/null +++ b/python/llm/src/bigdl/llm/langchain/llms/transformerspipelinellm.py @@ -0,0 +1,197 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This would makes sure Python is aware there is more than one sub-package within bigdl, +# physically located elsewhere. +# Otherwise there would be module not found error in non-pip's setting as Python would +# only search the first bigdl package and end up finding only one sub-package. + +# This file is adapted from +# https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py + +# The MIT License + +# Copyright (c) Harrison Chase + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +import importlib.util +import logging +from typing import Any, List, Mapping, Optional + +from pydantic import Extra + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens + +DEFAULT_MODEL_ID = "gpt2" +DEFAULT_TASK = "text-generation" +VALID_TASKS = ("text2text-generation", "text-generation", "summarization") + + +class TransformersPipelineLLM(LLM): + """Wrapper around the BigDL-LLM Transformer-INT4 model in Transformer.pipeline() + + Example: + .. code-block:: python + + from langchain.llms import TransformersPipelineLLM + llm = TransformersPipelineLLM.from_model_id(model_id="decapoda-research/llama-7b-hf") + """ + + pipeline: Any #: :meta private: + model_id: str = DEFAULT_MODEL_ID + """Model name or model path to use.""" + model_kwargs: Optional[dict] = None + """Key word arguments passed to the model.""" + pipeline_kwargs: Optional[dict] = None + """Key word arguments passed to the pipeline.""" + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @classmethod + def from_model_id( + cls, + model_id: str, + task: str, + model_kwargs: Optional[dict] = None, + pipeline_kwargs: Optional[dict] = None, + **kwargs: Any, + ) -> LLM: + """Construct the pipeline object from model_id and task.""" + try: + from bigdl.llm.transformers import ( + AutoModel, + AutoModelForCausalLM, + # AutoModelForSeq2SeqLM, + ) + from transformers import AutoTokenizer, LlamaTokenizer + from transformers import pipeline as hf_pipeline + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + + _model_kwargs = model_kwargs or {} + # TODO: may refactore this code in the future + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, **_model_kwargs) + except: + tokenizer = LlamaTokenizer.from_pretrained(model_id, **_model_kwargs) + + try: + if task == "text-generation": + model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs) + elif task in ("text2text-generation", "summarization"): + # TODO: support this when related PR merged + model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs) + else: + raise ValueError( + f"Got invalid task {task}, " + f"currently only {VALID_TASKS} are supported" + ) + except ImportError as e: + raise ValueError( + f"Could not load the {task} model due to missing dependencies." + ) from e + + if "trust_remote_code" in _model_kwargs: + _model_kwargs = { + k: v for k, v in _model_kwargs.items() if k != "trust_remote_code" + } + _pipeline_kwargs = pipeline_kwargs or {} + pipeline = hf_pipeline( + task=task, + model=model, + tokenizer=tokenizer, + device='cpu', # only cpu now + model_kwargs=_model_kwargs, + **_pipeline_kwargs, + ) + if pipeline.task not in VALID_TASKS: + raise ValueError( + f"Got invalid task {pipeline.task}, " + f"currently only {VALID_TASKS} are supported" + ) + return cls( + pipeline=pipeline, + model_id=model_id, + model_kwargs=_model_kwargs, + pipeline_kwargs=_pipeline_kwargs, + **kwargs, + ) + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model_id": self.model_id, + "model_kwargs": self.model_kwargs, + "pipeline_kwargs": self.pipeline_kwargs, + } + + @property + def _llm_type(self) -> str: + return "BigDL-llm" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + response = self.pipeline(prompt) + if self.pipeline.task == "text-generation": + # Text generation return includes the starter text. + text = response[0]["generated_text"][len(prompt) :] + elif self.pipeline.task == "text2text-generation": + text = response[0]["generated_text"] + elif self.pipeline.task == "summarization": + text = response[0]["summary_text"] + else: + raise ValueError( + f"Got invalid task {self.pipeline.task}, " + f"currently only {VALID_TASKS} are supported" + ) + if stop is not None: + # This is a bit hacky, but I can't figure out a better way to enforce + # stop tokens when making calls to huggingface_hub. + text = enforce_stop_tokens(text, stop) + return text