From 2b037b91ee6e62eb9d71d3856b636f332e25c974 Mon Sep 17 00:00:00 2001 From: Edwin Solis Date: Sat, 7 Jun 2025 22:21:47 -0700 Subject: [PATCH] Added gemm matrix accumulation matrix into interface and tests --- .../lib/linear_algebra/blas_operations.py | 12 +++++++++--- tests/test_blas.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/arrayfire_wrapper/lib/linear_algebra/blas_operations.py b/arrayfire_wrapper/lib/linear_algebra/blas_operations.py index 66e43f8..4ba5a66 100644 --- a/arrayfire_wrapper/lib/linear_algebra/blas_operations.py +++ b/arrayfire_wrapper/lib/linear_algebra/blas_operations.py @@ -5,7 +5,7 @@ from arrayfire_wrapper.dtypes import c_api_value_to_dtype, complex32, complex64, float32, float64 from arrayfire_wrapper.lib._constants import MatProp from arrayfire_wrapper.lib._utility import call_from_clib -from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_type +from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_type, copy_array def dot(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, /) -> AFArray: @@ -29,11 +29,17 @@ def dot_all(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, /) return real.value if imag.value == 0 else real.value + imag.value * 1j -def gemm(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, alpha: Any, beta: Any, /) -> AFArray: +def gemm(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, alpha: Any, beta: Any, accum: AFArray | None, /) -> AFArray: """ source: https://arrayfire.org/docs/group__blas__func__matmul.htm#ga0463ae584163128718237b02faf5caf7 """ - out = AFArray.create_null_pointer() + out = None + if not accum is None: + out = copy_array(accum) + else: + beta = 0.0 + out = AFArray.create_null_pointer() + lhs_dtype = c_api_value_to_dtype(get_type(lhs)) type_mapping = { diff --git a/tests/test_blas.py b/tests/test_blas.py index 615dea4..3ad9c22 100644 --- a/tests/test_blas.py +++ b/tests/test_blas.py @@ -281,7 +281,7 @@ def test_gemm_correct_shape_2d(shape_pairs: list) -> None: y = wrapper.randu(shape_pairs[1], dtype) result_shape = (shape_pairs[0][0], shape_pairs[1][1]) - result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1) + result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None) assert wrapper.get_dims(result)[0:2] == result_shape @@ -302,7 +302,7 @@ def test_gemm_correct_shape_3d(shape_pairs: list) -> None: y = wrapper.randu(shape_pairs[1], dtype) result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2]) - result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1) + result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None) assert wrapper.get_dims(result)[0:3] == result_shape @@ -322,7 +322,7 @@ def test_gemm_correct_shape_4d(shape_pairs: list) -> None: y = wrapper.randu(shape_pairs[1], dtype) result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2], shape_pairs[0][3]) - result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1) + result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None) assert wrapper.get_dims(result)[0:4] == result_shape @@ -339,7 +339,7 @@ def test_gemm_correct_dtype(dtype: dtypes.Dtype) -> None: x = wrapper.randu(shape, dtype) y = wrapper.randu(shape, dtype) - result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1) + result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None) assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype @@ -361,7 +361,7 @@ def test_gemm_invalid_pair(shape_pairs: list) -> None: x = wrapper.randu(shape_pairs[0], dtype) y = wrapper.randu(shape_pairs[1], dtype) - wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1) + wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None) def test_gemm_empty_shape() -> None: @@ -371,7 +371,7 @@ def test_gemm_empty_shape() -> None: dtype = dtypes.f32 x = wrapper.randu(empty_shape, dtype) - wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1) + wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1, None) @pytest.mark.parametrize( @@ -390,7 +390,7 @@ def test_gemm_invalid_dtype(dtype_index: int) -> None: x = wrapper.randu(shape, dtype) y = wrapper.randu(shape, dtype) - wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1) + wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None) def test_gemm_empty_matrix() -> None: @@ -400,7 +400,7 @@ def test_gemm_empty_matrix() -> None: dtype = dtypes.f32 x = wrapper.randu(empty_shape, dtype) - wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1) + wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1, None) # matmul tests