[WIP] PERF Optimize weighted percentile (from n log n to n)#32288
[WIP] PERF Optimize weighted percentile (from n log n to n)#32288cakedev0 wants to merge 23 commits into
Conversation
❌ Linting issuesThis PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling You can see the details of the linting issues under the
|
…nt have a nvidia-GPU
|
Even in draft mode, could you fill in a few details in the top comment already? In particular referencing the relevant other PRs/issues. I think it helps keep track of things. No full explanation, proof, etc needed. |
|
@cakedev0 if you want to run the tests on CUDA interactively, you can use https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c |
| x = x[mask_nz] | ||
| # Recursively compute weighted percentiles using partitioning | ||
| w_sorted = False | ||
| if not hasattr(xp, "argpartition"): |
There was a problem hiding this comment.
| if not hasattr(xp, "argpartition"): | |
| # XXX: update this once argpartition or equivalent is officially part of the | |
| # array API spec: | |
| # https://github.com/data-apis/array-api/issues/629 | |
| if not hasattr(xp, "argpartition"): |
There was a problem hiding this comment.
About that: see my PR data-apis/array-api-extra#449
There was a problem hiding this comment.
(I surely should read carefully all the discussions in the issue you've linked and the related PRs, but that sounds a bit daunting for today 😅)
There was a problem hiding this comment.
Once array-api-extra 0.9.1 is merged with data-apis/array-api-extra#449 we need to revendor it in scikit-learn and update this draft PR to leverage it.
There was a problem hiding this comment.
Is it worth waiting for that before making this PR ready for review?
If not, I'll finish polishing it this week (TODO: write tests, publish some benchmarks, fix CUDA).
There was a problem hiding this comment.
I think the CUDA problem can be fixed independently (see below), but we sure need benchmarks, both with numpy and torch CPU and with torch on CUDA. Ideally both for the current state of this and the future code path with the xpx.argpartition.
There was a problem hiding this comment.
Ideally both for the current state of this and the future code path with the xpx.argpartition
I guess I'll just wait for this future to be the present before doing the benchmarks then. Would that be ok? This mean not touching this PR before array-api-extra 0.9.1 is out and revendored.
There was a problem hiding this comment.
Or maybe it's still better to move forward here now? As this PR is kinda blocking data-apis/array-api-extra#340 (comment) and the equivalent one in scipy.
I'll let you decide, both options work for me.
There was a problem hiding this comment.
Let's benchmark the current state. If this is good, we can proceed with the review now that the tests are green.
It would be great to do benchmarks for 1, 5, 10 quantiles for |
…t-learn into optim-weighted-percentile
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
|
I made some benchmarks with varying shapes For d=1 and nq=1, the gain is clear for numpy (3x) and the loss is very limited for torch. Conclusion: maybe let's just go with the simple loop over quantiles inside the current implementation? This would be a clear and easy gain for some of the current use-cases in sklearn. The algorithm I propose here might have its place to be implemented in Cython/C/C++ somewhere in numpy and/or scipy. Such an implementation would make the gain significant for any d and nq= 1 (or nq=2, 3 too). nq=1 being a common use-case, that might be interesting. |
…t-learn into optim-weighted-percentile
|
Note: I found a way to mitigate the perf loss with d>>1 (I just pushed it), I'll give proper detailed benchmark results if you think it's still worth pursuing. |
|
I give up for now on the O(n) algorithm, so I'm closing this PR. |
I agree, this is a good first step with a net improvement. Once merged we, can always reexplore later to compare this new stronger baseline. But maybe it's not worth investing too much effort if this function is not reported as the computational bottleneck of any user actual workload. |
WIP
Reference Issues/PRs
Follow-up from the proof section of this PR: #32285
What does this implement/fix? Explain your changes.
The algorithm implemented here is basically the one described here: Find a weighted median for unsorted array in linear time adapted to handle several quantiles in the same recursive call.
Complexity is O(n) for one quantile, and approaches O(n log n) for many quantiles (in practice, 10 seems to be many already).
Benchmarks
In numpy:
For one quantile, it's ~3x faster than the unstable-sort version of the current code.
The current code doesn't handle multiple quantiles, so you have to loop on each quantile, which is much slower than my function that compute the 10 deciles in less time than computing one quantile with the current code.