* Add mamba cpu example * Add mamba gpu example * Use a smaller model as the example * minor fixes --------- Co-authored-by: Shengsheng Huang <shengsheng.huang@intel.com>
		
			
				
	
	
		
			926 lines
		
	
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			926 lines
		
	
	
	
		
			33 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#
 | 
						|
# Copyright 2016 The BigDL Authors.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
#
 | 
						|
# The code is adapted from: https://github.com/state-spaces/mamba.
 | 
						|
#
 | 
						|
 | 
						|
import json
 | 
						|
import math
 | 
						|
import os
 | 
						|
import time
 | 
						|
from collections import namedtuple
 | 
						|
from dataclasses import dataclass, field
 | 
						|
from functools import partial
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.nn.functional as F
 | 
						|
from einops import rearrange, repeat
 | 
						|
from torch import Tensor
 | 
						|
from transformers.generation import (
 | 
						|
    GreedySearchDecoderOnlyOutput,
 | 
						|
    SampleDecoderOnlyOutput,
 | 
						|
    TextStreamer,
 | 
						|
)
 | 
						|
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
 | 
						|
from transformers.utils.hub import cached_file
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class MambaConfig:
 | 
						|
    d_model: int = 2560
 | 
						|
    n_layer: int = 64
 | 
						|
    vocab_size: int = 50277
 | 
						|
    ssm_cfg: dict = field(default_factory=dict)
 | 
						|
    rms_norm: bool = True
 | 
						|
    fused_add_norm: bool = False
 | 
						|
    residual_in_fp32: bool = True
 | 
						|
    pad_vocab_size_multiple: int = 8
 | 
						|
 | 
						|
 | 
						|
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
 | 
						|
def _init_weights(
 | 
						|
    module,
 | 
						|
    n_layer,
 | 
						|
    initializer_range=0.02,
 | 
						|
    rescale_prenorm_residual=True,
 | 
						|
    n_residuals_per_layer=1,
 | 
						|
):
 | 
						|
    if isinstance(module, nn.Linear):
 | 
						|
        if module.bias is not None:
 | 
						|
            if not getattr(module.bias, "_no_reinit", False):
 | 
						|
                nn.init.zeros_(module.bias)
 | 
						|
    elif isinstance(module, nn.Embedding):
 | 
						|
        nn.init.normal_(module.weight, std=initializer_range)
 | 
						|
 | 
						|
    if rescale_prenorm_residual:
 | 
						|
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
 | 
						|
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
 | 
						|
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
 | 
						|
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
 | 
						|
        #
 | 
						|
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
 | 
						|
        for name, p in module.named_parameters():
 | 
						|
            if name in ["out_proj.weight", "fc2.weight"]:
 | 
						|
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
 | 
						|
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
 | 
						|
                # We need to reinit p since this code could be called multiple times
 | 
						|
                # Having just p *= scale would repeatedly scale it down
 | 
						|
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
 | 
						|
                with torch.no_grad():
 | 
						|
                    p /= math.sqrt(n_residuals_per_layer * n_layer)
 | 
						|
 | 
						|
 | 
						|
def selective_scan(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False):
 | 
						|
    """
 | 
						|
    u: r(B D L)
 | 
						|
    delta: r(B D L)
 | 
						|
    A: c(D N) or r(D N)
 | 
						|
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
 | 
						|
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
 | 
						|
    D: r(D)
 | 
						|
    z: r(B D L)
 | 
						|
    delta_bias: r(D), fp32
 | 
						|
 | 
						|
    out: r(B D L)
 | 
						|
    last_state (optional): r(B D dstate) or c(B D dstate)
 | 
						|
    """
 | 
						|
    dtype_in = u.dtype
 | 
						|
    u = u.float()
 | 
						|
    delta = delta.float()
 | 
						|
    if delta_bias is not None:
 | 
						|
        delta = delta + delta_bias[..., None].float()
 | 
						|
    if delta_softplus:
 | 
						|
        delta = F.softplus(delta)
 | 
						|
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
 | 
						|
    is_variable_B = B.dim() >= 3
 | 
						|
    is_variable_C = C.dim() >= 3
 | 
						|
    if A.is_complex():
 | 
						|
        if is_variable_B:
 | 
						|
            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
 | 
						|
        if is_variable_C:
 | 
						|
            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
 | 
						|
    else:
 | 
						|
        B = B.float()
 | 
						|
        C = C.float()
 | 
						|
    x = A.new_zeros((batch, dim, dstate))
 | 
						|
    ys = []
 | 
						|
    deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
 | 
						|
    if not is_variable_B:
 | 
						|
        deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
 | 
						|
    else:
 | 
						|
        if B.dim() == 3:
 | 
						|
            deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
 | 
						|
        else:
 | 
						|
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
 | 
						|
            deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
 | 
						|
    if is_variable_C and C.dim() == 4:
 | 
						|
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
 | 
						|
    last_state = None
 | 
						|
    for i in range(u.shape[2]):
 | 
						|
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
 | 
						|
        if not is_variable_C:
 | 
						|
            y = torch.einsum("bdn,dn->bd", x, C)
 | 
						|
        else:
 | 
						|
            if C.dim() == 3:
 | 
						|
                y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
 | 
						|
            else:
 | 
						|
                y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
 | 
						|
        if i == u.shape[2] - 1:
 | 
						|
            last_state = x
 | 
						|
        if y.is_complex():
 | 
						|
            y = y.real * 2
 | 
						|
        ys.append(y)
 | 
						|
    y = torch.stack(ys, dim=2)  # (batch dim L)
 | 
						|
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
 | 
						|
    if z is not None:
 | 
						|
        out = out * F.silu(z)
 | 
						|
    out = out.to(dtype=dtype_in)
 | 
						|
    return out if not return_last_state else (out, last_state)
 | 
						|
 | 
						|
 | 
						|
def layer_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
 | 
						|
    dtype = x.dtype
 | 
						|
    if residual is not None:
 | 
						|
        x = (x + residual).to(x.dtype)
 | 
						|
    out = F.layer_norm(
 | 
						|
        x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
 | 
						|
    ).to(dtype)
 | 
						|
    return out if not prenorm else (out, x)
 | 
						|
 | 
						|
 | 
						|
def rms_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
 | 
						|
    dtype = x.dtype
 | 
						|
    if residual is not None:
 | 
						|
        x = (x + residual).to(x.dtype)
 | 
						|
    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
 | 
						|
    out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
 | 
						|
    out = out.to(dtype)
 | 
						|
    return out if not prenorm else (out, x)
 | 
						|
 | 
						|
 | 
						|
def load_config_hf(model_name):
 | 
						|
    resolved_archive_file = cached_file(
 | 
						|
        model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False
 | 
						|
    )
 | 
						|
    return json.load(open(resolved_archive_file))
 | 
						|
 | 
						|
 | 
						|
def load_state_dict_hf(model_name, device=None, dtype=None):
 | 
						|
    mapped_device = "cpu" if dtype not in [torch.float32, None] else device
 | 
						|
    resolved_archive_file = cached_file(
 | 
						|
        model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
 | 
						|
    )
 | 
						|
    return torch.load(resolved_archive_file, map_location=mapped_device)
 | 
						|
 | 
						|
 | 
						|
@dataclass
 | 
						|
class InferenceParams:
 | 
						|
    """Inference parameters that are passed to the main model in order
 | 
						|
    to efficienly calculate and store the context during inference."""
 | 
						|
 | 
						|
    max_seqlen: int
 | 
						|
    max_batch_size: int
 | 
						|
    seqlen_offset: int = 0
 | 
						|
    batch_size_offset: int = 0
 | 
						|
    key_value_memory_dict: dict = field(default_factory=dict)
 | 
						|
    lengths_per_sample: Optional[Tensor] = None
 | 
						|
 | 
						|
    def reset(self, max_seqlen, max_batch_size):
 | 
						|
        self.max_seqlen = max_seqlen
 | 
						|
        self.max_batch_size = max_batch_size
 | 
						|
        self.seqlen_offset = 0
 | 
						|
        if self.lengths_per_sample is not None:
 | 
						|
            self.lengths_per_sample.zero_()
 | 
						|
 | 
						|
 | 
						|
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
 | 
						|
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
 | 
						|
def modify_logits_for_top_p_filtering(logits, top_p):
 | 
						|
    """Set the logits for none top-p values to -inf. Done in-place."""
 | 
						|
    if top_p <= 0.0 or top_p >= 1.0:
 | 
						|
        return
 | 
						|
    # First sort and calculate cumulative sum of probabilities.
 | 
						|
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
 | 
						|
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
 | 
						|
    # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
 | 
						|
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
 | 
						|
    # scatter sorted tensors to original indexing
 | 
						|
    indices_to_remove = sorted_indices_to_remove.scatter(
 | 
						|
        1, sorted_indices, sorted_indices_to_remove
 | 
						|
    )
 | 
						|
    logits.masked_fill_(indices_to_remove, float("-inf"))
 | 
						|
 | 
						|
 | 
						|
