Skip to content

Commit 97213b5

Browse files
syurkevi9prady9
authored andcommitted
FEAT: ndim reduce by key
adds multidimensional tests of varying sizes and types to match one dimensional cases
1 parent 10013a2 commit 97213b5

26 files changed

+4314
-119
lines changed

include/af/algorithm.h

Lines changed: 308 additions & 1 deletion
Large diffs are not rendered by default.

src/api/c/reduce.cpp

Lines changed: 322 additions & 27 deletions
Large diffs are not rendered by default.

src/api/cpp/reduce.cpp

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,24 @@ array sum(const array &in, const int dim, const double nanval) {
2626
return array(out);
2727
}
2828

29+
void sumByKey(array &keys_out, array &vals_out, const array &keys,
30+
const array &vals, const int dim) {
31+
af_array okeys, ovals;
32+
AF_THROW(af_sum_by_key(&okeys, &ovals, keys.get(), vals.get(),
33+
getFNSD(dim, vals.dims())));
34+
keys_out = array(okeys);
35+
vals_out = array(ovals);
36+
}
37+
38+
void sumByKey(array &keys_out, array &vals_out, const array &keys,
39+
const array &vals, const int dim, const double nanval) {
40+
af_array okeys, ovals;
41+
AF_THROW(
42+
af_sum_nan_by_key(&okeys, &ovals, keys.get(), vals.get(), dim, nanval));
43+
keys_out = array(okeys);
44+
vals_out = array(ovals);
45+
}
46+
2947
array product(const array &in, const int dim) {
3048
af_array out = 0;
3149
AF_THROW(af_product(&out, in.get(), getFNSD(dim, in.dims())));
@@ -38,6 +56,24 @@ array product(const array &in, const int dim, const double nanval) {
3856
return array(out);
3957
}
4058

59+
void productByKey(array &keys_out, array &vals_out, const array &keys,
60+
const array &vals, const int dim) {
61+
af_array okeys, ovals;
62+
AF_THROW(af_product_by_key(&okeys, &ovals, keys.get(), vals.get(),
63+
getFNSD(dim, vals.dims())));
64+
keys_out = array(okeys);
65+
vals_out = array(ovals);
66+
}
67+
68+
void productByKey(array &keys_out, array &vals_out, const array &keys,
69+
const array &vals, const int dim, const double nanval) {
70+
af_array okeys, ovals;
71+
AF_THROW(af_product_nan_by_key(&okeys, &ovals, keys.get(), vals.get(), dim,
72+
nanval));
73+
keys_out = array(okeys);
74+
vals_out = array(ovals);
75+
}
76+
4177
array mul(const array &in, const int dim) { return product(in, dim); }
4278

4379
array min(const array &in, const int dim) {
@@ -46,12 +82,30 @@ array min(const array &in, const int dim) {
4682
return array(out);
4783
}
4884

85+
void minByKey(array &keys_out, array &vals_out, const array &keys,
86+
const array &vals, const int dim) {
87+
af_array okeys, ovals;
88+
AF_THROW(af_min_by_key(&okeys, &ovals, keys.get(), vals.get(),
89+
getFNSD(dim, vals.dims())));
90+
keys_out = array(okeys);
91+
vals_out = array(ovals);
92+
}
93+
4994
array max(const array &in, const int dim) {
5095
af_array out = 0;
5196
AF_THROW(af_max(&out, in.get(), getFNSD(dim, in.dims())));
5297
return array(out);
5398
}
5499

100+
void maxByKey(array &keys_out, array &vals_out, const array &keys,
101+
const array &vals, const int dim) {
102+
af_array okeys, ovals;
103+
AF_THROW(af_max_by_key(&okeys, &ovals, keys.get(), vals.get(),
104+
getFNSD(dim, vals.dims())));
105+
keys_out = array(okeys);
106+
vals_out = array(ovals);
107+
}
108+
55109
// 2.1 compatibility
56110
array alltrue(const array &in, const int dim) { return allTrue(in, dim); }
57111
array allTrue(const array &in, const int dim) {
@@ -60,6 +114,15 @@ array allTrue(const array &in, const int dim) {
60114
return array(out);
61115
}
62116

117+
void allTrueByKey(array &keys_out, array &vals_out, const array &keys,
118+
const array &vals, const int dim) {
119+
af_array okeys, ovals;
120+
AF_THROW(af_all_true_by_key(&okeys, &ovals, keys.get(), vals.get(),
121+
getFNSD(dim, vals.dims())));
122+
keys_out = array(okeys);
123+
vals_out = array(ovals);
124+
}
125+
63126
// 2.1 compatibility
64127
array anytrue(const array &in, const int dim) { return anyTrue(in, dim); }
65128
array anyTrue(const array &in, const int dim) {
@@ -68,12 +131,30 @@ array anyTrue(const array &in, const int dim) {
68131
return array(out);
69132
}
70133

134+
void anyTrueByKey(array &keys_out, array &vals_out, const array &keys,
135+
const array &vals, const int dim) {
136+
af_array okeys, ovals;
137+
AF_THROW(af_any_true_by_key(&okeys, &ovals, keys.get(), vals.get(),
138+
getFNSD(dim, vals.dims())));
139+
keys_out = array(okeys);
140+
vals_out = array(ovals);
141+
}
142+
71143
array count(const array &in, const int dim) {
72144
af_array out = 0;
73145
AF_THROW(af_count(&out, in.get(), getFNSD(dim, in.dims())));
74146
return array(out);
75147
}
76148

149+
void countByKey(array &keys_out, array &vals_out, const array &keys,
150+
const array &vals, const int dim) {
151+
af_array okeys, ovals;
152+
AF_THROW(af_count_by_key(&okeys, &ovals, keys.get(), vals.get(),
153+
getFNSD(dim, vals.dims())));
154+
keys_out = array(okeys);
155+
vals_out = array(ovals);
156+
}
157+
77158
void min(array &val, array &idx, const array &in, const int dim) {
78159
af_array out = 0;
79160
af_array loc = 0;
@@ -107,15 +188,15 @@ void max(array &val, array &idx, const array &in, const int dim) {
107188
INSTANTIATE_CPLX(fnC, fnCPP, af_cdouble, double)
108189

109190
#define INSTANTIATE_REAL(fnC, fnCPP, T) \
110-
template<> \
191+
template <> \
111192
AFAPI T fnCPP(const array &in) { \
112193
double rval, ival; \
113194
AF_THROW(af_##fnC##_all(&rval, &ival, in.get())); \
114195
return (T)(rval); \
115196
}
116197

117198
#define INSTANTIATE_CPLX(fnC, fnCPP, T, Tr) \
118-
template<> \
199+
template <> \
119200
AFAPI T fnCPP(const array &in) { \
120201
double rval, ival; \
121202
AF_THROW(af_##fnC##_all(&rval, &ival, in.get())); \
@@ -138,15 +219,15 @@ INSTANTIATE_REAL(any_true, anyTrue, bool);
138219
#undef INSTANTIATE_CPLX
139220

140221
#define INSTANTIATE_REAL(fnC, fnCPP, T) \
141-
template<> \
222+
template <> \
142223
AFAPI T fnCPP(const array &in, const double nanval) { \
143224
double rval, ival; \
144225
AF_THROW(af_##fnC##_all(&rval, &ival, in.get(), nanval)); \
145226
return (T)(rval); \
146227
}
147228

148229
#define INSTANTIATE_CPLX(fnC, fnCPP, T, Tr) \
149-
template<> \
230+
template <> \
150231
AFAPI T fnCPP(const array &in, const double nanval) { \
151232
double rval, ival; \
152233
AF_THROW(af_##fnC##_all(&rval, &ival, in.get(), nanval)); \
@@ -162,7 +243,7 @@ INSTANTIATE(product_nan, product)
162243
#undef INSTANTIATE
163244

164245
#define INSTANTIATE_COMPAT(fnCPP, fnCompat, T) \
165-
template<> \
246+
template <> \
166247
AFAPI T fnCompat(const array &in) { \
167248
return fnCPP<T>(in); \
168249
}
@@ -194,15 +275,15 @@ INSTANTIATE_COMPAT(anyTrue, anytrue, bool)
194275
#undef INSTANTIATE_COMPAT
195276

196277
#define INSTANTIATE_REAL(fn, T) \
197-
template<> \
278+
template <> \
198279
AFAPI void fn(T *val, unsigned *idx, const array &in) { \
199280
double rval, ival; \
200281
AF_THROW(af_i##fn##_all(&rval, &ival, idx, in.get())); \
201282
*val = (T)(rval); \
202283
}
203284

204285
#define INSTANTIATE_CPLX(fn, T, Tr) \
205-
template<> \
286+
template <> \
206287
AFAPI void fn(T *val, unsigned *idx, const array &in) { \
207288
double rval, ival; \
208289
AF_THROW(af_i##fn##_all(&rval, &ival, idx, in.get())); \

src/api/unified/algorithm.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ ALGO_HAPI_DEF(af_diff2)
3030

3131
#undef ALGO_HAPI_DEF
3232

33+
#define ALGO_HAPI_DEF_BYKEY(af_func) \
34+
af_err af_func(af_array *keys_out, af_array *vals_out, \
35+
const af_array keys, const af_array vals, const int dim) { \
36+
CHECK_ARRAYS(keys); \
37+
CHECK_ARRAYS(vals); \
38+
return CALL(keys_out, vals_out, keys, vals, dim); \
39+
}
40+
41+
ALGO_HAPI_DEF_BYKEY(af_sum_by_key)
42+
ALGO_HAPI_DEF_BYKEY(af_product_by_key)
43+
ALGO_HAPI_DEF_BYKEY(af_min_by_key)
44+
ALGO_HAPI_DEF_BYKEY(af_max_by_key)
45+
ALGO_HAPI_DEF_BYKEY(af_all_true_by_key)
46+
ALGO_HAPI_DEF_BYKEY(af_any_true_by_key)
47+
ALGO_HAPI_DEF_BYKEY(af_count_by_key)
48+
49+
#undef ALGO_HAPI_DEF_BYKEY
50+
3351
#define ALGO_HAPI_DEF(af_func_nan) \
3452
af_err af_func_nan(af_array *out, const af_array in, const int dim, \
3553
const double nanval) { \
@@ -42,6 +60,20 @@ ALGO_HAPI_DEF(af_product_nan)
4260

4361
#undef ALGO_HAPI_DEF
4462

63+
#define ALGO_HAPI_DEF_BYKEY(af_func_nan) \
64+
af_err af_func_nan(af_array *keys_out, af_array *vals_out, \
65+
const af_array keys, const af_array vals, \
66+
const int dim, const double nanval) { \
67+
CHECK_ARRAYS(keys); \
68+
CHECK_ARRAYS(vals); \
69+
return CALL(keys_out, vals_out, keys, vals, dim, nanval); \
70+
}
71+
72+
ALGO_HAPI_DEF_BYKEY(af_sum_nan_by_key)
73+
ALGO_HAPI_DEF_BYKEY(af_product_nan_by_key)
74+
75+
#undef ALGO_HAPI_DEF_BYKEY
76+
4577
#define ALGO_HAPI_DEF(af_func_all) \
4678
af_err af_func_all(double *real, double *imag, const af_array in) { \
4779
CHECK_ARRAYS(in); \

src/backend/cpu/kernel/reduce.hpp

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
namespace cpu {
1515
namespace kernel {
1616

17-
template<af_op_t op, typename Ti, typename To, int D>
17+
template <af_op_t op, typename Ti, typename To, int D>
1818
struct reduce_dim {
1919
void operator()(Param<To> out, const dim_t outOffset, CParam<Ti> in,
2020
const dim_t inOffset, const int dim, bool change_nan,
@@ -34,7 +34,7 @@ struct reduce_dim {
3434
}
3535
};
3636

37-
template<af_op_t op, typename Ti, typename To>
37+
template <af_op_t op, typename Ti, typename To>
3838
struct reduce_dim<op, Ti, To, 0> {
3939
Transform<Ti, To, op> transform;
4040
Binary<To, op> reduce;
@@ -44,20 +44,120 @@ struct reduce_dim<op, Ti, To, 0> {
4444
const af::dim4 istrides = in.strides();
4545
const af::dim4 idims = in.dims();
4646

47-
To* const outPtr = out.get() + outOffset;
48-
Ti const* const inPtr = in.get() + inOffset;
47+
To *const outPtr = out.get() + outOffset;
48+
Ti const *const inPtr = in.get() + inOffset;
4949
dim_t stride = istrides[dim];
5050

5151
To out_val = Binary<To, op>::init();
5252
for (dim_t i = 0; i < idims[dim]; i++) {
53-
To in_val = transform(inPtr[i * stride]);
53+
To in_val = transform(inPtr[i * stride]);
5454
if (change_nan) in_val = IS_NAN(in_val) ? nanval : in_val;
55-
out_val = reduce(in_val, out_val);
55+
out_val = reduce(in_val, out_val);
5656
}
5757

5858
*outPtr = out_val;
5959
}
6060
};
6161

62-
} // namespace kernel
63-
} // namespace cpu
62+
template <typename Tk>
63+
void n_reduced_keys(Param<Tk> okeys, CParam<Tk> keys, int *n_reduced) {
64+
const af::dim4 kstrides = keys.strides();
65+
const af::dim4 kdims = keys.dims();
66+
67+
Tk *const outKeysPtr = okeys.get();
68+
Tk const *const inKeysPtr = keys.get();
69+
70+
int nkeys = 0;
71+
Tk current_key = inKeysPtr[0];
72+
for (dim_t i = 0; i < kdims[0]; i++) {
73+
Tk keyval = inKeysPtr[i];
74+
75+
if (keyval != current_key) {
76+
outKeysPtr[nkeys] = current_key;
77+
current_key = keyval;
78+
++nkeys;
79+
}
80+
81+
if (i == (kdims[0] - 1)) { outKeysPtr[nkeys] = current_key; }
82+
}
83+
84+
*n_reduced = nkeys + 1;
85+
}
86+
87+
template <af_op_t op, typename Ti, typename Tk, typename To, int D>
88+
struct reduce_dim_by_key {
89+
void operator()(Param<To> ovals, const dim_t ovOffset, CParam<Tk> keys,
90+
CParam<Ti> vals, const dim_t vOffset, int *n_reduced,
91+
const int dim, bool change_nan, double nanval) {
92+
static const int D1 = D - 1;
93+
reduce_dim_by_key<op, Ti, Tk, To, D1> reduce_by_key_dim_next;
94+
95+
const af::dim4 ovstrides = ovals.strides();
96+
const af::dim4 vstrides = vals.strides();
97+
const af::dim4 vdims = ovals.dims();
98+
99+
if (D1 == dim) {
100+
reduce_by_key_dim_next(ovals, ovOffset, keys, vals, vOffset,
101+
n_reduced, dim, change_nan, nanval);
102+
} else {
103+
for (dim_t i = 0; i < vdims[D1]; i++) {
104+
reduce_by_key_dim_next(ovals, ovOffset + (i * ovstrides[D1]),
105+
keys, vals, vOffset + (i * vstrides[D1]),
106+
n_reduced, dim, change_nan, nanval);
107+
}
108+
}
109+
}
110+
};
111+
112+
template <af_op_t op, typename Ti, typename Tk, typename To>
113+
struct reduce_dim_by_key<op, Ti, Tk, To, 0> {
114+
Transform<Ti, To, op> transform;
115+
Binary<To, op> reduce;
116+
void operator()(Param<To> ovals, const dim_t ovOffset, CParam<Tk> keys,
117+
CParam<Ti> vals, const dim_t vOffset, int *n_reduced,
118+
const int dim, bool change_nan, double nanval) {
119+
const af::dim4 kstrides = keys.strides();
120+
const af::dim4 kdims = keys.dims();
121+
122+
const af::dim4 vstrides = vals.strides();
123+
const af::dim4 vdims = vals.dims();
124+
125+
const af::dim4 ovstrides = ovals.strides();
126+
const af::dim4 ovdims = ovals.dims();
127+
128+
Tk const *const inKeysPtr = keys.get();
129+
Ti const *const inValsPtr = vals.get();
130+
To *const outValsPtr = ovals.get();
131+
132+
int keyidx = 0;
133+
Tk current_key = inKeysPtr[0];
134+
To out_val = reduce.init();
135+
136+
dim_t istride = vstrides[dim];
137+
dim_t ostride = ovstrides[dim];
138+
139+
for (dim_t i = 0; i < vdims[dim]; i++) {
140+
dim_t off = vOffset;
141+
Tk keyval = inKeysPtr[i];
142+
143+
if (keyval == current_key) {
144+
To in_val = transform(inValsPtr[vOffset + (i * istride)]);
145+
if (change_nan) in_val = IS_NAN(in_val) ? nanval : in_val;
146+
out_val = reduce(in_val, out_val);
147+
148+
} else {
149+
outValsPtr[ovOffset + (keyidx * ostride)] = out_val;
150+
151+
current_key = keyval;
152+
out_val = transform(inValsPtr[vOffset + (i * istride)]);
153+
++keyidx;
154+
}
155+
156+
if (i == (vdims[dim] - 1)) {
157+
outValsPtr[ovOffset + (keyidx * ostride)] = out_val;
158+
}
159+
}
160+
}
161+
};
162+
}
163+
}

0 commit comments

Comments
 (0)