fix bug: from_pretrained of VibeVoiceProcessor
This commit is contained in:
parent
5d09c31021
commit
21c35b4701
1 changed files with 21 additions and 6 deletions
|
@ -56,23 +56,38 @@ class VibeVoiceProcessor:
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
from transformers.utils import cached_file
|
||||||
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
|
||||||
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
|
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
|
||||||
VibeVoiceTextTokenizer,
|
VibeVoiceTextTokenizer,
|
||||||
VibeVoiceTextTokenizerFast
|
VibeVoiceTextTokenizerFast
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load processor configuration
|
# Try to load from local path first, then from HF hub
|
||||||
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
|
||||||
|
config = None
|
||||||
|
|
||||||
if os.path.exists(config_path):
|
if os.path.exists(config_path):
|
||||||
|
# Local path exists
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
|
# Try to load from HF hub
|
||||||
config = {
|
try:
|
||||||
"speech_tok_compress_ratio": 3200,
|
config_file = cached_file(
|
||||||
"db_normalize": True,
|
pretrained_model_name_or_path,
|
||||||
}
|
"preprocessor_config.json",
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not load preprocessor_config.json from {pretrained_model_name_or_path}: {e}")
|
||||||
|
logger.warning("Using default configuration")
|
||||||
|
config = {
|
||||||
|
"speech_tok_compress_ratio": 3200,
|
||||||
|
"db_normalize": True,
|
||||||
|
}
|
||||||
|
|
||||||
# Extract main processor parameters
|
# Extract main processor parameters
|
||||||
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
|
||||||
|
|
Loading…
Reference in a new issue