Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gram-based CD/BCD/FISTA solvers for (group)Lasso when n_samples >> n_features #4

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions gram_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# data available at https://www.dropbox.com/sh/32b3mr3xghi496g/AACNRS_NOsUXU-hrSLixNg0ja?dl=0


import time
import numpy as np
from celer import GroupLasso
from skglm.solvers.gram import gram_fista_group_lasso, gram_group_lasso

X = np.load("design_matrix.npy")
y = np.load("target.npy")
groups = np.load("groups.npy")
weights = np.load("weights.npy")
grps = [list(np.where(groups == i)[0]) for i in range(1, 33)]


alpha_ratio = 1e-2
n_alphas = 10
tol = 1e-8

# Case 1: slower runtime for small alphas
# alpha_max = 0.003471727067743962
alpha_max = np.max(np.linalg.norm((X.T @ y).reshape(-1, 5), axis=1)) / len(y)
alpha = alpha_max / 100
clf = GroupLasso(fit_intercept=False, tol=tol,
groups=grps, weights=weights, alpha=alpha, verbose=1)

t0 = time.time()
clf.fit(X, y)
t1 = time.time()

print(f"Celer: {t1 - t0:.3f} s")

t0 = time.time()
res = gram_group_lasso(X, y, alpha, groups=grps, tol=tol, weights=weights, max_iter=10_000,
check_freq=50)
t1 = time.time()

print(f"skglm gram: {t1 - t0:.3f} s")


# FISTA Gram for very small alphas
alpha = alpha_max / 1e-4
clf = GroupLasso(fit_intercept=False, tol=tol, groups=grps, weights=weights, alpha=alpha,
verbose=1)

t0 = time.time()
clf.fit(X, y)
t1 = time.time()

print(f"Celer: {t1 - t0:.3f} s")

t0 = time.time()
res = gram_fista_group_lasso(X, y, alpha, groups=grps, tol=tol, weights=weights, max_iter=10_000,
check_freq=50)
t1 = time.time()
print(f"skglm fista gram: {t1 - t0:.3f} s")
73 changes: 73 additions & 0 deletions skglm/gram_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from time import time
import numpy as np
from numpy.linalg import norm
from celer import Lasso, GroupLasso
from benchopt.datasets.simulated import make_correlated_data
from skglm.solvers.gram import gram_fista_group_lasso, gram_fista_lasso, gram_lasso, gram_group_lasso


n_samples, n_features = 100, 300
X, y, w_star = make_correlated_data(
n_samples=n_samples, n_features=n_features, random_state=0)
alpha_max = norm(X.T @ y, ord=np.inf)

# Hyperparameters
max_iter = 10_000
tol = 1e-8
reg = 0.1
group_size = 3

alpha = alpha_max * reg / n_samples

