diff --git a/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py b/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py index 0f09b346..6c8cc780 100644 --- a/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py +++ b/python/llm/src/bigdl/llm/serving/bigdl_llm_model.py @@ -256,10 +256,28 @@ class BigDLLLMAdapter(BaseModelAdapter): return model, tokenizer +class BigDLLMLOWBITAdapter(BaseModelAdapter): + "Model adapter for bigdl-llm backend low-bit models" + + def match(self, model_path: str): + return "bigdl-lowbit" in model_path + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=False, revision=revision + ) + print("Customized bigdl-llm loader") + from bigdl.llm.transformers import AutoModelForCausalLM + model = AutoModelForCausalLM.load_low_bit(model_path) + return model, tokenizer + + def patch_fastchat(): global is_fastchat_patched if is_fastchat_patched: return + register_model_adapter(BigDLLMLOWBITAdapter) register_model_adapter(BigDLLLMAdapter) mapping_fastchat = _get_patch_map()