287 lines
No EOL
9 KiB
Python
287 lines
No EOL
9 KiB
Python
import math
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from transformers.models.auto import AutoModel
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
# from transformers.modeling_layers import GradientCheckpointingLayer
|
|
from transformers.activations import ACT2FN
|
|
from transformers.utils import logging
|
|
|
|
from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
self.elementwise_affine = elementwise_affine
|
|
if self.elementwise_affine:
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
|
|
def _norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x):
|
|
output = self._norm(x.float()).type_as(x)
|
|
if self.weight is not None:
|
|
output = output * self.weight
|
|
return output
|
|
|
|
def extra_repr(self) -> str:
|
|
return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
|
|
|
def modulate(x, shift, scale):
|
|
"""Apply modulation to input tensor."""
|
|
return x * (1 + scale) + shift
|
|
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
"""
|
|
Embeds scalar timesteps into vector representations.
|
|
|
|
Args:
|
|
hidden_size (`int`): Size of the output embedding
|
|
frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding
|
|
"""
|
|
def __init__(self, hidden_size, frequency_embedding_size=256):
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=False),
|
|
# nn.SiLU(),
|
|
ACT2FN['silu'],
|
|
nn.Linear(hidden_size, hidden_size, bias=False),
|
|
)
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
|
|
@staticmethod
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
|
|
Args:
|
|
t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
dim (`int`): The dimension of the output.
|
|
max_period (`int`, optional): Controls the minimum frequency of the embeddings.
|
|
|
|
Returns:
|
|
`torch.Tensor`: An [N, D] Tensor of positional embeddings.
|
|
"""
|
|
half = dim // 2
|
|
freqs = torch.exp(
|
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
|
).to(t.device)
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
return embedding.to(t.dtype)
|
|
|
|
def forward(self, t):
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
|
t_emb = self.mlp(t_freq)
|
|
return t_emb
|
|
|
|
|
|
class FeedForwardNetwork(nn.Module):
|
|
"""
|
|
Standard feed-forward network with SwiGLU activation.
|
|
|
|
Args:
|
|
embed_dim (`int`): Input dimension
|
|
ffn_dim (`int`): Hidden dimension
|
|
"""
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
ffn_dim,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
|
self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False)
|
|
self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False)
|
|
self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function
|
|
|
|
def forward(self, x):
|
|
gate = self.gate_proj(x)
|
|
up = self.up_proj(x)
|
|
|
|
# SwiGLU activation
|
|
# gate = F.silu(gate)
|
|
gate = self.act_fn(gate)
|
|
return self.down_proj(gate * up)
|
|
|
|
|
|
class HeadLayer(nn.Module):
|
|
"""
|
|
A layer in the diffusion head.
|
|
|
|
Args:
|
|
embed_dim (`int`): Input dimension
|
|
ffn_dim (`int`): Hidden dimension
|
|
cond_dim (`int`): Condition embedding dimension
|
|
norm_eps (`float`, optional): Epsilon for normalization
|
|
"""
|
|
def __init__(
|
|
self,
|
|
embed_dim,
|
|
ffn_dim,
|
|
cond_dim,
|
|
norm_eps=1e-5,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.cond_dim = cond_dim
|
|
self.ffn_dim = ffn_dim
|
|
self.ffn = FeedForwardNetwork(
|
|
self.embed_dim,
|
|
self.ffn_dim,
|
|
)
|
|
self.norm = RMSNorm(self.embed_dim, eps=norm_eps)
|
|
self.adaLN_modulation = nn.Sequential(
|
|
# nn.SiLU(),
|
|
ACT2FN['silu'],
|
|
nn.Linear(cond_dim, 3 * self.embed_dim, bias=False)
|
|
)
|
|
|
|
def forward(self, x, c):
|
|
shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1)
|
|
x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn))
|
|
return x
|
|
|
|
|
|
class FinalLayer(nn.Module):
|
|
"""
|
|
Final layer in the diffusion head.
|
|
|
|
Args:
|
|
hidden_size (`int`): Input dimension
|
|
output_size (`int`): Output dimension
|
|
cond_size (`int`): Condition embedding dimension
|
|
norm_eps (`float`, optional): Epsilon for normalization
|
|
"""
|
|
def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5):
|
|
super().__init__()
|
|
self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False)
|
|
self.linear = nn.Linear(hidden_size, output_size, bias=False)
|
|
self.adaLN_modulation = nn.Sequential(
|
|
# nn.SiLU(),
|
|
ACT2FN['silu'],
|
|
nn.Linear(cond_size, 2 * hidden_size, bias=False)
|
|
)
|
|
|
|
def forward(self, x, c):
|
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
|
x = modulate(self.norm_final(x), shift, scale)
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
class VibeVoiceDiffusionHead(PreTrainedModel):
|
|
"""
|
|
Diffusion head model for vibevoice.
|
|
|
|
Args:
|
|
config (`VibeVoiceDiffusionHeadConfig`): Model configuration
|
|
latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`.
|
|
"""
|
|
config_class = VibeVoiceDiffusionHeadConfig
|
|
supports_gradient_checkpointing = True
|
|
_supports_flash_attn_2 = True
|
|
_supports_sdpa = True
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.cond_dim = config.hidden_size
|
|
latent_size = config.latent_size
|
|
|
|
self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False)
|
|
self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False)
|
|
self.t_embedder = TimestepEmbedder(self.cond_dim)
|
|
|
|
ffn_dim = int(config.hidden_size * config.head_ffn_ratio)
|
|
|
|
# Create the intermediate layers
|
|
self.layers = nn.ModuleList([
|
|
HeadLayer(
|
|
embed_dim=config.hidden_size,
|
|
ffn_dim=ffn_dim,
|
|
cond_dim=self.cond_dim,
|
|
norm_eps=config.rms_norm_eps
|
|
)
|
|
for _ in range(config.head_layers)
|
|
])
|
|
|
|
# Final layer for output
|
|
self.final_layer = FinalLayer(
|
|
hidden_size=config.hidden_size,
|
|
output_size=latent_size,
|
|
cond_size=self.cond_dim,
|
|
norm_eps=config.rms_norm_eps
|
|
)
|
|
|
|
self.initialize_weights()
|
|
|
|
def initialize_weights(self):
|
|
"""Initialize the weights of the model."""
|
|
# Initialize timestep embedder
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
|
|
|
# Zero-out adaLN modulation layers
|
|
for layer in self.layers:
|
|
nn.init.constant_(layer.adaLN_modulation[-1].weight, 0)
|
|
|
|
# Zero-out output layers
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
|
nn.init.constant_(self.final_layer.linear.weight, 0)
|
|
|
|
def forward(
|
|
self,
|
|
noisy_images,
|
|
timesteps,
|
|
condition,
|
|
):
|
|
"""
|
|
Forward pass of the prediction head.
|
|
|
|
Args:
|
|
noisy_images (`torch.Tensor`): Noisy images/latents to denoise
|
|
timesteps (`torch.Tensor`): Timesteps for diffusion
|
|
condition (`torch.Tensor`): Conditioning information
|
|
|
|
Returns:
|
|
`torch.Tensor`: The predicted noise/velocity
|
|
"""
|
|
x = self.noisy_images_proj(noisy_images)
|
|
t = self.t_embedder(timesteps)
|
|
condition = self.cond_proj(condition)
|
|
c = condition + t
|
|
|
|
for layer in self.layers:
|
|
x = layer(x, c)
|
|
|
|
x = self.final_layer(x, c)
|
|
return x
|
|
|
|
|
|
AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead)
|
|
|
|
__all__ = [
|
|
"VibeVoiceDiffusionHead",
|
|
] |