enable inference mode for deepspeed tp serving (#11742)

This commit is contained in:
Shaojun Liu 2024-08-08 14:38:30 +08:00 committed by GitHub
parent 9e65cf00b3
commit 107f7aafd0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -121,6 +121,8 @@ def load_model(model_path, low_bit):
current_accel = XPU_Accelerator()
set_accelerator(current_accel)
model=model.eval()
# Move model back to xpu
model = model.to(f"xpu:{local_rank}")
model = BenchmarkWrapper(model)