hot-fix redundant import funasr (#12277)
This commit is contained in:
parent
a0c6432899
commit
08cb065370
1 changed files with 9 additions and 2 deletions
|
|
@ -93,6 +93,9 @@ class _BaseAutoModelClass:
|
|||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||
kwargs["torch_dtype"] = torch.float32
|
||||
|
||||
if hasattr(cls, "get_cls_model"):
|
||||
cls.HF_Model = cls.get_cls_model()
|
||||
|
||||
low_bit = kwargs.pop("load_in_low_bit", "sym_int4")
|
||||
qtype_map = {
|
||||
"sym_int4": "sym_int4_rtn",
|
||||
|
|
@ -574,8 +577,6 @@ class AutoModelForTokenClassification(_BaseAutoModelClass):
|
|||
|
||||
|
||||
class FunAsrAutoModel(_BaseAutoModelClass):
|
||||
import funasr
|
||||
HF_Model = funasr.AutoModel
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.model = self.from_pretrained(*args, **kwargs)
|
||||
|
|
@ -583,6 +584,12 @@ class FunAsrAutoModel(_BaseAutoModelClass):
|
|||
def __getattr__(self, name):
|
||||
return getattr(self.model, name)
|
||||
|
||||
@classmethod
|
||||
def get_cls_model(cls):
|
||||
import funasr
|
||||
cls_model = funasr.AutoModel
|
||||
return cls_model
|
||||
|
||||
@classmethod
|
||||
def optimize_npu_model(cls, *args, **kwargs):
|
||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_funasr
|
||||
|
|
|
|||
Loading…
Reference in a new issue