Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add FISTA solver #91

Merged
merged 31 commits into from
Oct 22, 2022
Merged

ENH Add FISTA solver #91

merged 31 commits into from
Oct 22, 2022

Conversation

PABannier
Copy link
Collaborator

Closes #89

A few points to discuss:

  1. Currently the FISTA solver uses Gram updates (as per Gram-based CD/BCD/FISTA solvers for (group)Lasso when n_samples >> n_features #4 ). Question: do we want to keep it this way? Implement without Gram update? Or have two options?

  2. For non-coordinate-wise updates, we run into the issue of not having a prox_vec method in the BasePenalty class. If we want to support a larger class of penalties for FISTA (e.g.: L1, WeightedL1, SLOPE, ...), we need a prox_vec method.

@mathurinm
Copy link
Collaborator

  • Remove Gram to handle the generic case (+ gram is only suited to quadratics)
  • To keep the API simple, can you do a for loop over coordinates to compute the prox, calling prox_1D ? We'll lose a bit of time but the gradient computation should be the dominating cost

skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
@PABannier PABannier mentioned this pull request Oct 14, 2022
Copy link
Collaborator

@Badr-MOUFAD Badr-MOUFAD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work @PABannier!

Below, some minor remarks.

Besides, I have one concern: I don't think it's a good idea to add support for FISTA to all the datafits. We can limit ourselves to one of them just for testing purposes.
Indeed, AndersonCD and ProxNewton are much faster for separable problems. We better keep FISTA for particular cases (e.g SLOPE #92)

WDYT?

skglm/datafits/single_task.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Show resolved Hide resolved
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Show resolved Hide resolved
@PABannier
Copy link
Collaborator Author

@Badr-MOUFAD totally agree. FISTA would be for a subset of penalties where PN or AndersonCD are not available.

for j in range(n_features):
Xj = X_data[X_indptr[j]:X_indptr[j+1]]
self.lipschitz[j] = (Xj ** 2).sum() / (len(y) * 4)
self.global_lipschitz += (Xj ** 2).sum() / (len(y) * 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that will yield a very crude bound, potentially with a loss or the order of n_features.

Use a few iterations of the power method instead to approximate the lipschitz constant of the sparse matrix (there's also the Lanczos iteration but it's more complicated, let's implement the easy one first)

@Badr-MOUFAD Badr-MOUFAD self-assigned this Oct 20, 2022
skglm/solvers/fista.py Outdated Show resolved Hide resolved
skglm/solvers/fista.py Show resolved Hide resolved
skglm/tests/test_fista.py Outdated Show resolved Hide resolved
skglm/tests/test_fista.py Outdated Show resolved Hide resolved
skglm/utils.py Outdated Show resolved Hide resolved
skglm/utils.py Outdated Show resolved Hide resolved
skglm/utils.py Outdated Show resolved Hide resolved
skglm/utils.py Outdated Show resolved Hide resolved
@Badr-MOUFAD
Copy link
Collaborator

Badr-MOUFAD commented Oct 21, 2022

I am unable to find the root cause of the problem at the uniitest.

Here is a small script to reproduce

import numpy as np
from  scipy.sparse import random
from skglm.estimators import LinearSVC

n_samples, n_features = 20, 30
X_sparse = random(n_samples, n_features, density=0.5, format='csc', random_state=0)
y = np.ones(n_samples)

LinearSVC(C=1., tol=1e-9).fit(X_sparse, y)

Output (it depends)

Segmentation fault (core dumped)

or

corrupted size vs. prev_size
Aborted (core dumped)

or

python3: malloc.c:3852: _int_malloc: Assertion `chunk_main_arena (fwd)' failed.
Aborted (core dumped)

@mathurinm, @PABannier, any thoughts?

@PABannier
Copy link
Collaborator Author

A segfault is usually thrown by Numba when it can't access something it should (e.g. missing initialization of datafit). Have you tried setting breakpoints at various places of the code to see which line is causing the issue?

@Badr-MOUFAD
Copy link
Collaborator

Yes absolutely, I tried that. it breaks down in the initialization of datafit. Yet, I can't figure out why.
What is surprising is that it works for some X sizes,

@PABannier
Copy link
Collaborator Author

PABannier commented Oct 21, 2022

Weird, that works for me on this branch. Can you try reinstalling numba and skglm? I've run into similar issues with celer where I had segfaults on my machine that disappeared when I reinstalled the package.

@Badr-MOUFAD
Copy link
Collaborator

Badr-MOUFAD commented Oct 21, 2022

I found the bug.
In the case of SVC, the design matrix is yXT which has particularly n_rows=n_features instead of n_samples.
This caused an index out of range in spectral_norm.

Thanks @PABannier for your help!

@mathurinm
Copy link
Collaborator

Wow, good catch @Badr-MOUFAD

For a more robust design replace n_samples by n_rows_X and pass X.shape[0] explicitly

@mathurinm mathurinm merged commit 359f4da into scikit-learn-contrib:main Oct 22, 2022
@mathurinm
Copy link
Collaborator

Thanks @PABannier and @Badr-MOUFAD

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FEAT add FISTA solver
3 participants