-
Notifications
You must be signed in to change notification settings - Fork 548
Modified backend wrap() to accept output arg #2328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Please add tests. |
src/backend/cpu/wrap.cpp
Outdated
in.eval(); | ||
const dim4 idims = in.dims(); | ||
const dim4 odims(ox, oy, idims[2], idims[3]); | ||
out = createValueArray<T>(odims, scalar<T>(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant to omit these lines, correct? Since this will make wrap write its results to the memory returned by createValueArray
and not the already provided space in the out
argument. I see that you omitted this in the CUDA and OpenCL backends so I guess it was just an accident that you left this here.
1c43ebd
to
a7a179c
Compare
21583a5
to
559d59f
Compare
1d28ee5
to
0b88f98
Compare
src/api/c/wrap.cpp
Outdated
const dim_t sx, const dim_t sy, | ||
const dim_t px, const dim_t py, | ||
const bool is_column) { | ||
if (out == 0) return AF_ERR_ARG; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should do the same conversion you did in the approx_v2 functions here for the error messages.
src/api/c/wrap.cpp
Outdated
|
||
dim_t nx = (ox + 2 * px - wx) / sx + 1; | ||
dim_t ny = (oy + 2 * py - wy) / sy + 1; | ||
ARG_ASSERT(4, wx > 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These indicies will be incorrect in the older versions of approx. These checks should really be in the API level functions. Unfortunately it will be error prone to duplicate this everywhere. I am not sure how to fix this cleanly.
test/wrap.cpp
Outdated
AF_ERR_ARG); | ||
} | ||
|
||
class ArgDim { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from dim4?
test/wrap.cpp
Outdated
public: | ||
WindowDims *wc_; | ||
StrideDims *sc_; | ||
PadDims *pc_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these types really need to be pointers?
src/api/c/wrap.cpp
Outdated
af_array output; | ||
if (allocate_out) { *out = createHandle(out_dims, in_type); } | ||
|
||
DIM_ASSERT(1, getInfo(*out).dims() == out_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DIM_ASSERT(1, getInfo(*out).dims() == out_dims); | |
// The out pointer can be passed in to the function by the user | |
DIM_ASSERT(1, getInfo(*out).dims() == out_dims); |
src/backend/cpu/kernel/wrap.hpp
Outdated
@@ -25,6 +29,8 @@ void wrap_dim(Param<T> out, CParam<T> in, const dim_t wx, const dim_t wy, | |||
af::dim4 istrides = in.strides(); | |||
af::dim4 ostrides = out.strides(); | |||
|
|||
std::fill(begin(out), end(out), scalar<T>(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this. the out array should be zeroed before we get here. Instead of createHandle, you should use createValueArray.
If the out array is passed into the function, we can stipulate that the values will be summed into the output array. If they want the originals they need to pass a constant array or zero beforehand.
24fc926
to
6833a3a
Compare
Minor cleanup.
9581ce8
to
bf4c9e7
Compare
bf4c9e7
to
7d847a4
Compare
No description provided.