disable rwkv5 fp16 (#10699)

This commit is contained in:
Yishuo Wang 2024-04-09 16:42:11 +08:00 committed by GitHub
parent 6a32216269
commit e438f941f2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -36,6 +36,7 @@ import torch
import torch.nn.functional as F
from typing import List, Optional
from ipex_llm.utils.common.log4Error import invalidInputError
def extract_key_value(self, hidden, state=None):
@ -265,6 +266,8 @@ def rwkv_model_forward_wrapper(origin_rwkv_model_forward):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
invalidInputError(self.embeddings.weight.dtype == torch.float,
"Only fp32 is supported for now, fp16 and bf16 are not supported")
use_cache = use_cache if use_cache is not None else self.config.use_cache
# change `state` layout and put `num_hidden_layers` to the highest dim
if input_ids.device.type == "xpu" and use_cache and state is None: