LLM: fix memory access violation (#8519)
This commit is contained in:
parent
60c2c0c3dc
commit
6320bf201e
2 changed files with 4 additions and 4 deletions
|
|
@ -984,7 +984,7 @@ def ggml_type_size(qtype: ctypes.c_int) -> int:
|
||||||
_lib.ggml_type_size.argtypes = [
|
_lib.ggml_type_size.argtypes = [
|
||||||
ctypes.c_int,
|
ctypes.c_int,
|
||||||
]
|
]
|
||||||
_lib.ggml_type_size.restype = ctypes.c_int
|
_lib.ggml_type_size.restype = ctypes.c_size_t
|
||||||
|
|
||||||
|
|
||||||
def ggml_qk_size(qtype: ctypes.c_int) -> int:
|
def ggml_qk_size(qtype: ctypes.c_int) -> int:
|
||||||
|
|
|
||||||
|
|
@ -153,13 +153,13 @@ def ggml_matmul_src1_x_src0_t(src0: torch.Tensor,
|
||||||
if src1.dtype != torch.float32:
|
if src1.dtype != torch.float32:
|
||||||
src1 = src1.float()
|
src1 = src1.float()
|
||||||
|
|
||||||
src0_ptr = src0.data_ptr() + (src0.storage_offset() * src0.element_size())
|
src0_ptr = src0.data_ptr()
|
||||||
src1_ptr = src1.data_ptr() + (src1.storage_offset() * src1.element_size())
|
src1_ptr = src1.data_ptr()
|
||||||
|
|
||||||
result_shape = (src1.shape[0], src0_shape[0])
|
result_shape = (src1.shape[0], src0_shape[0])
|
||||||
|
|
||||||
result_t = torch.empty(result_shape, dtype=torch.float32)
|
result_t = torch.empty(result_shape, dtype=torch.float32)
|
||||||
result_ptr = result_t.data_ptr() + (result_t.storage_offset() * result_t.element_size())
|
result_ptr = result_t.data_ptr()
|
||||||
|
|
||||||
src0_shape = tuple(reversed(src0_shape))
|
src0_shape = tuple(reversed(src0_shape))
|
||||||
src1_shape = tuple(reversed(src1.shape))
|
src1_shape = tuple(reversed(src1.shape))
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue