Skip to content

[WIP] PERF Optimize weighted percentile (from n log n to n)#32288

Closed
cakedev0 wants to merge 23 commits into
scikit-learn:mainfrom
cakedev0:optim-weighted-percentile
Closed

[WIP] PERF Optimize weighted percentile (from n log n to n)#32288
cakedev0 wants to merge 23 commits into
scikit-learn:mainfrom
cakedev0:optim-weighted-percentile

Conversation

@cakedev0

@cakedev0 cakedev0 commented Sep 28, 2025

Copy link
Copy Markdown
Contributor

WIP

Reference Issues/PRs

Follow-up from the proof section of this PR: #32285

What does this implement/fix? Explain your changes.

The algorithm implemented here is basically the one described here: Find a weighted median for unsorted array in linear time adapted to handle several quantiles in the same recursive call.

Complexity is O(n) for one quantile, and approaches O(n log n) for many quantiles (in practice, 10 seems to be many already).

Benchmarks

In numpy:
For one quantile, it's ~3x faster than the unstable-sort version of the current code.
The current code doesn't handle multiple quantiles, so you have to loop on each quantile, which is much slower than my function that compute the 10 deciles in less time than computing one quantile with the current code.

@github-actions

github-actions Bot commented Sep 28, 2025

Copy link
Copy Markdown

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


cython-lint

cython-lint detected issues. Please fix them locally and push the changes. Here you can see the detected issues. Note that the installed cython-lint version is cython-lint=0.18.0.

Details

/home/runner/work/scikit-learn/scikit-learn/sklearn/cluster/_hdbscan/_tree.pyx:786:19: unnecessary set + generator (just use a set comprehension)

Generated for commit: d59abf5. Link to the linter CI: here

@betatim

betatim commented Sep 30, 2025

Copy link
Copy Markdown
Member

Even in draft mode, could you fill in a few details in the top comment already? In particular referencing the relevant other PRs/issues. I think it helps keep track of things. No full explanation, proof, etc needed.

@ogrisel ogrisel added the CUDA CI label Oct 2, 2025
@github-actions github-actions Bot removed the CUDA CI label Oct 2, 2025
@ogrisel

ogrisel commented Oct 2, 2025

Copy link
Copy Markdown
Member

@cakedev0 if you want to run the tests on CUDA interactively, you can use https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c

Comment thread sklearn/utils/stats.py Outdated
x = x[mask_nz]
# Recursively compute weighted percentiles using partitioning
w_sorted = False
if not hasattr(xp, "argpartition"):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not hasattr(xp, "argpartition"):
# XXX: update this once argpartition or equivalent is officially part of the
# array API spec:
# https://github.com/data-apis/array-api/issues/629
if not hasattr(xp, "argpartition"):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About that: see my PR data-apis/array-api-extra#449

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I surely should read carefully all the discussions in the issue you've linked and the related PRs, but that sounds a bit daunting for today 😅)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once array-api-extra 0.9.1 is merged with data-apis/array-api-extra#449 we need to revendor it in scikit-learn and update this draft PR to leverage it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth waiting for that before making this PR ready for review?

If not, I'll finish polishing it this week (TODO: write tests, publish some benchmarks, fix CUDA).

@ogrisel ogrisel Oct 6, 2025

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the CUDA problem can be fixed independently (see below), but we sure need benchmarks, both with numpy and torch CPU and with torch on CUDA. Ideally both for the current state of this and the future code path with the xpx.argpartition.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally both for the current state of this and the future code path with the xpx.argpartition

I guess I'll just wait for this future to be the present before doing the benchmarks then. Would that be ok? This mean not touching this PR before array-api-extra 0.9.1 is out and revendored.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe it's still better to move forward here now? As this PR is kinda blocking data-apis/array-api-extra#340 (comment) and the equivalent one in scipy.

I'll let you decide, both options work for me.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's benchmark the current state. If this is good, we can proceed with the review now that the tests are green.

Comment thread sklearn/utils/stats.py Outdated
Comment thread sklearn/utils/stats.py Outdated
@ogrisel

ogrisel commented Oct 9, 2025

Copy link
Copy Markdown
Member

Benchmarks

For one quantile, it's ~3x faster than the unstable-sort version of the current code.

It would be great to do benchmarks for 1, 5, 10 quantiles for int(1e5), int(1e6), int(1e7) data points with uniform and heavy tailed data and weights, both on numpy CPU and torch CUDA.

@cakedev0

Copy link
Copy Markdown
Contributor Author

I made some benchmarks with varying shapes (n, d) for the input array and with nq quantiles (1, 3, 9).

For d=1 and nq=1, the gain is clear for numpy (3x) and the loss is very limited for torch.
For nq > 1, the gain is clear in most cases, but significant gains could also be reached by just looping over quantiles inside the function (once the sort is done), and not outside.
For d > 1, and esp. d >> 1 (like 100), it's not great. I think it's because of the the over-head of looping over the dimensions, and making a lot of calls xp.some_func(...).

Conclusion: maybe let's just go with the simple loop over quantiles inside the current implementation? This would be a clear and easy gain for some of the current use-cases in sklearn.

The algorithm I propose here might have its place to be implemented in Cython/C/C++ somewhere in numpy and/or scipy. Such an implementation would make the gain significant for any d and nq= 1 (or nq=2, 3 too). nq=1 being a common use-case, that might be interesting.
I'm not familiar with those ecosystems so any advice/insight will be appreciated!

@cakedev0

Copy link
Copy Markdown
Contributor Author

Note: I found a way to mitigate the perf loss with d>>1 (I just pushed it), I'll give proper detailed benchmark results if you think it's still worth pursuing.

@cakedev0

Copy link
Copy Markdown
Contributor Author

I give up for now on the O(n) algorithm, so I'm closing this PR.

@ogrisel

ogrisel commented Oct 20, 2025

Copy link
Copy Markdown
Member

Conclusion: maybe let's just go with the simple loop over quantiles inside the current implementation? This would be a clear and easy gain for some of the current use-cases in sklearn.

I agree, this is a good first step with a net improvement. Once merged we, can always reexplore later to compare this new stronger baseline. But maybe it's not worth investing too much effort if this function is not reported as the computational bottleneck of any user actual workload.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants