Skip to content

Commit 2e8d17a

Browse files
committed
another step
1 parent edc513d commit 2e8d17a

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

_unittests/ut_mlmodel/test_piecewise_decision_tree_experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_criterions(self):
3636
self.assertEqual(w.sum(), X.shape[0])
3737
ind = numpy.arange(y.shape[0]).astype(numpy.int64)
3838
ys = y.astype(float).reshape((y.shape[0], 1))
39+
ys = numpy.ascontiguousarray(ys, dtype=numpy.float64).copy()
3940
_test_criterion_init(c1, ys, w, 1.0, ind, 0, y.shape[0])
4041
_test_criterion_init(c2, ys, w, 1.0, ind, 0, y.shape[0])
4142
return

mlinsights/mlmodel/_piecewise_tree_regression_common.pyx

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,26 @@ cdef class CommonRegressorCriterion(Criterion):
219219
- (self.weighted_n_left / weight * impurity_left)))
220220

221221

222-
cpdef _test_criterion_init(Criterion criterion,
223-
const DOUBLE_t[:, ::1] y,
224-
DOUBLE_t[:] sample_weight,
225-
double weighted_n_samples,
226-
SIZE_t[:] samples,
227-
SIZE_t start, SIZE_t end):
222+
cdef int _ctest_criterion_init(Criterion criterion,
223+
const DOUBLE_t[:, ::1] y,
224+
DOUBLE_t[:] sample_weight,
225+
double weighted_n_samples,
226+
SIZE_t[:] samples,
227+
SIZE_t start, SIZE_t end):
228228
"Test purposes. Methods cannot be directly called from python."
229-
if criterion.init(y, sample_weight, weighted_n_samples, samples, start, end) != 0:
229+
cdef const DOUBLE_t[:, ::1] y2 = y
230+
return criterion.init(y2, sample_weight, weighted_n_samples, samples, start, end)
231+
232+
233+
def _test_criterion_init(Criterion criterion,
234+
const DOUBLE_t[:, ::1] y,
235+
DOUBLE_t[:] sample_weight,
236+
double weighted_n_samples,
237+
SIZE_t[:] samples,
238+
SIZE_t start, SIZE_t end):
239+
"Test purposes. Methods cannot be directly called from python."
240+
if _ctest_criterion_init(criterion, y, sample_weight, weighted_n_samples,
241+
samples, start, end) != 0:
230242
raise AssertionError("Return is not 0.")
231243

232244

0 commit comments

Comments
 (0)