From e438f941f2e6484c5371db9e783dc7d26bba3bee Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 9 Apr 2024 16:42:11 +0800 Subject: [PATCH] disable rwkv5 fp16 (#10699) --- python/llm/src/ipex_llm/transformers/models/rwkv5.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/rwkv5.py b/python/llm/src/ipex_llm/transformers/models/rwkv5.py index 358c5a79..5619c16f 100644 --- a/python/llm/src/ipex_llm/transformers/models/rwkv5.py +++ b/python/llm/src/ipex_llm/transformers/models/rwkv5.py @@ -36,6 +36,7 @@ import torch import torch.nn.functional as F from typing import List, Optional +from ipex_llm.utils.common.log4Error import invalidInputError def extract_key_value(self, hidden, state=None): @@ -265,6 +266,8 @@ def rwkv_model_forward_wrapper(origin_rwkv_model_forward): output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): + invalidInputError(self.embeddings.weight.dtype == torch.float, + "Only fp32 is supported for now, fp16 and bf16 are not supported") use_cache = use_cache if use_cache is not None else self.config.use_cache # change `state` layout and put `num_hidden_layers` to the highest dim if input_ids.device.type == "xpu" and use_cache and state is None: