from sklearn.base import RegressorMixin
from sklearn.linear_model.base import LinearModel, _rescale_data
from sklearn.utils.validation import check_X_y
from ..vendor.group_lasso.group_lasso import sparse_group_lasso
[docs]class SparseGroupLasso(LinearModel, RegressorMixin):
def __init__(
self,
groups,
alpha=1.0,
rho=0.5,
max_iter=1000,
tol=1e-4,
normalize=False,
fit_intercept=True,
copy_X=True,
):
self.alpha = alpha
self.rho = rho
self.groups = groups
self.max_iter = max_iter
self.tol = tol
self.normalize = normalize
self.fit_intercept = fit_intercept
self.copy_X = copy_X
[docs] def fit(self, x, y, sample_weight=None):
x, y = check_X_y(x, y, accept_sparse=[], y_numeric=True, multi_output=False)
x, y, X_offset, y_offset, X_scale = self._preprocess_data(
x,
y,
fit_intercept=self.fit_intercept,
normalize=self.normalize,
copy=self.copy_X,
sample_weight=sample_weight,
)
if sample_weight is not None:
x, y = _rescale_data(x, y, sample_weight)
self.coef_ = sparse_group_lasso(
x, y, self.alpha, self.rho, self.groups, max_iter=self.max_iter, rtol=self.tol
)
self._set_intercept(X_offset, y_offset, X_scale)
return self