fix rwkv with pip installer (#10591)

This commit is contained in:
Yishuo Wang 2024-03-29 17:56:45 +08:00 committed by GitHub
parent 9a83f21b86
commit 437a349dd6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 4 deletions

View file

@ -54,7 +54,7 @@ def extract_key_value(self, hidden, state=None):
self.time_mix_key.data,
self.time_mix_value.data,
self.time_mix_receptance.data,
])
]).to(dtype=hidden.dtype)
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
@ -119,6 +119,8 @@ def rwkv_attention_forward(
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
if hidden.device.type == "xpu":
self.time_decay.data = self.time_decay.data.to(dtype=key.dtype)
self.time_first.data = self.time_first.data.to(dtype=key.dtype)
rwkv, layer_state = rwkv_linear_attention_xpu(
self.time_decay,
self.time_first,
@ -162,7 +164,8 @@ def rwkv_ffn_forward(
shifted = shifted.contiguous()
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
self.mixed_mix = torch.cat([self.time_mix_key.data,
self.time_mix_receptance.data]).to(dtype=hidden.dtype)
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)

View file

@ -55,7 +55,7 @@ def extract_key_value(self, hidden, state=None):
self.time_mix_value.data,
self.time_mix_receptance.data,
self.time_mix_gate.data,
])
]).to(dtype=hidden.dtype)
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
@ -232,7 +232,8 @@ def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
shifted = shifted.unsqueeze(1)
if not hasattr(self, "mixed_mix"):
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
self.mixed_mix = torch.cat([self.time_mix_key.data,
self.time_mix_receptance.data]).to(dtype=hidden.dtype)
import linear_q4_0
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)