LLM: add device id to benchmark utils. (#10877)

This commit is contained in:
Cengguang Zhang 2024-04-25 14:01:51 +08:00 committed by GitHub
parent 1ce8d7bcd9
commit cd369c2715
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -2443,7 +2443,7 @@ class BenchmarkWrapper:
if self.device.type == "xpu":
torch.xpu.synchronize()
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024**3))
memory_every_token.append(torch.xpu.memory.memory_reserved(self.device) / (1024**3))
self.peak_memory = np.max(memory_every_token)
end = time.perf_counter()
if first_token_time is None:
@ -2758,7 +2758,7 @@ class BenchmarkWrapper:
if self.device.type == "xpu":
torch.xpu.synchronize()
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024 ** 3))
memory_every_token.append(torch.xpu.memory.memory_reserved(self.device) / (1024 ** 3))
self.peak_memory = np.max(memory_every_token)
end = time.perf_counter()
if first_token_time is None:
@ -3099,7 +3099,7 @@ class BenchmarkWrapper:
if self.device.type == "xpu":
torch.xpu.synchronize()
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024 ** 3))
memory_every_token.append(torch.xpu.memory.memory_reserved(self.device) / (1024 ** 3))
self.peak_memory = np.max(memory_every_token)
end = time.perf_counter()
if first_token_time is None:
@ -3471,7 +3471,7 @@ class BenchmarkWrapper:
if self.device.type == "xpu":
torch.xpu.synchronize()
memory_every_token.append(torch.xpu.memory.memory_reserved() / (1024 ** 3))
memory_every_token.append(torch.xpu.memory.memory_reserved(self.device) / (1024 ** 3))
self.peak_memory = np.max(memory_every_token)
end = time.perf_counter()
if first_token_time is None: