Add cpu and gpu examples of Mamba (#9797)
* Add mamba cpu example * Add mamba gpu example * Use a smaller model as the example * minor fixes --------- Co-authored-by: Shengsheng Huang <shengsheng.huang@intel.com>
This commit is contained in:
parent
937e1f7c74
commit
2347f611cf
8 changed files with 2121 additions and 0 deletions
|
|
@ -182,6 +182,7 @@ Over 40 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
|
|||
| Distil-Whisper | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/distil-whisper) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/distil-whisper) |
|
||||
| Yi | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/yi) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/yi) |
|
||||
| BlueLM | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/bluelm) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/bluelm) |
|
||||
| Mamba | [link](python/llm/example/CPU/PyTorch-Models/Model/mamba) | [link](python/llm/example/GPU/PyTorch-Models/Model/mamba) |
|
||||
| SOLAR | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/solar) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/solar) |
|
||||
| Phixtral | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/phixtral) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/phixtral) |
|
||||
| InternLM2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/internlm2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/internlm2) |
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
|
|||
| Distil-Whisper | [link](example/CPU/HF-Transformers-AutoModels/Model/distil-whisper) | [link](example/GPU/HF-Transformers-AutoModels/Model/distil-whisper) |
|
||||
| Yi | [link](example/CPU/HF-Transformers-AutoModels/Model/yi) | [link](example/GPU/HF-Transformers-AutoModels/Model/yi) |
|
||||
| BlueLM | [link](example/CPU/HF-Transformers-AutoModels/Model/bluelm) | [link](example/GPU/HF-Transformers-AutoModels/Model/bluelm) |
|
||||
| Mamba | [link](example/CPU/PyTorch-Models/Model/mamba) | [link](example/GPU/PyTorch-Models/Model/mamba) |
|
||||
| SOLAR | [link](example/CPU/HF-Transformers-AutoModels/Model/solar) | [link](example/GPU/HF-Transformers-AutoModels/Model/solar) |
|
||||
| Phixtral | [link](example/CPU/HF-Transformers-AutoModels/Model/phixtral) | [link](example/GPU/HF-Transformers-AutoModels/Model/phixtral) |
|
||||
| InternLM2 | [link](example/CPU/HF-Transformers-AutoModels/Model/internlm2) | [link](example/GPU/HF-Transformers-AutoModels/Model/internlm2) |
|
||||
|
|
@ -86,6 +87,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
|
|||
| Phi-2 | [link](example/CPU/HF-Transformers-AutoModels/Model/phi-2) | [link](example/GPU/HF-Transformers-AutoModels/Model/phi-2) |
|
||||
| Yuan2 | [link](example/CPU/HF-Transformers-AutoModels/Model/yuan2) | [link](example/GPU/HF-Transformers-AutoModels/Model/yuan2) |
|
||||
| DeciLM-7B | [link](example/CPU/HF-Transformers-AutoModels/Model/deciLM-7b) | [link](example/GPU/HF-Transformers-AutoModels/Model/deciLM-7b) |
|
||||
|
||||
### Working with `bigdl-llm`
|
||||
|
||||
<details><summary>Table of Contents</summary>
|
||||
|
|
|
|||
71
python/llm/example/CPU/PyTorch-Models/Model/mamba/README.md
Normal file
71
python/llm/example/CPU/PyTorch-Models/Model/mamba/README.md
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
# Mamba
|
||||
In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Mamba models. For illustration purposes, we utilize the [state-spaces/mamba-1.4b](https://huggingface.co/state-spaces/mamba-1.4b) and [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) as reference Mamba models.
|
||||
|
||||
## Requirements
|
||||
To run these examples with BigDL-LLM, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||
|
||||
## Example: Predict Tokens using `generate()` API
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Mamba model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations.
|
||||
### 1. Install
|
||||
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
|
||||
|
||||
After installing conda, create a Python environment for BigDL-LLM:
|
||||
```bash
|
||||
conda create -n llm python=3.9 # recommend to use Python 3.9
|
||||
conda activate llm
|
||||
|
||||
pip install --pre --upgrade bigdl-llm[all] # install the latest bigdl-llm nightly build with 'all' option
|
||||
pip install einops # package required by Mamba
|
||||
```
|
||||
|
||||
### 2. Run
|
||||
After setting up the Python environment, you could run the example by following steps.
|
||||
|
||||
#### 2.1 Client
|
||||
On client Windows machines, it is recommended to run directly with full utilization of all cores:
|
||||
```powershell
|
||||
python ./generate.py
|
||||
```
|
||||
|
||||
More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
|
||||
|
||||
#### 2.2 Server
|
||||
For optimal performance on server, it is recommended to set several environment variables (refer to [here](../README.md#best-known-configuration-on-linux) for more information), and run the example with all the physical cores of a single socket.
|
||||
|
||||
E.g. on Linux,
|
||||
```bash
|
||||
# set BigDL-LLM env variables
|
||||
source bigdl-llm-init
|
||||
|
||||
# e.g. for a server with 48 cores per socket
|
||||
export OMP_NUM_THREADS=48
|
||||
numactl -C 0-47 -m 0 python ./generate.py
|
||||
```
|
||||
More information about arguments can be found in [Arguments Info](#23-arguments-info) section. The expected output can be found in [Sample Output](#24-sample-output) section.
|
||||
|
||||
#### 2.3 Arguments Info
|
||||
In the example, several arguments can be passed to satisfy your requirements:
|
||||
|
||||
- `--repo-id-or-model-path`: str, argument defining the huggingface repo id for the Mamba model (e.g `state-spaces/mamba-1.4b` and `state-spaces/mamba-2.8b`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `state-spaces/mamba-1.4b`.
|
||||
- `--tokenizer-repo-id-or-path`: str, argument defining the huggingface repo id for the tokenizer of Mamba model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `EleutherAI/gpt-neox-20b`.
|
||||
- `--prompt`: str, argument defining the prompt to be inferred (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||
- `--n-predict`: int, argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
#### 2.4 Sample Output
|
||||
#### [state-spaces/mamba-1.4b](https://huggingface.co/state-spaces/mamba-1.4b)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Output --------------------
|
||||
What is AI?
|
||||
|
||||
Artificial Intelligence is a field of computer science that deals with the creation of machines that can learn and think like humans. It is a field that has
|
||||
```
|
||||
|
||||
#### [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Output --------------------
|
||||
What is AI?
|
||||
|
||||
Artificial Intelligence is a field of computer science that focuses on developing intelligent machines. It is a field that is concerned with the creation of machines that can
|
||||
```
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import torch
|
||||
from bigdl.llm import optimize_model
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from model import MambaLMHeadModel
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Mamba model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="state-spaces/mamba-1.4b",
|
||||
help='The huggingface repo id for the Mamba model to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--tokenizer-repo-id-or-path', type=str, default="EleutherAI/gpt-neox-20b",
|
||||
help='The huggingface repo id for the Mamba tokenizer to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
tokenizer_path = args.tokenizer_repo_id_or_path
|
||||
|
||||
# Load model
|
||||
model = MambaLMHeadModel.from_pretrained(model_path)
|
||||
|
||||
# With only one line to enable BigDL-LLM optimization on model
|
||||
model = optimize_model(model, low_bit='asym_int4', modules_to_not_convert=["dt_proj", "x_proj", "out_proj"])
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt")
|
||||
st = time.time()
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
end = time.time()
|
||||
output_str = tokenizer.decode(output[0])
|
||||
print(f'Inference time: {end-st} s')
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
926
python/llm/example/CPU/PyTorch-Models/Model/mamba/model.py
Normal file
926
python/llm/example/CPU/PyTorch-Models/Model/mamba/model.py
Normal file
|
|
@ -0,0 +1,926 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# The code is adapted from: https://github.com/state-spaces/mamba.
|
||||
#
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from transformers.generation import (
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
SampleDecoderOnlyOutput,
|
||||
TextStreamer,
|
||||
)
|
||||
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaConfig:
|
||||
d_model: int = 2560
|
||||
n_layer: int = 64
|
||||
vocab_size: int = 50277
|
||||
ssm_cfg: dict = field(default_factory=dict)
|
||||
rms_norm: bool = True
|
||||
fused_add_norm: bool = False
|
||||
residual_in_fp32: bool = True
|
||||
pad_vocab_size_multiple: int = 8
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(
|
||||
module,
|
||||
n_layer,
|
||||
initializer_range=0.02,
|
||||
rescale_prenorm_residual=True,
|
||||
n_residuals_per_layer=1,
|
||||
):
|
||||
if isinstance(module, nn.Linear):
|
||||
if module.bias is not None:
|
||||
if not getattr(module.bias, "_no_reinit", False):
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
|
||||
if rescale_prenorm_residual:
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
||||
# We need to reinit p since this code could be called multiple times
|
||||
# Having just p *= scale would repeatedly scale it down
|
||||
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
||||
with torch.no_grad():
|
||||
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
||||
|
||||
|
||||
def selective_scan(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
if A.is_complex():
|
||||
if is_variable_B:
|
||||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||||
if is_variable_C:
|
||||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||||
else:
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate))
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum("bdn,dn->bd", x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
if y.is_complex():
|
||||
y = y.real * 2
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
def layer_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
|
||||
dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
out = F.layer_norm(
|
||||
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
||||
).to(dtype)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def rms_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
|
||||
dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
out = out.to(dtype)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def load_config_hf(model_name):
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
return json.load(open(resolved_archive_file))
|
||||
|
||||
|
||||
def load_state_dict_hf(model_name, device=None, dtype=None):
|
||||
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
return torch.load(resolved_archive_file, map_location=mapped_device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
seqlen_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
def reset(self, max_seqlen, max_batch_size):
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_batch_size = max_batch_size
|
||||
self.seqlen_offset = 0
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample.zero_()
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf. Done in-place."""
|
||||
if top_p <= 0.0 or top_p >= 1.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits.masked_fill_(indices_to_remove, float("-inf"))
|
||||
|
||||
|
||||
def modify_logit_for_repetition_penalty(
|
||||
logits, prev_output_tokens, repetition_penalty=1.0
|
||||
):
|
||||
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
||||
logits: (batch_size, vocab_size)
|
||||
prev_output_tokens: (batch_size, seq_len)
|
||||
"""
|
||||
if repetition_penalty == 1.0:
|
||||
return logits
|
||||
score = torch.gather(logits, 1, prev_output_tokens)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(1, prev_output_tokens, score)
|
||||
return logits
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
"""Sample from top-k logits.
|
||||
Arguments:
|
||||
logits: Tensor of shape (batch_size, vocab_size)
|
||||
"""
|
||||
if top_k == 1: # Short-circuit for greedy decoding
|
||||
return logits.argmax(dim=-1)
|
||||
else:
|
||||
if top_p > 0.0:
|
||||
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
if temperature != 1.0:
|
||||
logits_top /= temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(
|
||||
torch.softmax(logits_top, dim=-1), num_samples=1
|
||||
).squeeze(dim=-1),
|
||||
]
|
||||
else:
|
||||
# Clone so that when we modify for top_p we don't change the original logits
|
||||
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(
|
||||
torch.softmax(logits_top, dim=-1), num_samples=1
|
||||
).squeeze(dim=-1)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(
|
||||
input_ids,
|
||||
model,
|
||||
max_new_tokens,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
eos_token_id=None,
|
||||
teacher_outputs=None,
|
||||
vocab_size=None,
|
||||
streamer: Optional[TextStreamer] = None,
|
||||
):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
then top-p.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_new_tokens: int
|
||||
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
||||
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
||||
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
max_length = input_ids.shape[1] + max_new_tokens
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
||||
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
|
||||
def get_logits(input_ids, inference_params):
|
||||
decoding = inference_params.seqlen_offset > 0
|
||||
if decoding:
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params.seqlen_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
else:
|
||||
position_ids = None
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
return logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
|
||||
def sample_tokens(logits, inference_params):
|
||||
if (
|
||||
teacher_outputs is None
|
||||
or teacher_output_len <= inference_params.seqlen_offset
|
||||
):
|
||||
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
token = teacher_outputs[:, inference_params.seqlen_offset]
|
||||
# return rearrange(token, "b -> b 1")
|
||||
return token.unsqueeze(1)
|
||||
|
||||
def should_stop(current_token, inference_params):
|
||||
if inference_params.seqlen_offset == 0:
|
||||
return False
|
||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||
return True
|
||||
if inference_params.seqlen_offset >= max_length - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
scores, sequences = [], [input_ids]
|
||||
sequences_cat = input_ids
|
||||
while not should_stop(sequences[-1], inference_params):
|
||||
scores.append(get_logits(sequences[-1], inference_params))
|
||||
inference_params.seqlen_offset += sequences[-1].shape[1]
|
||||
if repetition_penalty == 1.0:
|
||||
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
||||
else:
|
||||
logits = modify_logit_for_repetition_penalty(
|
||||
scores[-1].clone(), sequences_cat, repetition_penalty
|
||||
)
|
||||
sampled_tokens = sample_tokens(logits, inference_params)
|
||||
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
||||
sequences.append(sampled_tokens)
|
||||
if streamer is not None:
|
||||
streamer.put(sampled_tokens.cpu())
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
output_cls = (
|
||||
GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
)
|
||||
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
max_new_tokens,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
temperature=1.0,
|
||||
return_dict_in_generate=False,
|
||||
output_scores=False,
|
||||
**kwargs,
|
||||
):
|
||||
output = decode(
|
||||
input_ids,
|
||||
self,
|
||||
max_new_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False):
|
||||
"""
|
||||
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
||||
|
||||
This Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
The standard block is: LN -> MHA/MLP -> Add.
|
||||
[Ref: https://arxiv.org/abs/2002.04745]
|
||||
Here we have: Add -> LN -> Mixer, returning both
|
||||
the hidden_states (output of the mixer) and the residual.
|
||||
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
||||
The residual needs to be provided (except for the very first block).
|
||||
"""
|
||||
super().__init__()
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.norm = norm_cls(dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
residual: Optional[Tensor] = None,
|
||||
inference_params=None,
|
||||
):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: hidden_states = Mixer(LN(residual))
|
||||
"""
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
||||
return hidden_states, residual
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return rms_norm(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=16,
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
dt_rank="auto",
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init="random",
|
||||
dt_scale=1.0,
|
||||
dt_init_floor=1e-4,
|
||||
conv_bias=True,
|
||||
bias=False,
|
||||
use_fast_path=True, # Fused kernel options
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.expand = expand
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||
self.use_fast_path = use_fast_path
|
||||
self.layer_idx = layer_idx
|
||||
self.dt_proj_in_feature = self.dt_rank
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.d_inner,
|
||||
out_channels=self.d_inner,
|
||||
bias=conv_bias,
|
||||
kernel_size=d_conv,
|
||||
groups=self.d_inner,
|
||||
padding=d_conv - 1,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self.activation = "silu"
|
||||
self.act = nn.SiLU()
|
||||
|
||||
self.x_proj = nn.Linear(
|
||||
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
||||
)
|
||||
|
||||
self.dt_proj = nn.Linear(
|
||||
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
|
||||
)
|
||||
|
||||
# Initialize special dt projection to preserve variance at initialization
|
||||
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
||||
if dt_init == "constant":
|
||||
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
||||
elif dt_init == "random":
|
||||
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
||||
dt = torch.exp(
|
||||
torch.rand(self.d_inner, **factory_kwargs)
|
||||
* (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
).clamp(min=dt_init_floor)
|
||||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
with torch.no_grad():
|
||||
self.dt_proj.bias.copy_(inv_dt)
|
||||
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
||||
self.dt_proj.bias._no_reinit = True
|
||||
|
||||
# S4D real initialization
|
||||
A = repeat(
|
||||
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
||||
"n -> d n",
|
||||
d=self.d_inner,
|
||||
).contiguous()
|
||||
A_log = torch.log(A) # Keep A_log in fp32
|
||||
self.A_log = nn.Parameter(A_log)
|
||||
self.A_log._no_weight_decay = True
|
||||
|
||||
# D "skip" parameter
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
self.out_proj = nn.Linear(
|
||||
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, inference_params=None):
|
||||
"""
|
||||
hidden_states: (B, L, D)
|
||||
Returns: same shape as hidden_states
|
||||
"""
|
||||
batch, seqlen, _ = hidden_states.shape
|
||||
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
||||
return out
|
||||
|
||||
# We do matmul and transpose BLH -> HBL at the same time
|
||||
xz = rearrange(
|
||||
self.in_proj(rearrange(hidden_states, "b l d -> d (b l)").t()).t(),
|
||||
"d (b l) -> b d l",
|
||||
l=seqlen,
|
||||
)
|
||||
|
||||
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
||||
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
if conv_state is not None:
|
||||
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||
conv_state.copy_(
|
||||
F.pad(x, (self.d_conv - x.shape[-1], 0))
|
||||
) # Update state (B D W)
|
||||
# if causal_conv1d_fn is None:
|
||||
x = self.act(self.conv1d(x)[..., :seqlen])
|
||||
|
||||
# We're careful here about the layout, to avoid extra transposes.
|
||||
# We want dt to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||
dt, B, C = torch.split(
|
||||
x_dbl, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
|
||||
)
|
||||
|
||||
dt = self.dt_proj(dt).t()
|
||||
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
assert self.activation in ["silu", "swish"]
|
||||
y = selective_scan(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.D.float(),
|
||||
z=z,
|
||||
delta_bias=None,
|
||||
delta_softplus=True,
|
||||
return_last_state=ssm_state is not None,
|
||||
)
|
||||
if ssm_state is not None:
|
||||
y, last_state = y
|
||||
ssm_state.copy_(last_state)
|
||||
y = rearrange(y, "b d l -> b l d")
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
|
||||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
dtype = hidden_states.dtype
|
||||
assert (
|
||||
hidden_states.shape[1] == 1
|
||||
), "Only support decoding with 1 token at a time for now"
|
||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
|
||||
# Conv step
|
||||
conv_state.copy_(
|
||||
torch.roll(conv_state, shifts=-1, dims=-1)
|
||||
) # Update state (B D W)
|
||||
conv_state[:, :, -1] = x
|
||||
x = torch.sum(
|
||||
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
||||
) # (B D)
|
||||
if self.conv1d.bias is not None:
|
||||
x = x + self.conv1d.bias
|
||||
x = self.act(x).to(dtype=dtype)
|
||||
|
||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||
dt, B, C = torch.split(
|
||||
x_db, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
|
||||
)
|
||||
dt = self.dt_proj(dt)
|
||||
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
||||
|
||||
# SSM step
|
||||
# Discretize A and B
|
||||
dt = F.softplus(dt)
|
||||
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
||||
dB = torch.einsum("bd,bn->bdn", dt, B)
|
||||
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
||||
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
||||
y = y + self.D.to(dtype) * x
|
||||
y = y * self.act(z) # (B D)
|
||||
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
device = self.out_proj.weight.device
|
||||
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_conv,
|
||||
device=device,
|
||||
dtype=conv_dtype,
|
||||
)
|
||||
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
||||
# ssm_dtype = torch.float32
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_state,
|
||||
device=device,
|
||||
dtype=ssm_dtype,
|
||||
)
|
||||
return conv_state, ssm_state
|
||||
|
||||
def _get_states_from_cache(
|
||||
self, inference_params, batch_size, initialize_states=False
|
||||
):
|
||||
assert self.layer_idx is not None
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
batch_shape = (batch_size,)
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_conv,
|
||||
device=self.conv1d.weight.device,
|
||||
dtype=self.conv1d.weight.dtype,
|
||||
)
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_state,
|
||||
device=self.dt_proj.weight.device,
|
||||
dtype=self.dt_proj.weight.dtype,
|
||||
# dtype=torch.float32,
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (
|
||||
conv_state,
|
||||
ssm_state,
|
||||
)
|
||||
else:
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
||||
self.layer_idx
|
||||
]
|
||||
# TODO: What if batch size changes between generation, and we reuse the same states?
|
||||
if initialize_states:
|
||||
conv_state.zero_()
|
||||
ssm_state.zero_()
|
||||
return conv_state, ssm_state
|
||||
|
||||
|
||||
def create_block(
|
||||
d_model,
|
||||
ssm_cfg=None,
|
||||
norm_epsilon=1e-5,
|
||||
rms_norm=False,
|
||||
residual_in_fp32=False,
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
if ssm_cfg is None:
|
||||
ssm_cfg = {}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
||||
norm_cls = partial(
|
||||
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
||||
)
|
||||
block = Block(
|
||||
d_model,
|
||||
mixer_cls,
|
||||
norm_cls=norm_cls,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
)
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
|
||||
class MixerModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
vocab_size: int,
|
||||
ssm_cfg=None,
|
||||
norm_epsilon: float = 1e-5,
|
||||
rms_norm: bool = False,
|
||||
initializer_cfg=None,
|
||||
residual_in_fp32=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
create_block(
|
||||
d_model,
|
||||
ssm_cfg=ssm_cfg,
|
||||
norm_epsilon=norm_epsilon,
|
||||
rms_norm=rms_norm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
layer_idx=i,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
||||
d_model, eps=norm_epsilon, **factory_kwargs
|
||||
)
|
||||
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return {
|
||||
i: layer.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
for i, layer in enumerate(self.layers)
|
||||
}
|
||||
|
||||
def forward(self, input_ids, inference_params=None):
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
hidden_states, residual, inference_params=inference_params
|
||||
)
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
initializer_cfg=None,
|
||||
device='cpu',
|
||||
dtype=torch.float32,
|
||||
) -> None:
|
||||
self.config = config
|
||||
d_model = config.d_model
|
||||
n_layer = config.n_layer
|
||||
vocab_size = config.vocab_size
|
||||
ssm_cfg = config.ssm_cfg
|
||||
rms_norm = config.rms_norm
|
||||
residual_in_fp32 = config.residual_in_fp32
|
||||
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (
|
||||
vocab_size % pad_vocab_size_multiple
|
||||
)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
vocab_size=vocab_size,
|
||||
ssm_cfg=ssm_cfg,
|
||||
rms_norm=rms_norm,
|
||||
initializer_cfg=initializer_cfg,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0
|
||||
):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device='cpu', dtype=torch.float32, **kwargs):
|
||||
config_data = load_config_hf(pretrained_model_name)
|
||||
config = MambaConfig(**config_data)
|
||||
model = cls(config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(
|
||||
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
||||
)
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
||||
Save the model and its configuration file to a directory.
|
||||
"""
|
||||
# Ensure save_directory exists
|
||||
if not os.path.exists(save_directory):
|
||||
os.makedirs(save_directory)
|
||||
|
||||
# Save the model's state_dict
|
||||
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
||||
torch.save(self.state_dict(), model_path)
|
||||
|
||||
# Save the configuration of the model
|
||||
config_path = os.path.join(save_directory, "config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(self.config.__dict__, f)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
65
python/llm/example/GPU/PyTorch-Models/Model/mamba/README.md
Normal file
65
python/llm/example/GPU/PyTorch-Models/Model/mamba/README.md
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# Mamba
|
||||
In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Mamba models. For illustration purposes, we utilize the [state-spaces/mamba-1.4b](https://huggingface.co/state-spaces/mamba-1.4b) and [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) as reference Mamba models.
|
||||
|
||||
## Requirements
|
||||
To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
|
||||
|
||||
## Example: Predict Tokens using `generate()` API
|
||||
In the example [generate.py](./generate.py), we show a basic use case for a Mamba model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
|
||||
### 1. Install
|
||||
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
|
||||
|
||||
After installing conda, create a Python environment for BigDL-LLM:
|
||||
```bash
|
||||
conda create -n llm python=3.9 # recommend to use Python 3.9
|
||||
conda activate llm
|
||||
|
||||
# below command will install intel_extension_for_pytorch==2.0.110+xpu as default
|
||||
# you can install specific ipex/torch version for your need
|
||||
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
|
||||
pip install einops # package required by Mamba
|
||||
```
|
||||
|
||||
### 2. Configures OneAPI environment variables
|
||||
```bash
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
### 3. Run
|
||||
|
||||
For optimal performance on Arc, it is recommended to set several environment variables.
|
||||
|
||||
```bash
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
```
|
||||
|
||||
```bash
|
||||
python ./generate.py
|
||||
```
|
||||
|
||||
In the example, several arguments can be passed to satisfy your requirements:
|
||||
|
||||
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the Mamba model (e.g `state-spaces/mamba-1.4b` and `state-spaces/mamba-2.8b`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `state-spaces/mamba-1.4b`.
|
||||
- `--tokenizer-repo-id-or-path`: argument defining the huggingface repo id for the tokenizer of Mamba model to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `EleutherAI/gpt-neox-20b`.
|
||||
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
|
||||
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
|
||||
|
||||
#### 2.3 Sample Output
|
||||
#### [state-spaces/mamba-1.4b](https://huggingface.co/state-spaces/mamba-1.4b)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Output --------------------
|
||||
What is AI?
|
||||
|
||||
Artificial Intelligence (AI) is a broad term that describes the use of artificial intelligence (AI) to create artificial intelligence (AI). AI is a
|
||||
```
|
||||
|
||||
#### [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b)
|
||||
```log
|
||||
Inference time: xxxx s
|
||||
-------------------- Output --------------------
|
||||
What is AI?
|
||||
|
||||
Artificial Intelligence is a field of study that focuses on creating machines that can perform intelligent functions. These functions are performed by machines that are smarter than humans
|
||||
```
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm import optimize_model
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from model import MambaLMHeadModel
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Mamba model')
|
||||
parser.add_argument('--repo-id-or-model-path', type=str, default="state-spaces/mamba-1.4b",
|
||||
help='The huggingface repo id for the Mamba model to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--tokenizer-repo-id-or-path', type=str, default="EleutherAI/gpt-neox-20b",
|
||||
help='The huggingface repo id for the Mamba tokenizer to be downloaded'
|
||||
', or the path to the huggingface checkpoint folder')
|
||||
parser.add_argument('--prompt', type=str, default="What is AI?",
|
||||
help='Prompt to infer')
|
||||
parser.add_argument('--n-predict', type=int, default=32,
|
||||
help='Max tokens to predict')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
tokenizer_path = args.tokenizer_repo_id_or_path
|
||||
|
||||
# Load model
|
||||
model = MambaLMHeadModel.from_pretrained(model_path)
|
||||
|
||||
# With only one line to enable BigDL-LLM optimization on model
|
||||
model = optimize_model(model, low_bit='asym_int4', modules_to_not_convert=["dt_proj", "x_proj"])
|
||||
|
||||
model = model.to('xpu')
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
# Generate predicted tokens
|
||||
with torch.inference_mode():
|
||||
input_ids = tokenizer.encode(args.prompt, return_tensors="pt").to('xpu')
|
||||
# ipex model needs a warmup, then inference time can be accurate
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
st = time.time()
|
||||
output = model.generate(input_ids,
|
||||
max_new_tokens=args.n_predict)
|
||||
torch.xpu.synchronize()
|
||||
end = time.time()
|
||||
output = output.cpu()
|
||||
output_str = tokenizer.decode(output[0])
|
||||
print(f'Inference time: {end-st} s')
|
||||
print('-'*20, 'Output', '-'*20)
|
||||
print(output_str)
|
||||
926
python/llm/example/GPU/PyTorch-Models/Model/mamba/model.py
Normal file
926
python/llm/example/GPU/PyTorch-Models/Model/mamba/model.py
Normal file
|
|
@ -0,0 +1,926 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# The code is adapted from: https://github.com/state-spaces/mamba.
|
||||
#
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
from transformers.generation import (
|
||||
GreedySearchDecoderOnlyOutput,
|
||||
SampleDecoderOnlyOutput,
|
||||
TextStreamer,
|
||||
)
|
||||
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaConfig:
|
||||
d_model: int = 2560
|
||||
n_layer: int = 64
|
||||
vocab_size: int = 50277
|
||||
ssm_cfg: dict = field(default_factory=dict)
|
||||
rms_norm: bool = True
|
||||
fused_add_norm: bool = False
|
||||
residual_in_fp32: bool = True
|
||||
pad_vocab_size_multiple: int = 8
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(
|
||||
module,
|
||||
n_layer,
|
||||
initializer_range=0.02,
|
||||
rescale_prenorm_residual=True,
|
||||
n_residuals_per_layer=1,
|
||||
):
|
||||
if isinstance(module, nn.Linear):
|
||||
if module.bias is not None:
|
||||
if not getattr(module.bias, "_no_reinit", False):
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
|
||||
if rescale_prenorm_residual:
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
||||
# We need to reinit p since this code could be called multiple times
|
||||
# Having just p *= scale would repeatedly scale it down
|
||||
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
||||
with torch.no_grad():
|
||||
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
||||
|
||||
|
||||
def selective_scan(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False):
|
||||
"""
|
||||
u: r(B D L)
|
||||
delta: r(B D L)
|
||||
A: c(D N) or r(D N)
|
||||
B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
|
||||
D: r(D)
|
||||
z: r(B D L)
|
||||
delta_bias: r(D), fp32
|
||||
|
||||
out: r(B D L)
|
||||
last_state (optional): r(B D dstate) or c(B D dstate)
|
||||
"""
|
||||
dtype_in = u.dtype
|
||||
u = u.float()
|
||||
delta = delta.float()
|
||||
if delta_bias is not None:
|
||||
delta = delta + delta_bias[..., None].float()
|
||||
if delta_softplus:
|
||||
delta = F.softplus(delta)
|
||||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||||
is_variable_B = B.dim() >= 3
|
||||
is_variable_C = C.dim() >= 3
|
||||
if A.is_complex():
|
||||
if is_variable_B:
|
||||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||||
if is_variable_C:
|
||||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||||
else:
|
||||
B = B.float()
|
||||
C = C.float()
|
||||
x = A.new_zeros((batch, dim, dstate))
|
||||
ys = []
|
||||
deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
|
||||
if not is_variable_B:
|
||||
deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
if B.dim() == 3:
|
||||
deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
|
||||
else:
|
||||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||||
deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
|
||||
if is_variable_C and C.dim() == 4:
|
||||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||||
last_state = None
|
||||
for i in range(u.shape[2]):
|
||||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||||
if not is_variable_C:
|
||||
y = torch.einsum("bdn,dn->bd", x, C)
|
||||
else:
|
||||
if C.dim() == 3:
|
||||
y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
|
||||
else:
|
||||
y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
|
||||
if i == u.shape[2] - 1:
|
||||
last_state = x
|
||||
if y.is_complex():
|
||||
y = y.real * 2
|
||||
ys.append(y)
|
||||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||||
if z is not None:
|
||||
out = out * F.silu(z)
|
||||
out = out.to(dtype=dtype_in)
|
||||
return out if not return_last_state else (out, last_state)
|
||||
|
||||
|
||||
def layer_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
|
||||
dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
out = F.layer_norm(
|
||||
x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
|
||||
).to(dtype)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def rms_norm(x, weight, bias, residual=None, eps=1e-6, prenorm=False):
|
||||
dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = (x + residual).to(x.dtype)
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
out = out.to(dtype)
|
||||
return out if not prenorm else (out, x)
|
||||
|
||||
|
||||
def load_config_hf(model_name):
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
return json.load(open(resolved_archive_file))
|
||||
|
||||
|
||||
def load_state_dict_hf(model_name, device=None, dtype=None):
|
||||
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
|
||||
resolved_archive_file = cached_file(
|
||||
model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
|
||||
)
|
||||
return torch.load(resolved_archive_file, map_location=mapped_device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceParams:
|
||||
"""Inference parameters that are passed to the main model in order
|
||||
to efficienly calculate and store the context during inference."""
|
||||
|
||||
max_seqlen: int
|
||||
max_batch_size: int
|
||||
seqlen_offset: int = 0
|
||||
batch_size_offset: int = 0
|
||||
key_value_memory_dict: dict = field(default_factory=dict)
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
def reset(self, max_seqlen, max_batch_size):
|
||||
self.max_seqlen = max_seqlen
|
||||
self.max_batch_size = max_batch_size
|
||||
self.seqlen_offset = 0
|
||||
if self.lengths_per_sample is not None:
|
||||
self.lengths_per_sample.zero_()
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf. Done in-place."""
|
||||
if top_p <= 0.0 or top_p >= 1.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits.masked_fill_(indices_to_remove, float("-inf"))
|
||||
|
||||
|
||||
def modify_logit_for_repetition_penalty(
|
||||
logits, prev_output_tokens, repetition_penalty=1.0
|
||||
):
|
||||
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
|
||||
logits: (batch_size, vocab_size)
|
||||
prev_output_tokens: (batch_size, seq_len)
|
||||
"""
|
||||
if repetition_penalty == 1.0:
|
||||
return logits
|
||||
score = torch.gather(logits, 1, prev_output_tokens)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
logits.scatter_(1, prev_output_tokens, score)
|
||||
return logits
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
"""Sample from top-k logits.
|
||||
Arguments:
|
||||
logits: Tensor of shape (batch_size, vocab_size)
|
||||
"""
|
||||
if top_k == 1: # Short-circuit for greedy decoding
|
||||
return logits.argmax(dim=-1)
|
||||
else:
|
||||
if top_p > 0.0:
|
||||
assert top_p <= 1.0, "top-p should be in (0, 1]."
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
if temperature != 1.0:
|
||||
logits_top /= temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(
|
||||
torch.softmax(logits_top, dim=-1), num_samples=1
|
||||
).squeeze(dim=-1),
|
||||
]
|
||||
else:
|
||||
# Clone so that when we modify for top_p we don't change the original logits
|
||||
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(
|
||||
torch.softmax(logits_top, dim=-1), num_samples=1
|
||||
).squeeze(dim=-1)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(
|
||||
input_ids,
|
||||
model,
|
||||
max_new_tokens,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
temperature=1.0,
|
||||
repetition_penalty=1.0,
|
||||
eos_token_id=None,
|
||||
teacher_outputs=None,
|
||||
vocab_size=None,
|
||||
streamer: Optional[TextStreamer] = None,
|
||||
):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
then top-p.
|
||||
We assume that all sequences in the same batch have the same length.
|
||||
|
||||
Arguments:
|
||||
input_ids: (batch, seq_len)
|
||||
max_new_tokens: int
|
||||
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
|
||||
logits, the next token is taken from the teacher_outputs. Useful for testing.
|
||||
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
||||
sequences: (batch, max_length)
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
if streamer is not None:
|
||||
streamer.put(input_ids.cpu())
|
||||
|
||||
max_length = input_ids.shape[1] + max_new_tokens
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
|
||||
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
|
||||
|
||||
def get_logits(input_ids, inference_params):
|
||||
decoding = inference_params.seqlen_offset > 0
|
||||
if decoding:
|
||||
position_ids = torch.full(
|
||||
(batch_size, 1),
|
||||
inference_params.seqlen_offset,
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
else:
|
||||
position_ids = None
|
||||
logits = model(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
return logits[..., :vocab_size] if vocab_size is not None else logits
|
||||
|
||||
def sample_tokens(logits, inference_params):
|
||||
if (
|
||||
teacher_outputs is None
|
||||
or teacher_output_len <= inference_params.seqlen_offset
|
||||
):
|
||||
token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
token = teacher_outputs[:, inference_params.seqlen_offset]
|
||||
# return rearrange(token, "b -> b 1")
|
||||
return token.unsqueeze(1)
|
||||
|
||||
def should_stop(current_token, inference_params):
|
||||
if inference_params.seqlen_offset == 0:
|
||||
return False
|
||||
if eos_token_id is not None and (current_token == eos_token_id).all():
|
||||
return True
|
||||
if inference_params.seqlen_offset >= max_length - 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
scores, sequences = [], [input_ids]
|
||||
sequences_cat = input_ids
|
||||
while not should_stop(sequences[-1], inference_params):
|
||||
scores.append(get_logits(sequences[-1], inference_params))
|
||||
inference_params.seqlen_offset += sequences[-1].shape[1]
|
||||
if repetition_penalty == 1.0:
|
||||
sampled_tokens = sample_tokens(scores[-1], inference_params)
|
||||
else:
|
||||
logits = modify_logit_for_repetition_penalty(
|
||||
scores[-1].clone(), sequences_cat, repetition_penalty
|
||||
)
|
||||
sampled_tokens = sample_tokens(logits, inference_params)
|
||||
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
|
||||
sequences.append(sampled_tokens)
|
||||
if streamer is not None:
|
||||
streamer.put(sampled_tokens.cpu())
|
||||
if streamer is not None:
|
||||
streamer.end()
|
||||
output_cls = (
|
||||
GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
)
|
||||
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
|
||||
|
||||
|
||||
class GenerationMixin:
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids,
|
||||
max_new_tokens,
|
||||
top_k=1,
|
||||
top_p=0.0,
|
||||
temperature=1.0,
|
||||
return_dict_in_generate=False,
|
||||
output_scores=False,
|
||||
**kwargs,
|
||||
):
|
||||
output = decode(
|
||||
input_ids,
|
||||
self,
|
||||
max_new_tokens,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
if not output_scores:
|
||||
output.scores = None
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, residual_in_fp32=False):
|
||||
"""
|
||||
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
|
||||
|
||||
This Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
The standard block is: LN -> MHA/MLP -> Add.
|
||||
[Ref: https://arxiv.org/abs/2002.04745]
|
||||
Here we have: Add -> LN -> Mixer, returning both
|
||||
the hidden_states (output of the mixer) and the residual.
|
||||
This is purely for performance reasons, as we can fuse add and LayerNorm.
|
||||
The residual needs to be provided (except for the very first block).
|
||||
"""
|
||||
super().__init__()
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.norm = norm_cls(dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
residual: Optional[Tensor] = None,
|
||||
inference_params=None,
|
||||
):
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
hidden_states: the sequence to the encoder layer (required).
|
||||
residual: hidden_states = Mixer(LN(residual))
|
||||
"""
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
||||
return hidden_states, residual
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.mixer.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
|
||||
return rms_norm(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
residual=residual,
|
||||
eps=self.eps,
|
||||
prenorm=prenorm,
|
||||
)
|
||||
|
||||
|
||||
class Mamba(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_state=16,
|
||||
d_conv=4,
|
||||
expand=2,
|
||||
dt_rank="auto",
|
||||
dt_min=0.001,
|
||||
dt_max=0.1,
|
||||
dt_init="random",
|
||||
dt_scale=1.0,
|
||||
dt_init_floor=1e-4,
|
||||
conv_bias=True,
|
||||
bias=False,
|
||||
use_fast_path=True, # Fused kernel options
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.d_state = d_state
|
||||
self.d_conv = d_conv
|
||||
self.expand = expand
|
||||
self.d_inner = int(self.expand * self.d_model)
|
||||
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||||
self.use_fast_path = use_fast_path
|
||||
self.layer_idx = layer_idx
|
||||
self.dt_proj_in_feature = self.dt_rank
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=self.d_inner,
|
||||
out_channels=self.d_inner,
|
||||
bias=conv_bias,
|
||||
kernel_size=d_conv,
|
||||
groups=self.d_inner,
|
||||
padding=d_conv - 1,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self.activation = "silu"
|
||||
self.act = nn.SiLU()
|
||||
|
||||
self.x_proj = nn.Linear(
|
||||
self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
|
||||
)
|
||||
|
||||
self.dt_proj = nn.Linear(
|
||||
self.dt_rank, self.d_inner, bias=True, **factory_kwargs
|
||||
)
|
||||
|
||||
# Initialize special dt projection to preserve variance at initialization
|
||||
dt_init_std = self.dt_rank**-0.5 * dt_scale
|
||||
if dt_init == "constant":
|
||||
nn.init.constant_(self.dt_proj.weight, dt_init_std)
|
||||
elif dt_init == "random":
|
||||
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
||||
dt = torch.exp(
|
||||
torch.rand(self.d_inner, **factory_kwargs)
|
||||
* (math.log(dt_max) - math.log(dt_min))
|
||||
+ math.log(dt_min)
|
||||
).clamp(min=dt_init_floor)
|
||||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||||
with torch.no_grad():
|
||||
self.dt_proj.bias.copy_(inv_dt)
|
||||
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
||||
self.dt_proj.bias._no_reinit = True
|
||||
|
||||
# S4D real initialization
|
||||
A = repeat(
|
||||
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
|
||||
"n -> d n",
|
||||
d=self.d_inner,
|
||||
).contiguous()
|
||||
A_log = torch.log(A) # Keep A_log in fp32
|
||||
self.A_log = nn.Parameter(A_log)
|
||||
self.A_log._no_weight_decay = True
|
||||
|
||||
# D "skip" parameter
|
||||
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
|
||||
self.D._no_weight_decay = True
|
||||
|
||||
self.out_proj = nn.Linear(
|
||||
self.d_inner, self.d_model, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, inference_params=None):
|
||||
"""
|
||||
hidden_states: (B, L, D)
|
||||
Returns: same shape as hidden_states
|
||||
"""
|
||||
batch, seqlen, _ = hidden_states.shape
|
||||
|
||||
conv_state, ssm_state = None, None
|
||||
if inference_params is not None:
|
||||
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
|
||||
if inference_params.seqlen_offset > 0:
|
||||
# The states are updated inplace
|
||||
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
|
||||
return out
|
||||
|
||||
# We do matmul and transpose BLH -> HBL at the same time
|
||||
xz = rearrange(
|
||||
self.in_proj(rearrange(hidden_states, "b l d -> d (b l)").t()).t(),
|
||||
"d (b l) -> b d l",
|
||||
l=seqlen,
|
||||
)
|
||||
|
||||
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
||||
# In the backward pass we write dx and dz next to each other to avoid torch.cat
|
||||
x, z = xz.chunk(2, dim=1)
|
||||
# Compute short convolution
|
||||
if conv_state is not None:
|
||||
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
|
||||
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
|
||||
conv_state.copy_(
|
||||
F.pad(x, (self.d_conv - x.shape[-1], 0))
|
||||
) # Update state (B D W)
|
||||
# if causal_conv1d_fn is None:
|
||||
x = self.act(self.conv1d(x)[..., :seqlen])
|
||||
|
||||
# We're careful here about the layout, to avoid extra transposes.
|
||||
# We want dt to have d as the slowest moving dimension
|
||||
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
|
||||
x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
|
||||
dt, B, C = torch.split(
|
||||
x_dbl, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
|
||||
)
|
||||
|
||||
dt = self.dt_proj(dt).t()
|
||||
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
|
||||
B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
|
||||
assert self.activation in ["silu", "swish"]
|
||||
y = selective_scan(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
self.D.float(),
|
||||
z=z,
|
||||
delta_bias=None,
|
||||
delta_softplus=True,
|
||||
return_last_state=ssm_state is not None,
|
||||
)
|
||||
if ssm_state is not None:
|
||||
y, last_state = y
|
||||
ssm_state.copy_(last_state)
|
||||
y = rearrange(y, "b d l -> b l d")
|
||||
out = self.out_proj(y)
|
||||
return out
|
||||
|
||||
def step(self, hidden_states, conv_state, ssm_state):
|
||||
dtype = hidden_states.dtype
|
||||
assert (
|
||||
hidden_states.shape[1] == 1
|
||||
), "Only support decoding with 1 token at a time for now"
|
||||
xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
|
||||
x, z = xz.chunk(2, dim=-1) # (B D)
|
||||
|
||||
# Conv step
|
||||
conv_state.copy_(
|
||||
torch.roll(conv_state, shifts=-1, dims=-1)
|
||||
) # Update state (B D W)
|
||||
conv_state[:, :, -1] = x
|
||||
x = torch.sum(
|
||||
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
|
||||
) # (B D)
|
||||
if self.conv1d.bias is not None:
|
||||
x = x + self.conv1d.bias
|
||||
x = self.act(x).to(dtype=dtype)
|
||||
|
||||
x_db = self.x_proj(x) # (B dt_rank+2*d_state)
|
||||
dt, B, C = torch.split(
|
||||
x_db, [self.dt_proj_in_feature, self.d_state, self.d_state], dim=-1
|
||||
)
|
||||
dt = self.dt_proj(dt)
|
||||
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
|
||||
|
||||
# SSM step
|
||||
# Discretize A and B
|
||||
dt = F.softplus(dt)
|
||||
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
|
||||
dB = torch.einsum("bd,bn->bdn", dt, B)
|
||||
ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
|
||||
y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
|
||||
y = y + self.D.to(dtype) * x
|
||||
y = y * self.act(z) # (B D)
|
||||
|
||||
out = self.out_proj(y)
|
||||
return out.unsqueeze(1), conv_state, ssm_state
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
device = self.out_proj.weight.device
|
||||
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_conv,
|
||||
device=device,
|
||||
dtype=conv_dtype,
|
||||
)
|
||||
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
|
||||
# ssm_dtype = torch.float32
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_state,
|
||||
device=device,
|
||||
dtype=ssm_dtype,
|
||||
)
|
||||
return conv_state, ssm_state
|
||||
|
||||
def _get_states_from_cache(
|
||||
self, inference_params, batch_size, initialize_states=False
|
||||
):
|
||||
assert self.layer_idx is not None
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
batch_shape = (batch_size,)
|
||||
conv_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_conv,
|
||||
device=self.conv1d.weight.device,
|
||||
dtype=self.conv1d.weight.dtype,
|
||||
)
|
||||
ssm_state = torch.zeros(
|
||||
batch_size,
|
||||
self.d_model * self.expand,
|
||||
self.d_state,
|
||||
device=self.dt_proj.weight.device,
|
||||
dtype=self.dt_proj.weight.dtype,
|
||||
# dtype=torch.float32,
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (
|
||||
conv_state,
|
||||
ssm_state,
|
||||
)
|
||||
else:
|
||||
conv_state, ssm_state = inference_params.key_value_memory_dict[
|
||||
self.layer_idx
|
||||
]
|
||||
# TODO: What if batch size changes between generation, and we reuse the same states?
|
||||
if initialize_states:
|
||||
conv_state.zero_()
|
||||
ssm_state.zero_()
|
||||
return conv_state, ssm_state
|
||||
|
||||
|
||||
def create_block(
|
||||
d_model,
|
||||
ssm_cfg=None,
|
||||
norm_epsilon=1e-5,
|
||||
rms_norm=False,
|
||||
residual_in_fp32=False,
|
||||
layer_idx=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
if ssm_cfg is None:
|
||||
ssm_cfg = {}
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
|
||||
norm_cls = partial(
|
||||
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
||||
)
|
||||
block = Block(
|
||||
d_model,
|
||||
mixer_cls,
|
||||
norm_cls=norm_cls,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
)
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
|
||||
class MixerModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_layer: int,
|
||||
vocab_size: int,
|
||||
ssm_cfg=None,
|
||||
norm_epsilon: float = 1e-5,
|
||||
rms_norm: bool = False,
|
||||
initializer_cfg=None,
|
||||
residual_in_fp32=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
|
||||
self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
create_block(
|
||||
d_model,
|
||||
ssm_cfg=ssm_cfg,
|
||||
norm_epsilon=norm_epsilon,
|
||||
rms_norm=rms_norm,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
layer_idx=i,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for i in range(n_layer)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
|
||||
d_model, eps=norm_epsilon, **factory_kwargs
|
||||
)
|
||||
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return {
|
||||
i: layer.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
for i, layer in enumerate(self.layers)
|
||||
}
|
||||
|
||||
def forward(self, input_ids, inference_params=None):
|
||||
hidden_states = self.embedding(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
hidden_states, residual, inference_params=inference_params
|
||||
)
|
||||
residual = (hidden_states + residual) if residual is not None else hidden_states
|
||||
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
initializer_cfg=None,
|
||||
device='cpu',
|
||||
dtype=torch.float32,
|
||||
) -> None:
|
||||
self.config = config
|
||||
d_model = config.d_model
|
||||
n_layer = config.n_layer
|
||||
vocab_size = config.vocab_size
|
||||
ssm_cfg = config.ssm_cfg
|
||||
rms_norm = config.rms_norm
|
||||
residual_in_fp32 = config.residual_in_fp32
|
||||
pad_vocab_size_multiple = config.pad_vocab_size_multiple
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
super().__init__()
|
||||
if vocab_size % pad_vocab_size_multiple != 0:
|
||||
vocab_size += pad_vocab_size_multiple - (
|
||||
vocab_size % pad_vocab_size_multiple
|
||||
)
|
||||
self.backbone = MixerModel(
|
||||
d_model=d_model,
|
||||
n_layer=n_layer,
|
||||
vocab_size=vocab_size,
|
||||
ssm_cfg=ssm_cfg,
|
||||
rms_norm=rms_norm,
|
||||
initializer_cfg=initializer_cfg,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
_init_weights,
|
||||
n_layer=n_layer,
|
||||
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.backbone.embedding.weight
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.backbone.allocate_inference_cache(
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0
|
||||
):
|
||||
"""
|
||||
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, device='cpu', dtype=torch.float32, **kwargs):
|
||||
config_data = load_config_hf(pretrained_model_name)
|
||||
config = MambaConfig(**config_data)
|
||||
model = cls(config, device=device, dtype=dtype, **kwargs)
|
||||
model.load_state_dict(
|
||||
load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
|
||||
)
|
||||
return model
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
"""
|
||||
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
||||
Save the model and its configuration file to a directory.
|
||||
"""
|
||||
# Ensure save_directory exists
|
||||
if not os.path.exists(save_directory):
|
||||
os.makedirs(save_directory)
|
||||
|
||||
# Save the model's state_dict
|
||||
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
||||
torch.save(self.state_dict(), model_path)
|
||||
|
||||
# Save the configuration of the model
|
||||
config_path = os.path.join(save_directory, "config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(self.config.__dict__, f)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
Loading…
Reference in a new issue