def modify_logit_for_repetition_penalty(
 | 
						|
    logits, prev_output_tokens, repetition_penalty=1.0
 | 
						|
):
 | 
						|
    """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
 | 
						|
    logits: (batch_size, vocab_size)
 | 
						|
    prev_output_tokens: (batch_size, seq_len)
 | 
						|
    """
 | 
						|
    if repetition_penalty == 1.0:
 | 
						|
        return logits
 | 
						|
    score = torch.gather(logits, 1, prev_output_tokens)
 | 
						|
    # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
 | 
						|
    score = torch.where(
 | 
						|
        score < 0, score * repetition_penalty, score / repetition_penalty
 | 
						|
    )
 | 
						|
    logits.scatter_(1, prev_output_tokens, score)
 | 
						|
    return logits
 | 
						|
 | 
						|
 | 
						|
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
 | 
						|
    """Sample from top-k logits.
 | 
						|
    Arguments:
 | 
						|
        logits: Tensor of shape (batch_size, vocab_size)
 | 
						|
    """
 | 
						|
    if top_k == 1:  # Short-circuit for greedy decoding
 | 
						|
        return logits.argmax(dim=-1)
 | 
						|
    else:
 | 
						|
        if top_p > 0.0:
 | 
						|
            assert top_p <= 1.0, "top-p should be in (0, 1]."
 | 
						|
        if top_k > 0:
 | 
						|
            top_k = min(top_k, logits.size(-1))  # Safety check
 | 
						|
            logits_top, indices = torch.topk(logits, top_k, dim=-1)
 | 
						|
            if temperature != 1.0:
 | 
						|
                logits_top /= temperature
 | 
						|
            modify_logits_for_top_p_filtering(logits_top, top_p)
 | 
						|
            return indices[
 | 
						|
                torch.arange(indices.shape[0], device=indices.device),
 | 
						|
                torch.multinomial(
 | 
						|
                    torch.softmax(logits_top, dim=-1), num_samples=1
 | 
						|
                ).squeeze(dim=-1),
 | 
						|
            ]
 | 
						|
        else:
 | 
						|
            # Clone so that when we modify for top_p we don't change the original logits
 | 
						|
            logits_top = logits / temperature if temperature != 1.0 else logits.clone()
 | 
						|
            modify_logits_for_top_p_filtering(logits_top, top_p)
 | 
						|
            return torch.multinomial(
 | 
						|
                torch.softmax(logits_top, dim=-1), num_samples=1
 | 
						|
            ).squeeze(dim=-1)
 | 
						|
 | 
						|
 | 
						|
@torch.inference_mode()
 | 
						|
def decode(
 | 
						|
    input_ids,
 | 
						|
    model,
 | 
						|
    max_new_tokens,
 | 
						|
    top_k=1,
 | 
						|
    top_p=0.0,
 | 
						|
    temperature=1.0,
 | 
						|
    repetition_penalty=1.0,
 | 
						|
    eos_token_id=None,
 | 
						|
    teacher_outputs=None,
 | 
						|
    vocab_size=None,
 | 
						|
    streamer: Optional[TextStreamer] = None,
 | 
						|
):
 | 
						|
    """Decoding, either greedy or with top-k or top-p sampling.
 | 
						|
    If top-k = 0, don't limit the number of candidates (pure sampling).
 | 
						|
    Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
 | 
						|
    then top-p.
 | 
						|
    We assume that all sequences in the same batch have the same length.
 | 
						|
 | 
						|
    Arguments:
 | 
						|
        input_ids: (batch, seq_len)
 | 
						|
        max_new_tokens: int
 | 
						|
        teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
 | 
						|
            logits, the next token is taken from the teacher_outputs. Useful for testing.
 | 
						|
    Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
 | 
						|
        sequences: (batch, max_length)
 | 
						|
        scores: tuples of (batch, vocab_size)
 | 
						|
    """
 | 
						|
    if streamer is not None:
 | 
						|
        streamer.put(input_ids.cpu())
 | 
						|
 | 
						|
    max_length = input_ids.shape[1] + max_new_tokens
 | 
						|
 | 
						|
    batch_size = input_ids.shape[0]
 | 
						|
    teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
 | 
						|
    inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
 | 
						|
 | 
						|
    def get_logits(input_ids, inference_params):
 | 
						|
        decoding = inference_params.seqlen_offset > 0
 | 
						|
        if decoding:
 | 
						|
            position_ids = torch.full(
 | 
						|
                (batch_size, 1),
 | 
						|
                inference_params.seqlen_offset,
 | 
						|
                dtype=torch.long,
 | 
						|
                device=input_ids.device,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            position_ids = None
 | 
						|
        logits = model(
 | 
						|
            input_ids,
 | 
						|
            position_ids=position_ids,
 | 
						|
            inference_params=inference_params,
 | 
						|
            num_last_tokens=1,
 | 
						|
        ).logits.squeeze(dim=1)
 | 
						|
        return logits[..., :vocab_size] if vocab_size is not None else logits
 | 
						|
 | 
						|
    def sample_tokens(logits, inference_params):
 | 
						|
        if (
 | 
						|
            teacher_outputs is None
 | 
						|
            or teacher_output_len <= inference_params.seqlen_offset
 | 
						|
        ):
 | 
						|
            token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
 | 
						|
        else:
 | 
						|
            token = teacher_outputs[:, inference_params.seqlen_offset]
 | 
						|
        # return rearrange(token, "b -> b 1")
 | 
						|
        return token.unsqueeze(1)
 | 
						|
 | 
						|
    def should_stop(current_token, inference_params):
 | 
						|
        if inference_params.seqlen_offset == 0:
 | 
						|
            return False
 | 
						|
        if eos_token_id is not None and (current_token == eos_token_id).all():
 | 
						|
            return True
 | 
						|
        if inference_params.seqlen_offset >= max_length - 1:
 | 
						|
            return True
 | 
						|
        return False
 | 
						|
 | 
						|
    scores, sequences = [], [input_ids]
 | 
						|
    sequences_cat = input_ids
 | 
						|
    while not should_stop(sequences[-1], inference_params):
 | 
						|
        scores.append(get_logits(sequences[-1], inference_params))
 | 
						|
        inference_params.seqlen_offset += sequences[-1].shape[1]
 | 
						|
        if repetition_penalty == 1.0:
 | 
						|
            sampled_tokens = sample_tokens(scores[-1], inference_params)
 | 
						|
        else:
 | 
						|
            logits = modify_logit_for_repetition_penalty(
 | 
						|
                scores[-1].clone(), sequences_cat, repetition_penalty
 | 
						|
            )
 | 
						|
            sampled_tokens = sample_tokens(logits, inference_params)
 | 
						|
            sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
 | 
						|
        sequences.append(sampled_tokens)
 | 
						|
        if streamer is not None:
 | 
						|
            streamer.put(sampled_tokens.cpu())
 | 
						|
    if streamer is not None:
 | 
						|
        streamer.end()
 | 
						|
    output_cls = (
 | 
						|
        GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
 | 
						|
    )
 | 
						|
    return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
 | 
						|
 | 
						|
 | 
						|
class GenerationMixin:
 | 
						|
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
 | 
						|
        raise NotImplementedError
 | 
						|
 | 
						|
    def generate(
 | 
						|
        self,
 | 
						|
        input_ids,
 | 
						|
        max_new_tokens,
 | 
						|
        top_k=1,
 | 
						|
        top_p=0.0,
 | 
						|
        temperature=1.0,
 | 
						|
        return_dict_in_generate=False,
 | 
						|
        output_scores=False,
 | 
						|
        **kwargs,
 | 
						|
    ):
 | 
						|
        output = decode(
 | 
						|
            input_ids,
 | 
						|
            self,
 | 
						|
            max_new_tokens,
 | 
						|
            top_k=top_k,
 | 
						|
            top_p=top_p,
 | 
						|
            temperature=temperature,
 | 
						|
            **kwargs,
 | 
						|
        )
 | 
						|
        if not output_scores:
 | 
						|
            output.scores = None
 | 
						|
        return output if return_dict_in_generate else output.sequences
 | 
						|
 | 
						|
 | 
						|
class Block(nn.Module):
 | 
						|
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False):
 | 
						|
        """
 | 
						|
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
 | 
						|
 | 
						|
        This Block has a slightly different structure compared to a regular
 | 
						|
        prenorm Transformer block.
 | 
						|
        The standard block is: LN -> MHA/MLP -> Add.
 | 
						|
        [Ref: https://arxiv.org/abs/2002.04745]
 | 
						|
        Here we have: Add -> LN -> Mixer, returning both
 | 
						|
        the hidden_states (output of the mixer) and the residual.
 | 
						|
        This is purely for performance reasons, as we can fuse add and LayerNorm.
 | 
						|
        The residual needs to be provided (except for the very first block).
 | 
						|
        """
 | 
						|
        super().__init__()
 | 
						|
        self.residual_in_fp32 = residual_in_fp32
 | 
						|
        self.mixer = mixer_cls(dim)
 | 
						|
        self.norm = norm_cls(dim)
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        hidden_states: Tensor,
 | 
						|
        residual: Optional[Tensor] = None,
 | 
						|
        inference_params=None,
 | 
						|
    ):
 | 
						|
        r"""Pass the input through the encoder layer.
 | 
						|
 | 
						|
        Args:
 | 
						|
            hidden_states: the sequence to the encoder layer (required).
 | 
						|
            residual: hidden_states = Mixer(LN(residual))
 | 
						|
        """
 | 
						|
        residual = (hidden_states + residual) if residual is not None else hidden_states
 | 
						|
        hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
 | 
						|
        if self.residual_in_fp32:
 | 
						|
            residual = residual.to(torch.float32)
 | 
						|
        hidden_states = self.mixer(hidden_states, inference_params=inference_params)
 | 
						|
        return hidden_states, residual
 | 
						|
 | 
						|
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
 | 
						|
        return self.mixer.allocate_inference_cache(
 | 
						|
            batch_size, max_seqlen, dtype=dtype, **kwargs
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class RMSNorm(torch.nn.Module):
 | 
						|
    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.eps = eps
 | 
						|
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
 | 
						|
        self.register_parameter("bias", None)
 | 
						|
        self.reset_parameters()
 | 
						|
 | 
						|
    def reset_parameters(self):
 | 
						|
        torch.nn.init.ones_(self.weight)
 | 
						|
 | 
						|
    def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
 | 
						|
        return rms_norm(
 | 
						|
            x,
 | 
						|
            self.weight,
 | 
						|
            self.bias,
 | 
						|
            residual=residual,
 | 
						|
            eps=self.eps,
 | 
						|
            prenorm=prenorm,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class Mamba(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        d_model,
 | 
						|
        d_state=16,
 | 
						|
        d_conv=4,
 | 
						|
        expand=2,
 | 
						|
        dt_rank="auto",
 | 
						|
        dt_min=0.001,
 | 
						|
        dt_max=0.1,
 | 
						|
        dt_init="random",
 | 
						|
        dt_scale=1.0,
 | 
						|
        dt_init_floor=1e-4,
 | 
						|
        conv_bias=True,
 | 
						|
        bias=False,
 | 
						|
        use_fast_path=True,  # Fused kernel options
 | 
						|
        layer_idx=None,
 | 
						|
        device=None,
 | 
						|
        dtype=None,
 | 
						|
    ):
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.d_model = d_model
 | 
						|
        self.d_state = d_state
 | 
						|
        self.d_conv = d_conv
 | 
						|
        self.expand = expand
 | 
						|
        self.d_inner = int(self.expand * self.d_model)
 | 
						|
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
 | 
						|
        self.use_fast_path = use_fast_path
 | 
						|
        self.layer_idx = layer_idx
 | 
						|
        self.dt_proj_in_feature = self.dt_rank
 | 
						|
 | 
						|
        self.in_proj = nn.Linear(
 | 
						|
            self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.conv1d = nn.Conv1d(
 | 
						|
            in_channels=self.d_inner,
 | 
						|
            out_channels=self.d_inner,
 | 
						|
            bias=conv_bias,
 | 
						|
            kernel_size=d_conv,
 | 
						|
            groups=self.d_inner,
 | 
						|
            padding=d_conv - 1,
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
 | 
						|
        self.activation = "silu"
 | 
						|
        self.act = nn.SiLU()
 | 
						|
 | 
						|
        self.x_proj = nn.Linear(
 | 
						|
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.dt_proj = nn.Linear(
 | 
						|
            self.dt_rank, self.d_inner, bias=True, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        # Initialize special dt projection to preserve variance at initialization
 | 
						|
        dt_init_std = self.dt_rank**-0.5 * dt_scale
 | 
						|
        if dt_init == "constant":
 | 
						|
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
 | 
						|
        elif dt_init == "random":
 | 
						|
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
 | 
						|
        else:
 | 
						|
            raise NotImplementedError
 | 
						|
 | 
						|
        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
 | 
						|
        dt = torch.exp(
 | 
						|
            torch.rand(self.d_inner, **factory_kwargs)
 | 
						|
            * (math.log(dt_max) - math.log(dt_min))
 | 
						|
            + math.log(dt_min)
 | 
						|
        ).clamp(min=dt_init_floor)
 | 
						|
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
 | 
						|
        inv_dt = dt + torch.log(-torch.expm1(-dt))
 | 
						|
        with torch.no_grad():
 | 
						|
            self.dt_proj.bias.copy_(inv_dt)
 | 
						|
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
 | 
						|
        self.dt_proj.bias._no_reinit = True
 | 
						|
 | 
						|
        # S4D real initialization
 | 
						|
        A = repeat(
 | 
						|
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
 | 
						|
            "n -> d n",
 | 
						|
            d=self.d_inner,
 | 
						|
        ).contiguous()
 | 
						|
        A_log = torch.log(A)  # Keep A_log in fp32
 | 
						|
        self.A_log = nn.Parameter(A_log)
 | 
						|
        self.A_log._no_weight_decay = True
 | 
						|
 | 
						|
        # D "skip" parameter
 | 
						|
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
 | 
						|
        self.D._no_weight_decay = True
 | 
						|
 | 
						|
        self.out_proj = nn.Linear(
 | 
						|
            self.d_inner, self.d_model, bias=bias, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(self, hidden_states, inference_params=None):
 | 
						|
        """
 | 
						|
        hidden_states: (B, L, D)
 | 
						|
        Returns: same shape as hidden_states
 | 
						|
        """
 | 
						|
        batch, seqlen, _ = hidden_states.shape
 | 
						|
 | 
						|
        conv_state, ssm_state = None, None
 | 
						|
        if inference_params is not None:
 | 
						|
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
 | 
						|
            if inference_params.seqlen_offset > 0:
 | 
						|
                # The states are updated inplace
 | 
						|
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
 | 
						|
                return out
 | 
						|
 | 
						|
        # We do matmul and transpose BLH -> HBL at the same time
 | 
						|
        xz = rearrange(
 | 
						|
            self.in_proj(rearrange(hidden_states, "b l d -> d (b l)").t()).t(),
 | 
						|
            "d (b l) -> b d l",
 | 
						|
            l=seqlen,
 | 
						|
        )
 | 
						|
 | 
						|
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
 | 
						|
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
 | 
						|
        x, z = xz.chunk(2, dim=1)
 | 
						|
        # Compute short convolution
 | 
						|
        if conv_state is not None:
 | 
						|
            # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
 | 
						|
            # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
 | 
						|
            conv_state.copy_(
 | 
						|
                F.pad(x, (self.d_conv - x.shape[-1], 0))
 | 
						|
            )  # Update state (B D W)
 | 
						|
        # if causal_conv1d_fn is None:
 | 
						|
        x = self.act(self.conv1d(x)[..., :seqlen])
 | 
						|
 | 
						|
        # We're careful here about the layout, to avoid extra transposes.
 | 
						|
        # We want dt to have d as the slowest moving dimension
 | 
						|
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
 | 
						|
        x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
 | 
						|
        dt, B, C = torch.split(
 | 
						|
            x_dbl, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
 | 
						|
        )
 | 
						|
 | 
						|
        dt = self.dt_proj(dt).t()
 | 
						|
        dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
 | 
						|
        B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
 | 
						|
        C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
 | 
						|
        assert self.activation in ["silu", "swish"]
 | 
						|
        y = selective_scan(
 | 
						|
            x,
 | 
						|
            dt,
 | 
						|
            A,
 | 
						|
            B,
 | 
						|
            C,
 | 
						|
            self.D.float(),
 | 
						|
            z=z,
 | 
						|
            delta_bias=None,
 | 
						|
            delta_softplus=True,
 | 
						|
            return_last_state=ssm_state is not None,
 | 
						|
        )
 | 
						|
        if ssm_state is not None:
 | 
						|
            y, last_state = y
 | 
						|
            ssm_state.copy_(last_state)
 | 
						|
        y = rearrange(y, "b d l -> b l d")
 | 
						|
        out = self.out_proj(y)
 | 
						|
        return out
 | 
						|
 | 
						|
    def step(self, hidden_states, conv_state, ssm_state):
 | 
						|
        dtype = hidden_states.dtype
 | 
						|
        assert (
 | 
						|
            hidden_states.shape[1] == 1
 | 
						|
        ), "Only support decoding with 1 token at a time for now"
 | 
						|
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
 | 
						|
        x, z = xz.chunk(2, dim=-1)  # (B D)
 | 
						|
 | 
						|
        # Conv step
 | 
						|
        conv_state.copy_(
 | 
						|
            torch.roll(conv_state, shifts=-1, dims=-1)
 | 
						|
        )  # Update state (B D W)
 | 
						|
        conv_state[:, :, -1] = x
 | 
						|
        x = torch.sum(
 | 
						|
            conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
 | 
						|
        )  # (B D)
 | 
						|
        if self.conv1d.bias is not None:
 | 
						|
            x = x + self.conv1d.bias
 | 
						|
        x = self.act(x).to(dtype=dtype)
 | 
						|
 | 
						|
        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
 | 
						|
        dt, B, C = torch.split(
 | 
						|
            x_db, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
 | 
						|
        )
 | 
						|
        dt = self.dt_proj(dt)
 | 
						|
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
 | 
						|
 | 
						|
        # SSM step
 | 
						|
        # Discretize A and B
 | 
						|
        dt = F.softplus(dt)
 | 
						|
        dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
 | 
						|
        dB = torch.einsum("bd,bn->bdn", dt, B)
 | 
						|
        ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
 | 
						|
        y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
 | 
						|
        y = y + self.D.to(dtype) * x
 | 
						|
        y = y * self.act(z)  # (B D)
 | 
						|
 | 
						|
        out = self.out_proj(y)
 | 
						|
        return out.unsqueeze(1), conv_state, ssm_state
 | 
						|
 | 
						|
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
 | 
						|
        device = self.out_proj.weight.device
 | 
						|
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
 | 
						|
        conv_state = torch.zeros(
 | 
						|
            batch_size,
 | 
						|
            self.d_model * self.expand,
 | 
						|
            self.d_conv,
 | 
						|
            device=device,
 | 
						|
            dtype=conv_dtype,
 | 
						|
        )
 | 
						|
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
 | 
						|
        # ssm_dtype = torch.float32
 | 
						|
        ssm_state = torch.zeros(
 | 
						|
            batch_size,
 | 
						|
            self.d_model * self.expand,
 | 
						|
            self.d_state,
 | 
						|
            device=device,
 | 
						|
            dtype=ssm_dtype,
 | 
						|
        )
 | 
						|
        return conv_state, ssm_state
 | 
						|
 | 
						|
    def _get_states_from_cache(
 | 
						|
        self, inference_params, batch_size, initialize_states=False
 | 
						|
    ):
 | 
						|
        assert self.layer_idx is not None
 | 
						|
        if self.layer_idx not in inference_params.key_value_memory_dict:
 | 
						|
            batch_shape = (batch_size,)
 | 
						|
            conv_state = torch.zeros(
 | 
						|
                batch_size,
 | 
						|
                self.d_model * self.expand,
 | 
						|
                self.d_conv,
 | 
						|
                device=self.conv1d.weight.device,
 | 
						|
                dtype=self.conv1d.weight.dtype,
 | 
						|
            )
 | 
						|
            ssm_state = torch.zeros(
 | 
						|
                batch_size,
 | 
						|
                self.d_model * self.expand,
 | 
						|
                self.d_state,
 | 
						|
                device=self.dt_proj.weight.device,
 | 
						|
                dtype=self.dt_proj.weight.dtype,
 | 
						|
                # dtype=torch.float32,
 | 
						|
            )
 | 
						|
            inference_params.key_value_memory_dict[self.layer_idx] = (
 | 
						|
                conv_state,
 | 
						|
                ssm_state,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            conv_state, ssm_state = inference_params.key_value_memory_dict[
 | 
						|
                self.layer_idx
 | 
						|
            ]
 | 
						|
            # TODO: What if batch size changes between generation, and we reuse the same states?
 | 
						|
            if initialize_states:
 | 
						|
                conv_state.zero_()
 | 
						|
                ssm_state.zero_()
 | 
						|
        return conv_state, ssm_state
 | 
						|
 | 
						|
 | 
						|
def create_block(
 | 
						|
    d_model,
 | 
						|
    ssm_cfg=None,
 | 
						|
    norm_epsilon=1e-5,
 | 
						|
    rms_norm=False,
 | 
						|
    residual_in_fp32=False,
 | 
						|
    layer_idx=None,
 | 
						|
    device=None,
 | 
						|
    dtype=None,
 | 
						|
):
 | 
						|
    if ssm_cfg is None:
 | 
						|
        ssm_cfg = {}
 | 
						|
    factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
 | 
						|
    norm_cls = partial(
 | 
						|
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
 | 
						|
    )
 | 
						|
    block = Block(
 | 
						|
        d_model,
 | 
						|
        mixer_cls,
 | 
						|
        norm_cls=norm_cls,
 | 
						|
        residual_in_fp32=residual_in_fp32,
 | 
						|
    )
 | 
						|
    block.layer_idx = layer_idx
 | 
						|
    return block
 | 
						|
 | 
						|
 | 
						|
class MixerModel(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        d_model: int,
 | 
						|
        n_layer: int,
 | 
						|
        vocab_size: int,
 | 
						|
        ssm_cfg=None,
 | 
						|
        norm_epsilon: float = 1e-5,
 | 
						|
        rms_norm: bool = False,
 | 
						|
        initializer_cfg=None,
 | 
						|
        residual_in_fp32=False,
 | 
						|
        device=None,
 | 
						|
        dtype=None,
 | 
						|
    ) -> None:
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
        super().__init__()
 | 
						|
        self.residual_in_fp32 = residual_in_fp32
 | 
						|
 | 
						|
        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
 | 
						|
 | 
						|
        self.layers = nn.ModuleList(
 | 
						|
            [
 | 
						|
                create_block(
 | 
						|
                    d_model,
 | 
						|
                    ssm_cfg=ssm_cfg,
 | 
						|
                    norm_epsilon=norm_epsilon,
 | 
						|
                    rms_norm=rms_norm,
 | 
						|
                    residual_in_fp32=residual_in_fp32,
 | 
						|
                    layer_idx=i,
 | 
						|
                    **factory_kwargs,
 | 
						|
                )
 | 
						|
                for i in range(n_layer)
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
 | 
						|
            d_model, eps=norm_epsilon, **factory_kwargs
 | 
						|
        )
 | 
						|
 | 
						|
        self.apply(
 | 
						|
            partial(
 | 
						|
                _init_weights,
 | 
						|
                n_layer=n_layer,
 | 
						|
                **(initializer_cfg if initializer_cfg is not None else {}),
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
 | 
						|
        return {
 | 
						|
            i: layer.allocate_inference_cache(
 | 
						|
                batch_size, max_seqlen, dtype=dtype, **kwargs
 | 
						|
            )
 | 
						|
            for i, layer in enumerate(self.layers)
 | 
						|
        }
 | 
						|
 | 
						|
    def forward(self, input_ids, inference_params=None):
 | 
						|
        hidden_states = self.embedding(input_ids)
 | 
						|
        residual = None
 | 
						|
        for layer in self.layers:
 | 
						|
            hidden_states, residual = layer(
 | 
						|
                hidden_states, residual, inference_params=inference_params
 | 
						|
            )
 | 
						|
        residual = (hidden_states + residual) if residual is not None else hidden_states
 | 
						|
        hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
 | 
						|
        return hidden_states
 | 
						|
 | 
						|
 | 
						|
class MambaLMHeadModel(nn.Module, GenerationMixin):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        config: MambaConfig,
 | 
						|
        initializer_cfg=None,
 | 
						|
        device='cpu',
 | 
						|
        dtype=torch.float32,
 | 
						|
    ) -> None:
 | 
						|
        self.config = config
 | 
						|
        d_model = config.d_model
 | 
						|
        n_layer = config.n_layer
 | 
						|
        vocab_size = config.vocab_size
 | 
						|
        ssm_cfg = config.ssm_cfg
 | 
						|
        rms_norm = config.rms_norm
 | 
						|
        residual_in_fp32 = config.residual_in_fp32
 | 
						|
        pad_vocab_size_multiple = config.pad_vocab_size_multiple
 | 
						|
        factory_kwargs = {"device": device, "dtype": dtype}
 | 
						|
 | 
						|
        super().__init__()
 | 
						|
        if vocab_size % pad_vocab_size_multiple != 0:
 | 
						|
            vocab_size += pad_vocab_size_multiple - (
 | 
						|
                vocab_size % pad_vocab_size_multiple
 | 
						|
            )
 | 
						|
        self.backbone = MixerModel(
 | 
						|
            d_model=d_model,
 | 
						|
            n_layer=n_layer,
 | 
						|
            vocab_size=vocab_size,
 | 
						|
            ssm_cfg=ssm_cfg,
 | 
						|
            rms_norm=rms_norm,
 | 
						|
            initializer_cfg=initializer_cfg,
 | 
						|
            residual_in_fp32=residual_in_fp32,
 | 
						|
            **factory_kwargs,
 | 
						|
        )
 | 
						|
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
 | 
						|
 | 
						|
        # Initialize weights and apply final processing
 | 
						|
        self.apply(
 | 
						|
            partial(
 | 
						|
                _init_weights,
 | 
						|
                n_layer=n_layer,
 | 
						|
                **(initializer_cfg if initializer_cfg is not None else {}),
 | 
						|
            )
 | 
						|
        )
 | 
						|
        self.tie_weights()
 | 
						|
 | 
						|
    def tie_weights(self):
 | 
						|
        self.lm_head.weight = self.backbone.embedding.weight
 | 
						|
 | 
						|
    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
 | 
						|
        return self.backbone.allocate_inference_cache(
 | 
						|
            batch_size, max_seqlen, dtype=dtype, **kwargs
 | 
						|
        )
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        "position_ids" is just to be compatible with Transformer generation. We don't use it.
 | 
						|
        num_last_tokens: if > 0, only return the logits for the last n tokens
 | 
						|
        """
 | 
						|
        hidden_states = self.backbone(input_ids, inference_params=inference_params)
 | 
						|
        if num_last_tokens > 0:
 | 
						|
            hidden_states = hidden_states[:, -num_last_tokens:]
 | 
						|
        lm_logits = self.lm_head(hidden_states)
 | 
						|
        CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
 | 
						|
        return CausalLMOutput(logits=lm_logits)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def from_pretrained(cls, pretrained_model_name, device='cpu', dtype=torch.float32, **kwargs):
 | 
						|
        config_data = load_config_hf(pretrained_model_name)
 | 
						|
        config = MambaConfig(**config_data)
 | 
						|
        model = cls(config, device=device, dtype=dtype, **kwargs)
 | 
						|
        model.load_state_dict(
 | 
						|
            load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
 | 
						|
        )
 | 
						|
        return model
 | 
						|
 | 
						|
    def save_pretrained(self, save_directory):
 | 
						|
        """
 | 
						|
        Minimal implementation of save_pretrained for MambaLMHeadModel.
 | 
						|
        Save the model and its configuration file to a directory.
 | 
						|
        """
 | 
						|
        # Ensure save_directory exists
 | 
						|
        if not os.path.exists(save_directory):
 | 
						|
            os.makedirs(save_directory)
 | 
						|
 | 
						|
        # Save the model's state_dict
 | 
						|
        model_path = os.path.join(save_directory, "pytorch_model.bin")
 | 
						|
        torch.save(self.state_dict(), model_path)
 | 
						|
 | 
						|
        # Save the configuration of the model
 | 
						|
        config_path = os.path.join(save_directory, "config.json")
 | 
						|
        with open(config_path, "w") as f:
 | 
						|
            json.dump(self.config.__dict__, f)
 | 
						|
 | 
						|
    @property
 | 
						|
    def device(self):
 | 
						|
        return next(self.parameters()).device
 |