fix rwkv with pip installer (#10591)
This commit is contained in:
		
							parent
							
								
									9a83f21b86
								
							
						
					
					
						commit
						437a349dd6
					
				
					 2 changed files with 8 additions and 4 deletions
				
			
		| 
						 | 
					@ -54,7 +54,7 @@ def extract_key_value(self, hidden, state=None):
 | 
				
			||||||
            self.time_mix_key.data,
 | 
					            self.time_mix_key.data,
 | 
				
			||||||
            self.time_mix_value.data,
 | 
					            self.time_mix_value.data,
 | 
				
			||||||
            self.time_mix_receptance.data,
 | 
					            self.time_mix_receptance.data,
 | 
				
			||||||
        ])
 | 
					        ]).to(dtype=hidden.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    import linear_q4_0
 | 
					    import linear_q4_0
 | 
				
			||||||
    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
					    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
				
			||||||
| 
						 | 
					@ -119,6 +119,8 @@ def rwkv_attention_forward(
 | 
				
			||||||
    layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
 | 
					    layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if hidden.device.type == "xpu":
 | 
					    if hidden.device.type == "xpu":
 | 
				
			||||||
 | 
					        self.time_decay.data = self.time_decay.data.to(dtype=key.dtype)
 | 
				
			||||||
 | 
					        self.time_first.data = self.time_first.data.to(dtype=key.dtype)
 | 
				
			||||||
        rwkv, layer_state = rwkv_linear_attention_xpu(
 | 
					        rwkv, layer_state = rwkv_linear_attention_xpu(
 | 
				
			||||||
            self.time_decay,
 | 
					            self.time_decay,
 | 
				
			||||||
            self.time_first,
 | 
					            self.time_first,
 | 
				
			||||||
| 
						 | 
					@ -162,7 +164,8 @@ def rwkv_ffn_forward(
 | 
				
			||||||
    shifted = shifted.contiguous()
 | 
					    shifted = shifted.contiguous()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not hasattr(self, "mixed_mix"):
 | 
					    if not hasattr(self, "mixed_mix"):
 | 
				
			||||||
        self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
 | 
					        self.mixed_mix = torch.cat([self.time_mix_key.data,
 | 
				
			||||||
 | 
					                                    self.time_mix_receptance.data]).to(dtype=hidden.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    import linear_q4_0
 | 
					    import linear_q4_0
 | 
				
			||||||
    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
					    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -55,7 +55,7 @@ def extract_key_value(self, hidden, state=None):
 | 
				
			||||||
            self.time_mix_value.data,
 | 
					            self.time_mix_value.data,
 | 
				
			||||||
            self.time_mix_receptance.data,
 | 
					            self.time_mix_receptance.data,
 | 
				
			||||||
            self.time_mix_gate.data,
 | 
					            self.time_mix_gate.data,
 | 
				
			||||||
        ])
 | 
					        ]).to(dtype=hidden.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    import linear_q4_0
 | 
					    import linear_q4_0
 | 
				
			||||||
    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
					    mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
				
			||||||
| 
						 | 
					@ -232,7 +232,8 @@ def rwkv_ffn_forward_wrapper(origin_rwkv_ffn_forward):
 | 
				
			||||||
                shifted = shifted.unsqueeze(1)
 | 
					                shifted = shifted.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if not hasattr(self, "mixed_mix"):
 | 
					            if not hasattr(self, "mixed_mix"):
 | 
				
			||||||
                self.mixed_mix = torch.cat([self.time_mix_key.data, self.time_mix_receptance.data])
 | 
					                self.mixed_mix = torch.cat([self.time_mix_key.data,
 | 
				
			||||||
 | 
					                                            self.time_mix_receptance.data]).to(dtype=hidden.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            import linear_q4_0
 | 
					            import linear_q4_0
 | 
				
			||||||
            mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
					            mixed_result = linear_q4_0.rwkv_time_shift(hidden, shifted, self.mixed_mix)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue