Fix run_decoders bug (#11871)

This commit is contained in:
Yang Wang 2024-08-20 12:04:59 -07:00 committed by GitHub
parent 32f0a77846
commit bdaeee1d63
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -601,7 +601,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
x_np[2].ctypes.data_as(ctypes.c_void_p),
)
t0 = time.perf_counter()
backend_lib.run_decoders(models_ptr, inputs_ptr, 2, 3)
backend_lib.run_decoders(models_ptr, inputs_ptr, self.intra_stages, 3)
t1 = time.perf_counter()
hidden_states = self.backend_decoders[-1].torch_out[0]