weights = np.random.normal(2, 0.4, n_features)
weights_grp = np.random.normal(2, 0.4, n_features // group_size)

# Lasso
print("#" * 15)
print("Lasso")
print("#" * 15)
start = time()
w = gram_lasso(X, y, alpha, max_iter, tol, weights=weights)
gram_lasso_time = time() - start
clf_sk = Lasso(alpha, weights=weights, tol=tol, fit_intercept=False)
start = time()
clf_sk.fit(X, y)
celer_lasso_time = time() - start
start = time()
w_fista = gram_fista_lasso(X, y, alpha, max_iter, tol, weights=weights)
gram_fista_lasso_time = time() - start
np.testing.assert_allclose(w, clf_sk.coef_, rtol=1e-4)
np.testing.assert_allclose(w, w_fista, rtol=1e-4)

print("\n")
print("Celer: %.2f" % celer_lasso_time)
print("CD Gram: %.2f" % gram_lasso_time)
print("FISTA Gram: %.2f" % gram_fista_lasso_time)
print("\n")

# Group Lasso
print("#" * 15)
print("Group Lasso")
print("#" * 15)
start = time()
w = gram_group_lasso(X, y, alpha, group_size, max_iter, tol, weights=weights_grp)
gram_group_lasso_time = time() - start
start = time()
w_fista = gram_fista_group_lasso(X, y, alpha, group_size, max_iter, tol,
weights=weights_grp)
gram_fista_group_lasso_time = time() - start

np.testing.assert_allclose(w, w_fista, rtol=1e-4)

clf_celer = GroupLasso(group_size, alpha, tol=tol, weights=weights_grp,
fit_intercept=False)
start = time()
clf_celer.fit(X, y)
celer_group_lasso_time = time() - start
np.testing.assert_allclose(w, clf_celer.coef_, rtol=1e-4)

print("\n")
print("Celer: %.2f" % celer_group_lasso_time)
print("BCD Gram: %.2f" % gram_group_lasso_time)
print("FISTA Gram: %.2f" % gram_fista_group_lasso_time)
print("\n")
254 changes: 254 additions & 0 deletions skglm/solvers/gram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import numpy as np
from numba import njit
from numpy.linalg import norm
from celer.homotopy import _grp_converter
mathurinm marked this conversation as resolved.
Show resolved Hide resolved

from skglm.utils import BST, ST, ST_vec


@njit
def primal(alpha, r, w, weights):
n_features = len(weights)
p_obj = (r @ r) / (2 * len(r))
pen = 0.
for j in range(n_features):
if weights[j] == np.inf:
continue
pen += np.abs(w[j] * weights[j])
return p_obj + alpha * pen


@njit
def primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights):
p_obj = norm_r2 / (2 * len(r))
for g in range(len(grp_ptr) - 1):
PABannier marked this conversation as resolved.
Show resolved Hide resolved
if weights[g] == np.inf:
continue
w_g = w[grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
p_obj += alpha * norm(w_g * weights[g], ord=2)
return p_obj


@njit
def dual(alpha, norm_y2, theta, y):
d_obj = - np.sum((y / (alpha * len(y)) - theta) ** 2)
d_obj *= 0.5 * alpha ** 2 * len(y)
PABannier marked this conversation as resolved.
Show resolved Hide resolved
d_obj += norm_y2 / (2 * len(y))
return d_obj


@njit
def dnorm_l1(theta, X, weights):
n_features = X.shape[1]
scal = 0.
for j in range(n_features):
Xj_theta = X[:, j] @ theta
scal = max(scal, Xj_theta / weights[j])
return scal


@njit
def dnorm_l21(theta, grp_ptr, grp_indices, X, weights):
scal = 0.
n_groups = len(grp_ptr) - 1
for g in range(n_groups):
if weights[g] == np.inf:
continue
tmp = 0.
for k in range(grp_ptr[g], grp_ptr[g + 1]):
j = grp_indices[k]
Xj_theta = X[:, j] @ theta
tmp += Xj_theta ** 2
scal = max(scal, np.sqrt(tmp) / weights[g])
return scal


@njit
def create_dual_point(r, alpha, X, y, weights):
theta = r / (alpha * len(y))
scal = dnorm_l1(theta, X, weights)
if scal > 1.:
theta /= scal
return theta


@njit
def create_dual_point_grp(r, alpha, y, X, grp_ptr, grp_indices, weights):
theta = r / (alpha * len(y))
scal = dnorm_l21(theta, grp_ptr, grp_indices, X, weights)
if scal > 1.:
theta /= scal
return theta


@njit
def dual_gap(alpha, norm_y2, y, X, w, weights):
r = y - X @ w
p_obj = primal(alpha, r, w, weights)
theta = create_dual_point(r, alpha, X, y, weights)
d_obj = dual(alpha, norm_y2, theta, y)
return p_obj, d_obj, p_obj - d_obj


@njit
def dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr, grp_indices, weights):
r = y - X @ w
norm_r2 = r @ r
p_obj = primal_grp(alpha, norm_r2, r, w, grp_ptr, grp_indices, weights)
theta = create_dual_point_grp(r, alpha, y, X, grp_ptr, grp_indices, weights)
d_obj = dual(alpha, norm_y2, theta, y)
return p_obj, d_obj, p_obj - d_obj


@njit
def compute_lipschitz(X, y):
n_features = X.shape[1]
lipschitz = np.zeros(n_features, dtype=X.dtype)
for j in range(n_features):
lipschitz[j] = (X[:, j] ** 2).sum() / len(y)
return lipschitz


@njit
def prox_l21(w, u, weights, grp_ptr, grp_indices):
n_groups = len(grp_ptr) - 1
out = w.copy()
for g in range(n_groups):
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
grp_nrm = norm(w[idx], ord=2)
scaling = np.maximum(1 - u / grp_nrm * weights[g], 0)
out[idx] *= scaling
return out


