Skip to content

[WIP] ENH Tree Splitter: 50% performance improvement with radix sort and feature ranks #24239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

pedroilidio
Copy link

@pedroilidio pedroilidio commented Aug 23, 2022

Since sorting on subsets of X's columns is performed multiple times when searching for the best split, we can achieve around 50% reduction in split search duration by computing the ranks of each feature and using them for future sorts instead of the actual feature values in Xf, given that integer sorting algorithms such as radix sort can run in linear time.

Benchmarks (script from #22868)

Test durations in seconds comparing this PR to sklearn 1.1:

image

Pending discussion topics

  • The extra memory needed to store the feature ranks may break old code using large datasets. Should this be an optional feature then? A FasterBestSplitter class could be created, allowing users to opt for the faster version using the splitter parameter of decision tree-based estimators. I haven't found a good way to implement this though, avoiding copy-pasting the whole BestSplit.node_split(). A BestSplitter._sort() method maybe could be created, but wrapping the two sorting procedures (radix and the current introsort) in the same method seems a little complicated;
  • We should be able to pass a precomputed X_ranks array to the Splitter to avoid redundant calculation in ensembles or cross-validation;
  • In this implementation, I used numpy.argsort twice to compute the ranks. Is there any better option? Numpy already uses radix sort for integer types; Found a way of sorting once, based on Scipy's rankdata. I don't see much room for improvement, but ideas are always appreciated;
  • In the first tree node, when the ranks are sorted, the results would be the sorted indices (output from np.argsort). Since these sorted indices would naturally be obtained when determining the ranks, is it possible to reuse them in the root node? I imagine this would require to store not only the ranks but also the sorted indices, which adds a lot of memory load for very small gains. An alternative could be calculating the ranks in the root node, but I don't see any intuitive way of doing so;
  • Two auxiliary arrays were created, but some algorithms such as American Flag Sorting
    claim to be able to perform in-place sorting with similar time complexity. I did not explore those;
  • Also implement a sparse splitter version. Delegated to future PR.

The branch passes all tree/tests/test_tree.py tests.

Downsides: we are currently using auxilary vectors declared in BestSplitter and we have to hold two whole matrices X_ranks and X_order. Further strategies to reduce memory usage are welcome.

this branch:
n_samples=1000 with 0.066 +/- 0.017
n_samples=5000 with 0.491 +/- 0.124
n_samples=10000 with 1.276 +/- 0.048
n_samples=20000 with 3.520 +/- 0.416
n_samples=50000 with 12.558 +/- 2.260

sklearn/main:
n_samples=1000 with 0.100 +/- 0.015
n_samples=5000 with 0.972 +/- 0.279
n_samples=10000 with 2.492 +/- 0.188
n_samples=20000 with 7.367 +/- 1.070
n_samples=50000 with 26.034 +/- 3.585
@pedroilidio
Copy link
Author

pedroilidio commented Aug 24, 2022

Regarding item 3, scipy.stats.rankdata consistently outperforms applying numpy.argsort twice, so we'd better switch to it. At this point, I think we can drop item 4's idea to avoid complicating the code.

I'd be glad to hear opinions and ideas from anyone out there and discuss the feasibility of this PR.

Edit: It turns out that rankdata uses np.argsort under the hood, but only once, in a more clever way.

image

Benchmark script
from argparse import ArgumentParser
from timeit import repeat
import numpy as np
from scipy.stats import rankdata
import matplotlib.pyplot as plt

argparser = ArgumentParser()
argparser.add_argument('--out', '-o', default='benchmark_results.tsv')
argparser.add_argument('--n-iter', default=8, type=int)
argparser.add_argument('--n-times', default=1, type=int)
argparser.add_argument('--min', default=10, type=int)
argparser.add_argument('--max', default=17, type=int)
argparser.add_argument('--cols', default=500, type=int)
argparser.add_argument('--repeat', '-r', default=20, type=int)
args = argparser.parse_args()

nn = np.logspace(args.min, args.max, args.n_iter, base=2, dtype=int)
outfile = open(args.out, 'w')
outfile.write('n\tnp_mean\tnp_std\tsp_mean\tsp_std\n')

np_means, np_stdevs , sp_means, sp_stdevs = [], [], [], []

for n in nn:
    print('*** n =', n)
    a = np.random.rand(n, args.cols)
    np_res = repeat('np.argsort(np.argsort(a, axis=0), axis=0)',
                    number=args.n_times, repeat=args.repeat, globals=globals())
    sp_res = repeat("rankdata(a, method='ordinal', axis=0)-1",
                    number=args.n_times, repeat=args.repeat, globals=globals())

    np_mean, np_stdev = np.mean(np_res), np.std(np_res)
    sp_mean, sp_stdev = np.mean(sp_res), np.std(sp_res)

    print(f'np.argsort:     {np_mean:.4f} ({np_stdev:.4f})')
    print(f'scipy.rankdata: {sp_mean:.4f} ({sp_stdev:.4f})')
    outfile.write(f'{n}\t{np_mean}\t{np_stdev}\t{sp_mean}\t{sp_stdev}\n')

    np_means.append(np_mean)
    np_stdevs.append(np_stdev)

    sp_means.append(sp_mean)
    sp_stdevs.append(sp_stdev)

outfile.close()

# nn, np_means, np_stdevs , sp_means, sp_stdevs = np.loadtxt(args.out, skiprows=1).T

plt.figure(figsize=(10, 3))

plt.errorbar(nn, np_means, fmt='-o', yerr=np_stdevs, label='numpy.argsort')
plt.errorbar(nn, sp_means, fmt='--o', yerr=sp_stdevs, label='scipy.stats.rankdata')

plt.xlabel('Number of samples')
plt.ylabel('Duration (s)')
plt.yscale('log')
plt.xscale('log')
plt.legend()
plt.grid()

plt.savefig(args.out+'.png')

@pedroilidio
Copy link
Author

Based on scipys implementation, I've rewritten the ranking step in the last commit to avoid sorting twice as before, using the already defined sort function of the _splitter module. A further small performance improvement seems to happen, mainly on smaller datasets, as expected.

Scipy code:
https://github.com/scipy/scipy/blob/651a9b717deb68adde9416072c1e1d5aa14a58a1/scipy/stats/_stats_py.py#L9128-L9134

New implementation:
38ec9d3#diff-e2cca285e1e883ab1d427120dfa974c1ba83eb6e2f5d5f416bbd99717ca5f5fcR282-R292

New benchmarks

argsort is the previous version using numpy.argsort twice, manual is this new implementation.

image

@betatim
Copy link
Member

betatim commented Sep 13, 2022

This looks like an interesting idea and not too much code/new complexity.

What would be the next steps to move this forward?

I think, as pointed out, the extra memory consumption is worth thinking about. Can we measure this for the benchmarks you performed to get a feeling for the size of the increase?

I think I'd defer the work for sparse X to a new PR, same with the optimisation to pass in precomputed ranks.

WDYT?

@pedroilidio
Copy link
Author

Thanks for your input, @betatim. I agree that precomputed ranks and a sparse version can wait for a future PR. I see that passing X_ranks, especially, would need deeper code changes and thus more careful discussion.

Additionally, the radix sort implementation seems good enough to me, but it still may benefit from someone more experienced checking it out. Maybe an in-place version to drop the auxiliary arrays?

But I think the only point really preventing us from going on is the increased memory load. As suggested, I ran a similar benchmark (bellow) for us to get a better idea of the impact. The increase was lower than I expected given the arrays' relative sizes and that X's type is np.float32 and X_ranks' type is np.intp (which is int64 on 64-bit machines as the one I used, AFAIU). The smaller error bars on the n=50000 plot also seemed a little odd to me, but they consistently appeared as such in subsequent runs.

A safe way forward is to rethink this PR as an optional change. I noticed that node_split() is very similar across BestSplitter and RandomSplitter so that we could try sharing more code between the two and a possible new FasterBestSplitter. Again, this seems like material for another PR, and the direct solution here would be to just replicate BestSplitter and change the necessary bits.

Anyway, everything depends on how much of a problem would the extra memory consumption be, and I would appreciate hearing some more thoughts on that before going on and making a separate Splitter.

Please also let me know if there are any further analyses and considerations I can help with.

Memory usage comparison (MB)

image

Memory benchmark script
import argparse
import json
from statistics import mean, stdev
from memory_profiler import memory_usage

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from collections import defaultdict


N_SAMPLES = [1_000, 5_000, 10_000, 20_000, 50_000]
N_REPEATS = 20

parser = argparse.ArgumentParser()
parser.add_argument("results")
args = parser.parse_args()

results = defaultdict(list)


def train_tree(X, y, random_state):
    tree = DecisionTreeClassifier(random_state=random_state)
    tree.fit(X, y)


for n_samples in N_SAMPLES:
    for n_repeat in range(N_REPEATS):
        X, y = make_classification(
            random_state=n_repeat, n_samples=n_samples, n_features=100
        )
        memusage = memory_usage((train_tree, (X, y, n_repeat), {}))
        results[n_samples].append(max(memusage))

    results_mean, results_stdev = mean(results[n_samples]), stdev(results[n_samples])
    print(f"n_samples={n_samples} with {results_mean:.3f} +/- {results_stdev:.3f}")

with open(args.results, "w") as out:
    json.dump(results, out)

@betatim
Copy link
Member

betatim commented Sep 19, 2022

When I run your script I get the following results for the current main:

n_samples=1000 with 143.772 +/- 4.745
n_samples=5000 with 197.711 +/- 14.662
n_samples=10000 with 276.789 +/- 20.704
n_samples=20000 with 331.693 +/- 40.535
n_samples=50000 with 336.234 +/- 24.382

and these numbers when I run it for this PR:

n_samples=1000 with 134.100 +/- 4.162
n_samples=5000 with 183.102 +/- 10.466
n_samples=10000 with 258.835 +/- 22.544
n_samples=20000 with 352.887 +/- 3.416
n_samples=50000 with 529.395 +/- 10.184

I had to make one change to the script, but I think that it doesnt change the results:

import argparse
import json
from statistics import mean, stdev
from memory_profiler import memory_usage

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from collections import defaultdict


N_SAMPLES = [1_000, 5_000, 10_000, 20_000, 50_000]
N_REPEATS = 20

parser = argparse.ArgumentParser()
parser.add_argument("results")
args = parser.parse_args()

results = defaultdict(list)


def train_tree(X, y, random_state):
    tree = DecisionTreeClassifier(random_state=random_state)
    tree.fit(X, y)


if __name__ == "__main__":
    for n_samples in N_SAMPLES:
        for n_repeat in range(N_REPEATS):
            X, y = make_classification(
                random_state=n_repeat, n_samples=n_samples, n_features=100
            )
            memusage = memory_usage((train_tree, (X, y, n_repeat), {}))
            results[n_samples].append(max(memusage))

        results_mean, results_stdev = mean(results[n_samples]), stdev(results[n_samples])
        print(f"n_samples={n_samples} with {results_mean:.3f} +/- {results_stdev:.3f}")

    with open(args.results, "w") as out:
        json.dump(results, out)

I also ran the script using RandomForestClassifier instead of a single decision tree. These are the results on main:

n_samples=1000 with 137.230 +/- 2.922
n_samples=5000 with 157.052 +/- 9.558
n_samples=10000 with 206.872 +/- 12.596
n_samples=20000 with 196.811 +/- 38.642
n_samples=50000 with 318.940 +/- 36.224

and on this PR the script has been running on the first value of n_samples for a long time (10 or 15minutes) now. So I've stopped it.

I think something to investigate is what happens if you use the decision trees in something like a random forest (with different values of n_jobs).

The overall question is: how much extra memory are you allowed to use for a given speed up? I don't know the answer to that. Maybe we need a core dev to help with that.

@pedroilidio
Copy link
Author

pedroilidio commented Oct 5, 2022

I've modified the memory benchmarking script to use filprofiler, in order to track allocation instead of resident memory. Results are more consistent, showing the expected 3-fold increase in memory usage (we are additionally storing an np.intp for each np.float32-typed feature value).

We should be able to get a 2-fold instead of 3-fold increase by simply using np.uint32 instead of np.intp. I will soon work on that commit.

image

Instructions

Set the PR and 1.1 environments (named env and env_stable in my case) and run the following to reproduce:

$ python run_benchmark_across_envs.py --script memory_benchmark_fil.py

Notice that additional command-line options are possible, for example:

$ python run_benchmark_across_envs.py --script memory_benchmark_fil.py --classifier sklearn.ensemble.RandomForestClassifier --n-samples 1000 5000 10000  --envs env env_stable --env-names PR 1.1
memory_benchmark_fil.py
import argparse
import json
from importlib import import_module
from statistics import mean, stdev
from filprofiler.api import profile

from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from collections import defaultdict

N_SAMPLES = [1_000, 5_000, 10_000, 20_000, 50_000]
N_REPEATS = 20
CLASSIFIER = "sklearn.tree.DecisionTreeClassifier"
OUTPATH = "benchmark_results.json"


def get_total_mem_peak(prof_path):
    with open(prof_path) as prof_file:
        return sum(int(line.rsplit(maxsplit=1)[1]) for line in prof_file if '(fit)' in line)


def classifier_memory_benchmark(
    classifier=DecisionTreeClassifier,
    n_repeats=N_REPEATS,
    n_samples=N_SAMPLES,
):
    results = defaultdict(list)

    for n_sample in n_samples:
        for n_repeat in range(n_repeats):
            print(f"Repeat {n_repeat}/{n_repeats}...", end="\r")
            X, y = make_classification(
                random_state=n_repeat, n_samples=n_sample, n_features=100
            )
            clf = classifier(random_state=n_repeat)

            outdir = f"fil-result/{n_sample}"
            profile(lambda: clf.fit(X, y, n_repeat), outdir)
            results[n_sample].append(get_total_mem_peak(outdir+"/peak-memory.prof"))

        results_mean, results_stdev = mean(results[n_sample]), stdev(results[n_sample])
        print(f"n_sample={n_sample} with {results_mean:.3f} +/- {results_stdev:.3f}")
    
    return results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--out", "-o", default=OUTPATH)
    parser.add_argument("--n-repeats", "-r", type=int, default=N_REPEATS)
    parser.add_argument("--n-samples", "-s", nargs="+", type=int, default=N_SAMPLES)
    parser.add_argument("--classifier", "-c", default=CLASSIFIER)
    args = parser.parse_args()

    module, clf = args.classifier.rsplit(".", maxsplit=1)
    classifier = getattr(import_module(module), clf)

    print("Starting.")

    results = classifier_memory_benchmark(
        classifier=classifier,
        n_samples=args.n_samples,
        n_repeats=args.n_repeats,
    )

    with open(args.out, "w+") as out:
        json.dump(results, out)
    
    print(f"Results written to {args.out}")


if __name__ == "__main__":
    main()
run_benchmark_across_envs.py
import argparse
import json
from pathlib import Path 
from statistics import mean, stdev
from subprocess import run

import matplotlib.pyplot as plt

OUTPATH = "combined_benchmark_results.json"
INDIVIDUAL_OUTPATH = "benchmark_results.json"
ENVS = ["env_stable", "env"][::-1]
ENV_NAMES = ["1.1", "PR"][::-1]
BENCHMARK_SCRIPT = "memory_benchmark.py"
PLOT_TITLE = "Peak memory allocation (bytes)"


def plot_results(data, out=OUTPATH+".png", title=PLOT_TITLE,
                 width=None, height=3):

    n_plots = len(next(iter(data.values())))
    width = width or len(data)*n_plots
    plt.figure(figsize=(width, height))
    plt.suptitle(title)

    for j, (name, times) in enumerate(data.items()):
        for i, (size, values) in enumerate(times.items()):
            plt.subplot(1, n_plots, i+1)
            plt.title('n = '+size)
            plt.bar(j, mean(values), yerr=stdev(values))
            plt.xticks(range(len(data)), rotation=45,
                       labels=(k.split('_')[0] for k in data.keys()))

    plt.tight_layout()
    plt.savefig(out)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--out", "-o", default=OUTPATH)
    parser.add_argument("--tmpout", default=INDIVIDUAL_OUTPATH)
    parser.add_argument("--envs", nargs="+", default=ENVS)
    parser.add_argument("--env-names", nargs="+", default=ENV_NAMES)
    parser.add_argument("--script", default=BENCHMARK_SCRIPT)
    parser.add_argument("--no-fil", action="store_true")
    parser.add_argument("--title", default=PLOT_TITLE)
    args, unknown_args = parser.parse_known_args()

    assert len(args.envs) == len(args.env_names)

    results = {}
    for env, env_name in zip(args.envs, args.env_names):
        python_command = [str(Path(env)/"bin/python")]
        if not args.no_fil:
            python_command += ["-m", "filprofiler", "python"]

        command = [
            *python_command, args.script, "-o", args.tmpout, *unknown_args,
        ]

        print(f"Running '{' '.join(command)}'")
        run(command, check=True)

        with open(args.tmpout, "r") as tmpout:
            results[env_name] = json.load(tmpout)

        with open(args.out, "w") as out:
            json.dump(results, out)
    
    print(f"Combined results written to {args.out}")

    plot_results(results, args.out+".png", args.title)
    print(f"Plot saved to {args.out+'.png'}")


if __name__ == "__main__":
    main()

I also arrived at the same issue with RandomForestClassifier (thanks for pointing that out, @betatim), not being able to run the benchmarks in reasonable time. I suppose at least part of the problem comes from each tree allocating its own X_ranks matrix, so that, with the default 100 trees, we end up allocating ~200 times more memory (!). If that's the case, the second discussion topic turns out to be impeditive, and deeper changes may be needed for this to work. However, running with as few as 10 samples still generates the issue. It can be reproduced by running

$ python memory_benchmark.py --classifier sklearn.ensemble.RandomForestClassifier --n-samples 10

@lorentzenchr
Copy link
Member

I close as this PR seems stalled and a heavy performance regression is unresolved.
@pedroilidio Thanks for trying this out. I case you solve the open performance issues, feeld free to re-open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants