From 55e705a84cfe61f77eae4f5a31add166adedda49 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 30 Aug 2023 11:16:14 +0800 Subject: [PATCH] [LLM] Support the rest of AutoXXX classes in Transformers API (#8815) * add transformers auto models * fix --- .../src/bigdl/llm/transformers/__init__.py | 6 ++++- .../llm/src/bigdl/llm/transformers/model.py | 24 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/__init__.py b/python/llm/src/bigdl/llm/transformers/__init__.py index d6ff9239..b446cfcc 100644 --- a/python/llm/src/bigdl/llm/transformers/__init__.py +++ b/python/llm/src/bigdl/llm/transformers/__init__.py @@ -15,5 +15,9 @@ # from .convert import ggml_convert_quant -from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq +from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \ + AutoModelForSpeechSeq2Seq, AutoModelForQuestionAnswering, \ + AutoModelForSequenceClassification, AutoModelForMaskedLM, \ + AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \ + AutoModelForTokenClassification from .modelling_bigdl import * diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 67c60dd5..4d6231ce 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -302,3 +302,27 @@ class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): class AutoModelForSeq2SeqLM(_BaseAutoModelClass): HF_Model = transformers.AutoModelForSeq2SeqLM + + +class AutoModelForSequenceClassification(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForSequenceClassification + + +class AutoModelForMaskedLM(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForMaskedLM + + +class AutoModelForQuestionAnswering(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForQuestionAnswering + + +class AutoModelForNextSentencePrediction(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForNextSentencePrediction + + +class AutoModelForMultipleChoice(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForMultipleChoice + + +class AutoModelForTokenClassification(_BaseAutoModelClass): + HF_Model = transformers.AutoModelForTokenClassification