Skip to content

Optional out parameter for numpy.dot #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@
add_newdoc('numpy.core', 'dot',
"""
dot(a, b)
dot(a, b, out)

Dot product of two arrays.

Expand All @@ -1209,13 +1210,21 @@
First argument.
b : array_like
Second argument.
out : ndarray, optional
Output argument. This must have the exact kind that would be returned
if it was not used. In particular, it must have the right type, must be
C-contiguous, and its dtype must be the dtype that would be returned
for `dot(a,b)`. This is a performance feature. Therefore, if these
conditions are not met, an exception is raised, instead of attempting
to be flexible.

Returns
-------
output : ndarray
Returns the dot product of `a` and `b`. If `a` and `b` are both
scalars or both 1-D arrays then a scalar is returned; otherwise
an array is returned.
If `out` is given, then it is returned.

Raises
------
Expand Down
52 changes: 41 additions & 11 deletions numpy/core/blasdot/_dotblas.c
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ _bad_strides(PyArrayObject *ap)
* NB: The first argument is not conjugated.;
*/
static PyObject *
dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwargs)
{
PyObject *op1, *op2;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ret = NULL;
PyArrayObject *ap1 = NULL, *ap2 = NULL, *out = NULL, *ret = NULL;
int j, l, lda, ldb, ldc;
int typenum, nd;
npy_intp ap1stride = 0;
Expand All @@ -230,8 +230,10 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
PyTypeObject *subtype;
PyArray_Descr *dtype;
MatrixShape ap1shape, ap2shape;
char* kwords[] = {"a", "b", "out", NULL };

if (!PyArg_ParseTuple(args, "OO", &op1, &op2)) {
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|O", kwords,
&op1, &op2, &out)) {
return NULL;
}

Expand All @@ -246,7 +248,10 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
/* This function doesn't handle other types */
if ((typenum != PyArray_DOUBLE && typenum != PyArray_CDOUBLE &&
typenum != PyArray_FLOAT && typenum != PyArray_CFLOAT)) {
return PyArray_Return((PyArrayObject *)PyArray_MatrixProduct(op1, op2));
return PyArray_Return((PyArrayObject *)PyArray_MatrixProduct3(
(PyObject *)op1,
(PyObject *)op2,
(PyObject *)out));
}

dtype = PyArray_DescrFromType(typenum);
Expand Down Expand Up @@ -279,8 +284,9 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
Py_DECREF(tmp1);
Py_DECREF(tmp2);
}
ret = (PyArrayObject *)PyArray_MatrixProduct((PyObject *)ap1,
(PyObject *)ap2);
ret = (PyArrayObject *)PyArray_MatrixProduct3((PyObject *)ap1,
(PyObject *)ap2,
(PyObject *)out);
Py_DECREF(ap1);
Py_DECREF(ap2);
return PyArray_Return(ret);
Expand Down Expand Up @@ -418,10 +424,34 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
subtype = Py_TYPE(ap1);
}

ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
typenum, NULL, NULL, 0, 0,
(PyObject *)
(prior2 > prior1 ? ap2 : ap1));
if (out) {
int d;
/* verify that out is usable */
if (Py_TYPE(out) != subtype ||
PyArray_NDIM(out) != nd ||
PyArray_TYPE(out) != typenum ||
!PyArray_ISCARRAY(out)) {

PyErr_SetString(PyExc_ValueError,
"output array is not acceptable "
"(must have the right type, nr dimensions, and be a C-Array)");
goto fail;
}
for (d = 0; d != nd; ++d) {
if (dimensions[d] != PyArray_DIM(out, d)) {
PyErr_SetString(PyExc_ValueError,
"output array has wrong dimensions");
goto fail;
}
}
Py_INCREF(out);
ret = out;
} else {
ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
typenum, NULL, NULL, 0, 0,
(PyObject *)
(prior2 > prior1 ? ap2 : ap1));
}

if (ret == NULL) {
goto fail;
Expand Down Expand Up @@ -1167,7 +1197,7 @@ static PyObject *dotblas_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args) {
}

static struct PyMethodDef dotblas_module_methods[] = {
{"dot", (PyCFunction)dotblas_matrixproduct, 1, NULL},
{"dot", (PyCFunction)dotblas_matrixproduct, METH_VARARGS|METH_KEYWORDS, NULL},
{"inner", (PyCFunction)dotblas_innerproduct, 1, NULL},
{"vdot", (PyCFunction)dotblas_vdot, 1, NULL},
{"alterdot", (PyCFunction)dotblas_alterdot, 1, NULL},
Expand Down
1 change: 1 addition & 0 deletions numpy/core/code_generators/numpy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@
'PyArray_TimedeltaToTimedeltaStruct': 218,
'PyArray_DatetimeStructToDatetime': 219,
'PyArray_TimedeltaStructToTimedelta': 220,
'PyArray_MatrixProduct3': 222,
}

ufunc_types_api = {
Expand Down
56 changes: 44 additions & 12 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ PyArray_CanCoerceScalar(int thistype, int neededtype,
* priority of ap1 and ap2 into account.
*/
static PyArrayObject *
new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2,
new_array_for_sum(PyArrayObject* out, PyArrayObject *ap1, PyArrayObject *ap2,
int nd, intp dimensions[], int typenum)
{
PyArrayObject *ret;
Expand All @@ -597,6 +597,28 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2,
prior1 = prior2 = 0.0;
subtype = Py_TYPE(ap1);
}
if (out) {
int d;
/* verify that out is usable */
if (Py_TYPE(out) != subtype ||
PyArray_NDIM(out) != nd ||
PyArray_TYPE(out) != typenum ||
!PyArray_ISCARRAY(out)) {
PyErr_SetString(PyExc_ValueError,
"output array is not acceptable "
"(must have the right type, nr dimensions, and be a C-Array)");
return 0;
}
for (d = 0; d != nd; ++d) {
if (dimensions[d] != PyArray_DIM(out, d)) {
PyErr_SetString(PyExc_ValueError,
"output array has wrong dimensions");
return 0;
}
}
Py_INCREF(out);
return out;
}

ret = (PyArrayObject *)PyArray_New(subtype, nd, dimensions,
typenum, NULL, NULL, 0, 0,
Expand Down Expand Up @@ -666,7 +688,7 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
* Need to choose an output array that can hold a sum
* -- use priority to determine which subtype.
*/
ret = new_array_for_sum(ap1, ap2, nd, dimensions, typenum);
ret = new_array_for_sum(NULL, ap1, ap2, nd, dimensions, typenum);
if (ret == NULL) {
goto fail;
}
Expand Down Expand Up @@ -713,13 +735,12 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
return NULL;
}


/*NUMPY_API
*Numeric.matrixproduct(a,v)
* Numeric.matrixproduct(a,v,out)
* just like inner product but does the swapaxes stuff on the fly
*/
NPY_NO_EXPORT PyObject *
PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
PyArray_MatrixProduct3(PyObject *op1, PyObject *op2, PyObject* out)
{
PyArrayObject *ap1, *ap2, *ret = NULL;
PyArrayIterObject *it1, *it2;
Expand Down Expand Up @@ -788,7 +809,7 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)

is1 = ap1->strides[ap1->nd-1]; is2 = ap2->strides[matchDim];
/* Choose which subtype to return */
ret = new_array_for_sum(ap1, ap2, nd, dimensions, typenum);
ret = new_array_for_sum(out, ap1, ap2, nd, dimensions, typenum);
if (ret == NULL) {
goto fail;
}
Expand Down Expand Up @@ -845,6 +866,16 @@ PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
return NULL;
}

/*NUMPY_API
*Numeric.matrixproduct(a,v)
* just like inner product but does the swapaxes stuff on the fly
*/
NPY_NO_EXPORT PyObject *
PyArray_MatrixProduct(PyObject *op1, PyObject *op2)
{
return PyArray_MatrixProduct3(op1, op2, NULL);
}

/*NUMPY_API
* Fast Copy and Transpose
*/
Expand Down Expand Up @@ -968,7 +999,7 @@ _pyarray_correlate(PyArrayObject *ap1, PyArrayObject *ap2, int typenum,
* Need to choose an output array that can hold a sum
* -- use priority to determine which subtype.
*/
ret = new_array_for_sum(ap1, ap2, 1, &length, typenum);
ret = new_array_for_sum(NULL, ap1, ap2, 1, &length, typenum);
if (ret == NULL) {
return NULL;
}
Expand Down Expand Up @@ -1850,14 +1881,15 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
}

static PyObject *
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds)
{
PyObject *v, *a;
PyObject *v, *a, *o = NULL;
char* kwlist[] = {"a", "b", "out", NULL };

if (!PyArg_ParseTuple(args, "OO", &a, &v)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist, &a, &v, &o)) {
return NULL;
}
return _ARET(PyArray_MatrixProduct(a, v));
return _ARET(PyArray_MatrixProduct3(a, v, o));
}

static PyObject *
Expand Down Expand Up @@ -2766,7 +2798,7 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS, NULL},
{"dot",
(PyCFunction)array_matrixproduct,
METH_VARARGS, NULL},
METH_VARARGS | METH_KEYWORDS, NULL},
{"_fastCopyAndTranspose",
(PyCFunction)array_fastCopyAndTranspose,
METH_VARARGS, NULL},
Expand Down
65 changes: 64 additions & 1 deletion numpy/core/tests/test_blasdot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import sys
from numpy.core import zeros, float64
from numpy.testing import dec, TestCase, assert_almost_equal, assert_
from numpy.testing import dec, TestCase, assert_almost_equal, assert_, \
assert_raises, assert_array_equal, assert_allclose, assert_equal
from numpy.core.multiarray import inner as inner_

DECPREC = 14
Expand All @@ -26,3 +29,63 @@ def test_blasdot_used():
assert_(inner is _dotblas.inner)
assert_(alterdot is _dotblas.alterdot)
assert_(restoredot is _dotblas.restoredot)


def test_dot_2args():
from numpy.core import dot

a = np.array([[1, 2], [3, 4]], dtype=float)
b = np.array([[1, 0], [1, 1]], dtype=float)
c = np.array([[3, 2], [7, 4]], dtype=float)

d = dot(a, b)
assert_allclose(c, d)

def test_dot_3args():
np.random.seed(22)
f = np.random.random_sample((1024, 16))
v = np.random.random_sample((16, 32))

r = np.empty((1024, 32))
for i in xrange(12):
np.dot(f,v,r)
assert_equal(sys.getrefcount(r), 2)
r2 = np.dot(f,v)
assert_array_equal(r2, r)
assert_(r is np.dot(f,v,r))

v = v[:,0].copy() # v.shape == (16,)
r = r[:,0].copy() # r.shape == (1024,)
r2 = np.dot(f,v)
assert_(r is np.dot(f,v,r))
assert_array_equal(r2, r)

def test_dot_3args_errors():
np.random.seed(22)
f = np.random.random_sample((1024, 16))
v = np.random.random_sample((16, 32))

r = np.empty((1024, 31))
assert_raises(ValueError, np.dot, f, v, r)

r = np.empty((1024,))
assert_raises(ValueError, np.dot, f, v, r)

r = np.empty((32,))
assert_raises(ValueError, np.dot, f, v, r)

r = np.empty((32, 1024))
assert_raises(ValueError, np.dot, f, v, r)
assert_raises(ValueError, np.dot, f, v, r.T)

r = np.empty((1024, 64))
assert_raises(ValueError, np.dot, f, v, r[:,::2])
assert_raises(ValueError, np.dot, f, v, r[:,:32])

r = np.empty((1024, 32), dtype=np.float32)
assert_raises(ValueError, np.dot, f, v, r)

r = np.empty((1024, 32), dtype=int)
assert_raises(ValueError, np.dot, f, v, r)


Loading