Skip to content

Commit c867c2f

Browse files
committed
CLN always make raw_predictions.shape=(n_samples, n_trees_per_iteration)
1 parent 1985c94 commit c867c2f

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -413,16 +413,17 @@ def fit(self, X, y, sample_weight=None):
413413

414414
# initialize raw_predictions: those are the accumulated values
415415
# predicted by the trees for the training data. raw_prediction has
416-
# shape (n_trees_per_iteration, n_samples) where
416+
# shape (n_samples, n_trees_per_iteration) where
417417
# n_trees_per_iterations is n_classes in multiclass classification,
418418
# else 1.
419-
# self._baseline_prediction has shape (n_trees_per_iteration, 1)
419+
# self._baseline_prediction has shape (1, n_trees_per_iteration)
420420
self._baseline_prediction = self._loss.fit_intercept_only(
421421
y_true=y_train, sample_weight=sample_weight_train
422-
).reshape((-1, 1))
422+
).reshape((1, -1))
423423
raw_predictions = np.zeros(
424-
shape=(self.n_trees_per_iteration_, n_samples),
424+
shape=(n_samples, self.n_trees_per_iteration_),
425425
dtype=self._baseline_prediction.dtype,
426+
order="F",
426427
)
427428
raw_predictions += self._baseline_prediction
428429

@@ -452,8 +453,9 @@ def fit(self, X, y, sample_weight=None):
452453

453454
if self._use_validation_data:
454455
raw_predictions_val = np.zeros(
455-
shape=(self.n_trees_per_iteration_, X_binned_val.shape[0]),
456+
shape=(X_binned_val.shape[0], self.n_trees_per_iteration_),
456457
dtype=self._baseline_prediction.dtype,
458+
order="F",
457459
)
458460

459461
raw_predictions_val += self._baseline_prediction
@@ -553,15 +555,15 @@ def fit(self, X, y, sample_weight=None):
553555
if self._loss.constant_hessian:
554556
self._loss.gradient(
555557
y_true=y_train,
556-
raw_prediction=raw_predictions.T,
558+
raw_prediction=raw_predictions,
557559
sample_weight=sample_weight_train,
558560
gradient_out=gradient,
559561
n_threads=n_threads,
560562
)
561563
else:
562564
self._loss.gradient_hessian(
563565
y_true=y_train,
564-
raw_prediction=raw_predictions.T,
566+
raw_prediction=raw_predictions,
565567
sample_weight=sample_weight_train,
566568
gradient_out=gradient,
567569
hessian_out=hessian,
@@ -609,7 +611,7 @@ def fit(self, X, y, sample_weight=None):
609611
loss=self._loss,
610612
grower=grower,
611613
y_true=y_train,
612-
raw_prediction=raw_predictions[k, :],
614+
raw_prediction=raw_predictions[:, k],
613615
sample_weight=sample_weight_train,
614616
)
615617

@@ -621,7 +623,7 @@ def fit(self, X, y, sample_weight=None):
621623
# Update raw_predictions with the predictions of the newly
622624
# created tree.
623625
tic_pred = time()
624-
_update_raw_predictions(raw_predictions[k, :], grower, n_threads)
626+
_update_raw_predictions(raw_predictions[:, k], grower, n_threads)
625627
toc_pred = time()
626628
acc_prediction_time += toc_pred - tic_pred
627629

@@ -631,7 +633,7 @@ def fit(self, X, y, sample_weight=None):
631633
# Update raw_predictions_val with the newest tree(s)
632634
if self._use_validation_data:
633635
for k, pred in enumerate(self._predictors[-1]):
634-
raw_predictions_val[k, :] += pred.predict_binned(
636+
raw_predictions_val[:, k] += pred.predict_binned(
635637
X_binned_val,
636638
self._bin_mapper.missing_values_bin_idx_,
637639
n_threads,
@@ -804,7 +806,7 @@ def _check_early_stopping_loss(
804806
self.train_score_.append(
805807
-self._loss(
806808
y_true=y_train,
807-
raw_prediction=raw_predictions.T,
809+
raw_prediction=raw_predictions,
808810
sample_weight=sample_weight_train,
809811
n_threads=n_threads,
810812
)
@@ -814,7 +816,7 @@ def _check_early_stopping_loss(
814816
self.validation_score_.append(
815817
-self._loss(
816818
y_true=y_val,
817-
raw_prediction=raw_predictions_val.T,
819+
raw_prediction=raw_predictions_val,
818820
sample_weight=sample_weight_val,
819821
n_threads=n_threads,
820822
)
@@ -928,7 +930,7 @@ def _raw_predict(self, X, n_threads=None):
928930
929931
Returns
930932
-------
931-
raw_predictions : array, shape (n_trees_per_iteration, n_samples)
933+
raw_predictions : array, shape (n_samples, n_trees_per_iteration)
932934
The raw predicted values.
933935
"""
934936
is_binned = getattr(self, "_in_fit", False)
@@ -942,8 +944,9 @@ def _raw_predict(self, X, n_threads=None):
942944
)
943945
n_samples = X.shape[0]
944946
raw_predictions = np.zeros(
945-
shape=(self.n_trees_per_iteration_, n_samples),
947+
shape=(n_samples, self.n_trees_per_iteration_),
946948
dtype=self._baseline_prediction.dtype,
949+
order="F",
947950
)
948951
raw_predictions += self._baseline_prediction
949952

@@ -979,7 +982,7 @@ def _predict_iterations(self, X, predictors, raw_predictions, is_binned, n_threa
979982
f_idx_map=f_idx_map,
980983
n_threads=n_threads,
981984
)
982-
raw_predictions[k, :] += predict(X)
985+
raw_predictions[:, k] += predict(X)
983986

984987
def _staged_raw_predict(self, X):
985988
"""Compute raw predictions of ``X`` for each iteration.
@@ -995,7 +998,7 @@ def _staged_raw_predict(self, X):
995998
Yields
996999
-------
9971000
raw_predictions : generator of ndarray of shape \
998-
(n_trees_per_iteration, n_samples)
1001+
(n_samples, n_trees_per_iteration)
9991002
The raw predictions of the input samples. The order of the
10001003
classes corresponds to that in the attribute :term:`classes_`.
10011004
"""
@@ -1008,8 +1011,9 @@ def _staged_raw_predict(self, X):
10081011
)
10091012
n_samples = X.shape[0]
10101013
raw_predictions = np.zeros(
1011-
shape=(self.n_trees_per_iteration_, n_samples),
1014+
shape=(n_samples, self.n_trees_per_iteration_),
10121015
dtype=self._baseline_prediction.dtype,
1016+
order="F",
10131017
)
10141018
raw_predictions += self._baseline_prediction
10151019

@@ -1693,7 +1697,7 @@ def predict_proba(self, X):
16931697
The class probabilities of the input samples.
16941698
"""
16951699
raw_predictions = self._raw_predict(X)
1696-
return self._loss.predict_proba(raw_predictions.T)
1700+
return self._loss.predict_proba(raw_predictions)
16971701

16981702
def staged_predict_proba(self, X):
16991703
"""Predict class probabilities at each iteration.
@@ -1713,7 +1717,7 @@ def staged_predict_proba(self, X):
17131717
for each iteration.
17141718
"""
17151719
for raw_predictions in self._staged_raw_predict(X):
1716-
yield self._loss.predict_proba(raw_predictions.T)
1720+
yield self._loss.predict_proba(raw_predictions)
17171721

17181722
def decision_function(self, X):
17191723
"""Compute the decision function of ``X``.

0 commit comments

Comments
 (0)