Skip to content

Added gemm matrix accumulation matrix into interface and tests #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions arrayfire_wrapper/lib/linear_algebra/blas_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from arrayfire_wrapper.dtypes import c_api_value_to_dtype, complex32, complex64, float32, float64
from arrayfire_wrapper.lib._constants import MatProp
from arrayfire_wrapper.lib._utility import call_from_clib
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_type
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_type, copy_array


def dot(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, /) -> AFArray:
Expand All @@ -29,11 +29,17 @@ def dot_all(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, /)
return real.value if imag.value == 0 else real.value + imag.value * 1j


def gemm(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, alpha: Any, beta: Any, /) -> AFArray:
def gemm(lhs: AFArray, rhs: AFArray, lhs_opts: MatProp, rhs_opts: MatProp, alpha: Any, beta: Any, accum: AFArray | None, /) -> AFArray:
"""
source: https://arrayfire.org/docs/group__blas__func__matmul.htm#ga0463ae584163128718237b02faf5caf7
"""
out = AFArray.create_null_pointer()
out = None
if not accum is None:
out = copy_array(accum)
else:
beta = 0.0
out = AFArray.create_null_pointer()

lhs_dtype = c_api_value_to_dtype(get_type(lhs))

type_mapping = {
Expand Down
16 changes: 8 additions & 8 deletions tests/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_gemm_correct_shape_2d(shape_pairs: list) -> None:
y = wrapper.randu(shape_pairs[1], dtype)

result_shape = (shape_pairs[0][0], shape_pairs[1][1])
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)

assert wrapper.get_dims(result)[0:2] == result_shape

Expand All @@ -302,7 +302,7 @@ def test_gemm_correct_shape_3d(shape_pairs: list) -> None:
y = wrapper.randu(shape_pairs[1], dtype)
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2])

result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
assert wrapper.get_dims(result)[0:3] == result_shape


Expand All @@ -322,7 +322,7 @@ def test_gemm_correct_shape_4d(shape_pairs: list) -> None:
y = wrapper.randu(shape_pairs[1], dtype)
result_shape = (shape_pairs[0][0], shape_pairs[1][1], shape_pairs[0][2], shape_pairs[0][3])

result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)
assert wrapper.get_dims(result)[0:4] == result_shape


Expand All @@ -339,7 +339,7 @@ def test_gemm_correct_dtype(dtype: dtypes.Dtype) -> None:
x = wrapper.randu(shape, dtype)
y = wrapper.randu(shape, dtype)

result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
result = wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)

assert dtypes.c_api_value_to_dtype(wrapper.get_type(result)) == dtype

Expand All @@ -361,7 +361,7 @@ def test_gemm_invalid_pair(shape_pairs: list) -> None:
x = wrapper.randu(shape_pairs[0], dtype)
y = wrapper.randu(shape_pairs[1], dtype)

wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)


def test_gemm_empty_shape() -> None:
Expand All @@ -371,7 +371,7 @@ def test_gemm_empty_shape() -> None:
dtype = dtypes.f32

x = wrapper.randu(empty_shape, dtype)
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1)
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1, None)


@pytest.mark.parametrize(
Expand All @@ -390,7 +390,7 @@ def test_gemm_invalid_dtype(dtype_index: int) -> None:
x = wrapper.randu(shape, dtype)
y = wrapper.randu(shape, dtype)

wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1)
wrapper.gemm(x, y, MatProp.NONE, MatProp.NONE, 1, 1, None)


def test_gemm_empty_matrix() -> None:
Expand All @@ -400,7 +400,7 @@ def test_gemm_empty_matrix() -> None:
dtype = dtypes.f32

x = wrapper.randu(empty_shape, dtype)
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1)
wrapper.gemm(x, x, MatProp.NONE, MatProp.NONE, 1, 1, None)


# matmul tests
Expand Down
Loading