fix rwkv with pip installer (#10591)
This commit is contained in:
parent
9a83f21b86
commit
437a349dd6
2 changed files with 8 additions and 4 deletions
|
|
@ -54,7 +54,7 @@ def extract_key_value(self, hidden, state=None):
|
|||
self.time_mix_key.data,
|
||||
self.time_mix_value.data,
|
||||
self.time_mix_receptance.data,
|
||||
])
|
||||
]).to(dtype=hidden.dtype)
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
|
|
@ -119,6 +119,8 @@ def rwkv_attention_forward(
|
|||
layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
|
||||
|
||||
if hidden.device.type == "xpu":
|
||||
self.time_decay.data = self.time_decay.data.to(dtype=key.dtype)
|
||||
self.time_first.data = self.time_first.data.to(dtype=key.dtype)
|
||||
rwkv, layer_state = rwkv_linear_attention_xpu(
|
||||
self.time_decay,
|
||||
self.time_first,
|
||||
|
|
@ -162,7 +164,8 @@ def rwkv_ffn_forward(
|
|||
shifted = shifted.contiguous()
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data,
|
||||
self.time_mix_receptance.data]).to(dtype=hidden.dtype)
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def extract_key_value(self, hidden, state=None):
|
|||
self.time_mix_value.data,
|
||||
self.time_mix_receptance.data,
|
||||
self.time_mix_gate.data,
|
||||
])
|
||||
]).to(dtype=hidden.dtype)
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
|
|
@ -232,7 +232,8 @@ def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
|
|||
shifted = shifted.unsqueeze(1)
|
||||
|
||||
if not hasattr(self, "mixed_mix"):
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
|
||||
self.mixed_mix = torch.cat([self.time_mix_key.data,
|
||||
self.time_mix_receptance.data]).to(dtype=hidden.dtype)
|
||||
|
||||
import linear_q4_0
|
||||
mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
|
||||
|
|
|
|||
Loading…
Reference in a new issue