From 437a349dd65a0b96ea548f02f39ad18fa6d4ddeb Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 29 Mar 2024 17:56:45 +0800 Subject: [PATCH] fix rwkv with pip installer (#10591) --- python/llm/src/ipex_llm/transformers/models/rwkv4.py | 7 +++++-- python/llm/src/ipex_llm/transformers/models/rwkv5.py | 5 +++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/rwkv4.py b/python/llm/src/ipex_llm/transformers/models/rwkv4.py index b29af810..2c0dc7ae 100644 --- a/python/llm/src/ipex_llm/transformers/models/rwkv4.py +++ b/python/llm/src/ipex_llm/transformers/models/rwkv4.py @@ -54,7 +54,7 @@ def extract_key_value(self, hidden, state=None): self.time_mix_key.data, self.time_mix_value.data, self.time_mix_receptance.data, - ]) + ]).to(dtype=hidden.dtype) import linear_q4_0 mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) @@ -119,6 +119,8 @@ def rwkv_attention_forward( layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None if hidden.device.type == "xpu": + self.time_decay.data = self.time_decay.data.to(dtype=key.dtype) + self.time_first.data = self.time_first.data.to(dtype=key.dtype) rwkv, layer_state = rwkv_linear_attention_xpu( self.time_decay, self.time_first, @@ -162,7 +164,8 @@ def rwkv_ffn_forward( shifted = shifted.contiguous() if not hasattr(self, "mixed_mix"): - self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data]) + self.mixed_mix = torch.cat([self.time_mix_key.data, + self.time_mix_receptance.data]).to(dtype=hidden.dtype) import linear_q4_0 mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) diff --git a/python/llm/src/ipex_llm/transformers/models/rwkv5.py b/python/llm/src/ipex_llm/transformers/models/rwkv5.py index 7d4d6d03..358c5a79 100644 --- a/python/llm/src/ipex_llm/transformers/models/rwkv5.py +++ b/python/llm/src/ipex_llm/transformers/models/rwkv5.py @@ -55,7 +55,7 @@ def extract_key_value(self, hidden, state=None): self.time_mix_value.data, self.time_mix_receptance.data, self.time_mix_gate.data, - ]) + ]).to(dtype=hidden.dtype) import linear_q4_0 mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix) @@ -232,7 +232,8 @@ def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward): shifted = shifted.unsqueeze(1) if not hasattr(self, "mixed_mix"): - self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data]) + self.mixed_mix = torch.cat([self.time_mix_key.data, + self.time_mix_receptance.data]).to(dtype=hidden.dtype) import linear_q4_0 mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)