hot-fix redundant import funasr (#12277)

This commit is contained in:
SONG Ge 2024-10-25 19:40:39 +08:00 committed by GitHub
parent a0c6432899
commit 08cb065370
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -93,6 +93,9 @@ class _BaseAutoModelClass:
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
kwargs["torch_dtype"] = torch.float32
if hasattr(cls, "get_cls_model"):
cls.HF_Model = cls.get_cls_model()
low_bit = kwargs.pop("load_in_low_bit", "sym_int4")
qtype_map = {
"sym_int4": "sym_int4_rtn",
@ -574,8 +577,6 @@ class AutoModelForTokenClassification(_BaseAutoModelClass):
class FunAsrAutoModel(_BaseAutoModelClass):
import funasr
HF_Model = funasr.AutoModel
def __init__(self, *args, **kwargs):
self.model = self.from_pretrained(*args, **kwargs)
@ -583,6 +584,12 @@ class FunAsrAutoModel(_BaseAutoModelClass):
def __getattr__(self, name):
return getattr(self.model, name)
@classmethod
def get_cls_model(cls):
import funasr
cls_model = funasr.AutoModel
return cls_model
@classmethod
def optimize_npu_model(cls, *args, **kwargs):
from ipex_llm.transformers.npu_models.convert_mp import optimize_funasr