def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq=100):
n_features = X.shape[1]
norm_y2 = y @ y
grads = X.T @ y / len(y)
G = X.T @ X
lipschitz = compute_lipschitz(X, y)
w = w_init.copy() if w_init is not None else np.zeros(n_features)
weights = weights if weights is not None else np.ones(n_features)
for n_iter in range(max_iter):
cd_epoch(X, G, grads, w, alpha, lipschitz, weights)
if n_iter % check_freq == 0:
p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights)
print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f}" +
f" :: gap {d_gap:.5f}")
if d_gap < tol:
print("Convergence reached!")
break
return w


def gram_fista_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None,
check_freq=100):
n_samples, n_features = X.shape
norm_y2 = y @ y
t_new = 1
w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
weights = weights if weights is not None else np.ones(n_features)
G = X.T @ X
Xty = X.T @ y
L = np.linalg.norm(X, ord=2) ** 2 / n_samples
for n_iter in range(max_iter):
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
z -= (G @ z - Xty) / L / n_samples
w = ST_vec(z, alpha / L * weights)
z = w + (t_old - 1.) / t_new * (w - w_old)
if n_iter % check_freq == 0:
p_obj, d_obj, d_gap = dual_gap(alpha, norm_y2, y, X, w, weights)
print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " +
f":: gap {d_gap:.5f}")
if d_gap < tol:
print("Convergence reached!")
break
return w


def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=None,
check_freq=100):
n_features = X.shape[1]
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
n_groups = len(grp_ptr) - 1
norm_y2 = y @ y
grads = X.T @ y / len(y)
G = X.T @ X
lipschitz = np.zeros(n_groups, dtype=X.dtype)
for g in range(n_groups):
X_g = X[:, grp_indices[grp_ptr[g]:grp_ptr[g + 1]]]
lipschitz[g] = norm(X_g, ord=2) ** 2 / len(y)
w = w_init.copy() if w_init is not None else np.zeros(n_features)
weights = weights if weights is not None else np.ones(n_groups)
for n_iter in range(max_iter):
bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights)
if n_iter % check_freq == 0:
p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr,
grp_indices, weights)
print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " +
f":: gap {d_gap:.5f}")
if d_gap < tol:
print("Convergence reached!")
break
return w


def gram_fista_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None,
weights=None, check_freq=100):
n_features = X.shape[1]
norm_y2 = y @ y
grp_ptr, grp_indices = _grp_converter(groups, X.shape[1])
n_groups = len(grp_ptr) - 1
t_new = 1
w = w_init.copy() if w_init is not None else np.zeros(n_features)
z = w_init.copy() if w_init is not None else np.zeros(n_features)
weights = weights if weights is not None else np.ones(n_groups)
G = X.T @ X
Xty = X.T @ y
L = np.linalg.norm(X, ord=2) ** 2 / len(y)
for n_iter in range(max_iter):
t_old = t_new
t_new = (1 + np.sqrt(1 + 4 * t_old ** 2)) / 2
w_old = w.copy()
z -= (G @ z - Xty) / L / len(y)
w = prox_l21(z, alpha / L, weights, grp_ptr, grp_indices)
z = w + (t_old - 1.) / t_new * (w - w_old)
if n_iter % check_freq == 0:
p_obj, d_obj, d_gap = dual_gap_grp(y, X, w, alpha, norm_y2, grp_ptr,
grp_indices, weights)
print(f"iter {n_iter} :: p_obj {p_obj:.5f} :: d_obj {d_obj:.5f} " +
f":: gap {d_gap:.5f}")
if d_gap < tol:
print("Convergence reached!")
break
return w


@njit
def cd_epoch(X, G, grads, w, alpha, lipschitz, weights):
n_features = X.shape[1]
for j in range(n_features):
if lipschitz[j] == 0. or weights[j] == np.inf:
continue
old_w_j = w[j]
w[j] = ST(w[j] + grads[j] / lipschitz[j], alpha / lipschitz[j] * weights[j])
if old_w_j != w[j]:
grads += G[j, :] * (old_w_j - w[j]) / len(X)


@njit
def bcd_epoch(X, G, grads, w, alpha, lipschitz, grp_indices, grp_ptr, weights):
n_groups = len(grp_ptr) - 1
for g in range(n_groups):
if lipschitz[g] == 0. and weights[g] == np.inf:
continue
idx = grp_indices[grp_ptr[g]:grp_ptr[g + 1]]
old_w_g = w[idx].copy()
w[idx] = BST(w[idx] + grads[idx] / lipschitz[g], alpha / lipschitz[g]
* weights[g])
diff = old_w_g - w[idx]
if np.any(diff != 0.):
grads += diff @ G[idx, :] / len(X)