LLM: fix memory access violation (#8519)

This commit is contained in:
Yishuo Wang 2023-07-13 17:08:08 +08:00 committed by GitHub
parent 60c2c0c3dc
commit 6320bf201e
2 changed files with 4 additions and 4 deletions

View file

@ -984,7 +984,7 @@ def ggml_type_size(qtype: ctypes.c_int) -> int:
_lib.ggml_type_size.argtypes = [
ctypes.c_int,
]
_lib.ggml_type_size.restype = ctypes.c_int
_lib.ggml_type_size.restype = ctypes.c_size_t
def ggml_qk_size(qtype: ctypes.c_int) -> int:

View file

@ -153,13 +153,13 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
if src1.dtype != torch.float32:
src1 = src1.float()
src0_ptr = src0.data_ptr() + (src0.storage_offset() * src0.element_size())
src1_ptr = src1.data_ptr() + (src1.storage_offset() * src1.element_size())
src0_ptr = src0.data_ptr()
src1_ptr = src1.data_ptr()
result_shape = (src1.shape[0], src0_shape[0])
result_t = torch.empty(result_shape, dtype=torch.float32)
result_ptr = result_t.data_ptr() + (result_t.storage_offset() * result_t.element_size())
result_ptr = result_t.data_ptr()
src0_shape = tuple(reversed(src0_shape))
src1_shape = tuple(reversed(src1.shape))