Add llama2 gpu low memory example (#9514)

* Add low memory example

* Minor fixes

* Update readme.md
This commit is contained in:
Zheng, Yi 2023-12-05 17:29:48 +08:00 committed by GitHub
parent 06febb5fa7
commit d154b38bf9
5 changed files with 382 additions and 7 deletions

View file

@ -137,7 +137,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
| Model | CPU Example | GPU Example | | Model | CPU Example | GPU Example |
|------------|----------------------------------------------------------------|-----------------------------------------------------------------| |------------|----------------------------------------------------------------|-----------------------------------------------------------------|
| LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/vicuna)| | LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/vicuna)|
| LLaMA 2 | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/llama2) | | LLaMA 2 | [link1](python/llm/example/CPU/Native-Models), [link2](python/llm/example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link1](python/llm/example/GPU/HF-Transformers-AutoModels/Model/llama2), [link2-low GPU memory example](python/llm/example/GPU/PyTorch-Models/Model/llama2#example-2---low-memory-version-predict-tokens-using-generate-api) |
| ChatGLM | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm) | | | ChatGLM | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm) | |
| ChatGLM2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm2) | | ChatGLM2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm2) |
| ChatGLM3 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm3) | | ChatGLM3 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](python/llm/example/GPU/HF-Transformers-AutoModels/Model/chatglm3) |

View file

@ -39,7 +39,7 @@ Over 20 models have been optimized/verified on `bigdl-llm`, including *LLaMA/LLa
| Model | CPU Example | GPU Example | | Model | CPU Example | GPU Example |
|------------|----------------------------------------------------------------|-----------------------------------------------------------------| |------------|----------------------------------------------------------------|-----------------------------------------------------------------|
| LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](example/GPU/HF-Transformers-AutoModels/Model/vicuna)| | LLaMA *(such as Vicuna, Guanaco, Koala, Baize, WizardLM, etc.)* | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/vicuna) |[link](example/GPU/HF-Transformers-AutoModels/Model/vicuna)|
| LLaMA 2 | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link](example/GPU/HF-Transformers-AutoModels/Model/llama2) | | LLaMA 2 | [link1](example/CPU/Native-Models), [link2](example/CPU/HF-Transformers-AutoModels/Model/llama2) | [link1](example/GPU/HF-Transformers-AutoModels/Model/llama2), [link2-low GPU memory example](example/GPU/PyTorch-Models/Model/llama2#example-2---low-memory-version-predict-tokens-using-generate-api) |
| ChatGLM | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm) | | | ChatGLM | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm) | |
| ChatGLM2 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm2) | | ChatGLM2 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm2) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm2) |
| ChatGLM3 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm3) | | ChatGLM3 | [link](example/CPU/HF-Transformers-AutoModels/Model/chatglm3) | [link](example/GPU/HF-Transformers-AutoModels/Model/chatglm3) |

View file

@ -1,10 +1,10 @@
# Llama2 # Llama2
In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Llama2 models. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) and [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) as reference Llama2 models. In this directory, you will find examples on how you could use BigDL-LLM `optimize_model` API to accelerate Llama2 models. For illustration purposes, we utilize the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) and [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) as reference Llama2 models.
## Requirements ## Requirements
To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. To run these examples with BigDL-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information.
## Example: Predict Tokens using `generate()` API ## Example 1 - Basic Version: Predict Tokens using `generate()` API
In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs. In the example [generate.py](./generate.py), we show a basic use case for a Llama2 model to predict the next N tokens using `generate()` API, with BigDL-LLM INT4 optimizations on Intel GPUs.
### 1. Install ### 1. Install
We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#). We suggest using conda to manage the Python environment. For more information about conda installation, please refer to [here](https://docs.conda.io/en/latest/miniconda.html#).
@ -43,7 +43,7 @@ In the example, several arguments can be passed to satisfy your requirements:
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`. - `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`. - `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
#### 2.3 Sample Output #### Sample Output
#### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) #### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
```log ```log
Inference time: xxxx s Inference time: xxxx s
@ -67,3 +67,24 @@ What is AI?
AI, or artificial intelligence, refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving, AI, or artificial intelligence, refers to the ability of machines to perform tasks that would typically require human intelligence, such as learning, problem-solving,
``` ```
## Example 2 - Low Memory Version: Predict Tokens using `generate()` API
If you're not able to load the full 4-bit model (e.g. [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)) in one GPU as shown in [Example 1](#example-1---basic-version-predict-tokens-using-generate-api), you may try this example instead.
In [low_memory_generate.py](./low_memory_generate.py), we show a way to load very large models with very low GPU memory footprint. However this could be much slower than the standard way. The implementation is adapted from [here](https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag).
### 1. Environment setup
Please refer to [Example 1](#example-1---basic-version-predict-tokens-using-generate-api) for more information.
### 2. Run
```bash
python ./low_memory_generate.py --split-weight --splitted-weights-path ${SPLITTED_WEIGHTS_PATH}
```
In the example, besides arguments in [Example 1](#3-run), several other arguments can be passed to satisfy your requirements:
- `--splitted-weights-path`: argument defining folder saving per-layer weights.
- `--split-weight`: argument defining whether to split weights by layer. If this argument is enabled, per-layer weights will be generated and saved to `--splitted-weights-path`. This argument only needs to be enabled once for the same model.
- `--max-cache-num`: argument defining the maximum number of weights saved in the cache. You can adjust this argument based on your GPU memory. It is default to be 200. For [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf), GPU peak memory is around 3G when it is set to 0 and 15G when it is set to 200.

View file

@ -0,0 +1,354 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# The code is adapted from: https://www.kaggle.com/code/simjeg/platypus2-70b-without-wikipedia-rag.
#
import argparse
import gc
import os
import time
import warnings
from typing import Dict, List, Optional, Union
import intel_extension_for_pytorch as ipex
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from safetensors.torch import load_file, save_file
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from bigdl.llm import optimize_model
from bigdl.llm.transformers.low_bit_linear import FP4Params, LowBitLinear
MAX_LENGTH = 4096
# you could tune the prompt based on your own model,
# here the prompt tuning refers to https://huggingface.co/georgesung/llama2_7b_chat_uncensored#prompt-style
LLAMA2_PROMPT_FORMAT = """### HUMAN:
{prompt}
### RESPONSE:
"""
# Modified based on https://github.com/huggingface/accelerate/blob/d25efa71ce76a5f5911a1fc6c039979d7248596f/src/accelerate/utils/modeling.py#L238
def set_module_tensor_to_device_with_cache(
module: nn.Module,
tensor_name: str,
device: Union[int, str, torch.device],
value: Optional[torch.Tensor],
dtype: Optional[Union[str, torch.dtype]] = None,
cache_dict: Optional[Dict[str, FP4Params]] = None,
max_cache_num: Optional[int] = 100,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
param_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`):
The value of the tensor (useful when going from the meta device to any other device).
dtype (`torch.dtype`, *optional*):
If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
the dtype of the existing parameter in the model.
cache_dict (`Dict`, *optional*):
The cache dict to save layer weights. This can improve the loading speed.
max_cache_num (`int`, *optional*):
The maximum number of weights saved in the cache_dict. You can adjust this number based on your GPU memory.
Default is 100.
"""
original_tensor_name = tensor_name
assert value is not None
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
# Use cache to load weights
if original_tensor_name in cache_dict:
module._parameters[tensor_name] = cache_dict[original_tensor_name]
return
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
if dtype is None:
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
value = value.to(old_value.dtype)
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
value = value.to(dtype)
param = module._parameters[tensor_name] if tensor_name in module._parameters else None
param_cls = type(param)
with torch.no_grad():
if isinstance(module, LowBitLinear):
# load cpu int4 weights
new_value = FP4Params(data=value,
requires_grad=False,
quantized=True,
_shape=(module.out_features, module.in_features),
convert_shape_only=False,
qtype=2).to(device)
if len(cache_dict) < max_cache_num:
cache_dict.update({original_tensor_name: new_value})
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)
if is_buffer:
module._buffers[tensor_name] = new_value
elif isinstance(module, LowBitLinear):
module._parameters[tensor_name] = new_value
elif value is not None or torch.device(device) != module._parameters[tensor_name].device:
param_cls = type(module._parameters[tensor_name])
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
module._parameters[tensor_name] = new_value
else:
raise NotImplementedError
class LowMemoryLlama(GenerationMixin):
def __init__(self, model_path: str, splitted_weights_path: str, max_cache_num: Optional[int] = 100):
"""
Low memory version of LlamaForCausalLM : the model is splitted into layer shards to reduce GPU memory usage.
During the forward pass, the inputs are processed layer by layer, and the GPU memory is freed after each layer.
Parameters
----------
model_path (`str`):
The huggingface repo id or path to the huggingface checkpoint folder that including config.json and tokenizer.model.
splitted_weights_path (`str`):
The folder including int4 weights per layer. You can use `LowMemoryLlama.split_and_convert_to_cpu_int4_weights` to generate those weights.
max_cache_num (`int`, *optional*):
The maximum number of weights saved in the cache_dict. You can adjust this number based on your GPU memory.
Default is 100.
"""
super().__init__()
# Save parameters
self.model_path = model_path
self.splitted_weights_path = splitted_weights_path
self.device = torch.device('xpu:0')
self.layer_weight_cache = {} # initialize weight cache dict
self.max_cache_num = max_cache_num
# Create model
self._create_model()
# Check if `self.splitted_weights_path` exists
self._check_split_weight_path()
# Initialize attention mask and position ids to further improve the inference speed
self._generate_att_mask_and_pos_id()
@classmethod
def split_and_convert_to_cpu_int4_weights(cls, model_path, safetensor_per_layer_path):
model = AutoModelForCausalLM.from_pretrained(model_path, use_safetensors=True)
model = optimize_model(model)
state_dict = model.state_dict()
layer_names = ["model.embed_tokens."] + [f"model.layers.{i}." for i in range(len(model.model.layers))] + ["model.norm.", "lm_head."]
for layer_name in tqdm(layer_names):
local_state_dict = {k: v.contiguous() for k, v in state_dict.items() if k.startswith(layer_name)}
save_name = os.path.join(safetensor_per_layer_path, f"{layer_name}safetensors")
save_file(local_state_dict, save_name)
print(f'Save splitted safetensor weights to {safetensor_per_layer_path}')
def _create_model(self):
# Load config
self.config = AutoConfig.from_pretrained(self.model_path)
if self.config.pretraining_tp > 1:
warnings.warn("Set config.pretraining_tp = 1 to use int4 inference")
self.config.pretraining_tp = 1
if not self.config.use_cache:
warnings.warn("Set config.use_cache to further improve performance")
self.config.use_cache = True
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
# Model initialization
self._init_model()
# Tie with self.model
self.layer_names = ["model.embed_tokens"] + [f"model.layers.{i}" for i in range(len(self.model.model.layers))] + ["model.norm", "lm_head"]
self.generation_config = self.model.generation_config
self.main_input_name = self.model.main_input_name
def _check_split_weight_path(self):
for layer_name in self.layer_names:
split_weight_path = os.path.join(self.splitted_weights_path, f"{layer_name}.safetensors")
if not os.path.exists(split_weight_path):
raise FileNotFoundError(f"Weight file {split_weight_path} is missing."
f"You can generate it using `LowMemoryLlama.split_and_convert_to_cpu_int4_weights`.")
def _generate_att_mask_and_pos_id(self):
self.attention_mask = torch.full((MAX_LENGTH, MAX_LENGTH), torch.finfo(torch.float16).min, device=self.device)
mask_cond = torch.arange(self.attention_mask.size(-1), device=self.device)
self.attention_mask.masked_fill_(mask_cond < (mask_cond + 1).view(self.attention_mask.size(-1), 1), 0)
self.attention_mask = self.attention_mask.to(torch.float16)[None, None, :, :]
self.position_ids = torch.arange(MAX_LENGTH, dtype=torch.long, device=self.device)[None, :]
def _init_model(self):
# Load meta model (no memory used)
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(self.config)
self.model = optimize_model(self.model)
self.model.tie_weights()
self.layers = [self.model.model.embed_tokens] + list(self.model.model.layers) \
+ [self.model.model.norm, self.model.lm_head]
# Move buffers to device (not that much GPU memory used)
for buffer_name, buffer in self.model.named_buffers():
set_module_tensor_to_device_with_cache(self.model, buffer_name, self.device, value=buffer, dtype=buffer.dtype,
cache_dict=self.layer_weight_cache, max_cache_num=self.max_cache_num)
def move_layer_to_device(self, state_dict):
for param_name, param in state_dict.items():
set_module_tensor_to_device_with_cache(self.model, param_name, self.device, value=param, dtype=param.dtype,
cache_dict=self.layer_weight_cache, max_cache_num=self.max_cache_num)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
*args,
**kwargs,
):
# Reinit model and clean memory
del self.model
gc.collect()
torch.xpu.empty_cache()
self._init_model()
# Send batch to device
inputs = input_ids.to(self.device)
# Set up kv cache
kv_cache = {}
if past_key_values is None:
past_key_values = {}
with torch.inference_mode():
# Generate attention mask and position ids
current_shape = inputs.shape[1]
if past_key_values.get(self.layer_names[1], None):
pre_shape = past_key_values[self.layer_names[1]][0].size(2)
else:
pre_shape = 0
pos = self.position_ids[:, pre_shape : current_shape + pre_shape]
attn = self.attention_mask[:, :, -current_shape:, - current_shape - pre_shape:]
for (layer_name, layer) in tqdm(zip(self.layer_names, self.layers), total=len(self.layers)):
# Load current layer to device
state_dict = load_file(os.path.join(self.splitted_weights_path, f"{layer_name}.safetensors"), device="cpu")
self.move_layer_to_device(state_dict)
# Run layer
if layer_name in ("model.embed_tokens", "model.norm", "lm_head"):
inputs = layer(inputs)
else:
inputs, new_kv_cache = layer(inputs, use_cache=True, past_key_value=past_key_values.get(layer_name, None),
position_ids=pos, attention_mask=attn)
kv_cache[layer_name] = new_kv_cache
# Delete weight before moving to('meta')
for module in layer.modules():
if hasattr(module, "weight"):
del module.weight
# Remove previous layer from memory (including buffers)
layer.to("meta")
result = CausalLMOutputWithPast(
logits=inputs.detach(),
past_key_values=kv_cache,
)
return result
def can_generate(self):
return True
def prepare_inputs_for_generation(self, *args, **kwargs):
return self.model.prepare_inputs_for_generation(*args, **kwargs)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Llama2 model')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`,'
' `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--prompt', type=str, default="What is AI?",
help='Prompt to infer')
parser.add_argument('--n-predict', type=int, default=32,
help='Max tokens to predict')
parser.add_argument('--splitted-weights-path', type=str, required=True,
help='The folder saving per-layer weights. You can use'
' LowMemoryLlama.split_and_convert_to_cpu_int4_weights() to generate those weights.')
parser.add_argument('--split-weight', action='store_true',
help='Whether to split weights by layer. If this argument is enabled, per-layer weights will'
' be generated and saved to `--splitted-weights-path`. This argument only needs to be'
' enabled once for the same model.')
parser.add_argument('--max-cache-num', type=int, default=200,
help='The maximum number of weights saved in the cache_dict. You can adjust this'
' number based on your GPU memory. Default is 200.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
splitted_weights_path = args.splitted_weights_path
if args.split_weight:
os.makedirs(splitted_weights_path, exist_ok=True)
LowMemoryLlama.split_and_convert_to_cpu_int4_weights(model_path, splitted_weights_path)
model = LowMemoryLlama(model_path, splitted_weights_path, args.max_cache_num)
with torch.inference_mode():
prompt = LLAMA2_PROMPT_FORMAT.format(prompt=args.prompt)
input_ids = model.tokenizer.encode(prompt, return_tensors="pt").to('xpu')
st = time.time()
output = model.generate(input_ids,
max_new_tokens=args.n_predict)
torch.xpu.synchronize()
end = time.time()
output = output.cpu()
output_str = model.tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Inference time: {end-st} s')
print('-'*20, 'Output', '-'*20)
print(output_str)

View file

@ -226,8 +226,8 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
invalidInputError(isinstance(model, torch.nn.Module), invalidInputError(isinstance(model, torch.nn.Module),
"model should be an instance of " "model should be an instance of "
f"`torch.nn.Module`, but got {type(model)} at last.") f"`torch.nn.Module`, but got {type(model)} at last.")
invalidInputError(model.device.type == 'cpu', invalidInputError(model.device.type in ('cpu', 'meta'),
"Expect model on device `cpu`, " "Expect model on device `cpu` or `meta`, "
f"but got device type {model.device.type}") f"but got device type {model.device.type}")
if kwargs.pop("replace_embedding", False): if kwargs.pop("replace_embedding", False):
warnings.warn("replace_embedding is deprecated and will be removed in a future version," warnings.warn("replace_embedding is deprecated and will be removed in a future version,"