disable rwkv5 fp16 (#10699)
This commit is contained in:
parent
6a32216269
commit
e438f941f2
1 changed files with 3 additions and 0 deletions
|
|
@ -36,6 +36,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
from typing import List, Optional
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
||||
|
||||
def extract_key_value(self, hidden, state=None):
|
||||
|
|
@ -265,6 +266,8 @@ def rwkv_model_forward_wrapper(origin_rwkv_model_forward):
|
|||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
invalidInputError(self.embeddings.weight.dtype == torch.float,
|
||||
"Only fp32 is supported for now, fp16 and bf16 are not supported")
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
# change `state` layout and put `num_hidden_layers` to the highest dim
|
||||
if input_ids.device.type == "xpu" and use_cache and state is None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue