hot-fix redundant import funasr (#12277)
This commit is contained in:
parent
a0c6432899
commit
08cb065370
1 changed files with 9 additions and 2 deletions
|
|
@ -93,6 +93,9 @@ class _BaseAutoModelClass:
|
||||||
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
warnings.warn("`torch_dtype` will be ignored, `torch.float` will be used")
|
||||||
kwargs["torch_dtype"] = torch.float32
|
kwargs["torch_dtype"] = torch.float32
|
||||||
|
|
||||||
|
if hasattr(cls, "get_cls_model"):
|
||||||
|
cls.HF_Model = cls.get_cls_model()
|
||||||
|
|
||||||
low_bit = kwargs.pop("load_in_low_bit", "sym_int4")
|
low_bit = kwargs.pop("load_in_low_bit", "sym_int4")
|
||||||
qtype_map = {
|
qtype_map = {
|
||||||
"sym_int4": "sym_int4_rtn",
|
"sym_int4": "sym_int4_rtn",
|
||||||
|
|
@ -574,8 +577,6 @@ class AutoModelForTokenClassification(_BaseAutoModelClass):
|
||||||
|
|
||||||
|
|
||||||
class FunAsrAutoModel(_BaseAutoModelClass):
|
class FunAsrAutoModel(_BaseAutoModelClass):
|
||||||
import funasr
|
|
||||||
HF_Model = funasr.AutoModel
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.model = self.from_pretrained(*args, **kwargs)
|
self.model = self.from_pretrained(*args, **kwargs)
|
||||||
|
|
@ -583,6 +584,12 @@ class FunAsrAutoModel(_BaseAutoModelClass):
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
return getattr(self.model, name)
|
return getattr(self.model, name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_cls_model(cls):
|
||||||
|
import funasr
|
||||||
|
cls_model = funasr.AutoModel
|
||||||
|
return cls_model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def optimize_npu_model(cls, *args, **kwargs):
|
def optimize_npu_model(cls, *args, **kwargs):
|
||||||
from ipex_llm.transformers.npu_models.convert_mp import optimize_funasr
|
from ipex_llm.transformers.npu_models.convert_mp import optimize_funasr
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue