Source code for sparsereg.model.group_lasso

from sklearn.base import RegressorMixin
from sklearn.linear_model.base import LinearModel, _rescale_data
from sklearn.utils.validation import check_X_y


from ..vendor.group_lasso.group_lasso import sparse_group_lasso


[docs]class SparseGroupLasso(LinearModel, RegressorMixin): def __init__( self, groups, alpha=1.0, rho=0.5, max_iter=1000, tol=1e-4, normalize=False, fit_intercept=True, copy_X=True, ): self.alpha = alpha self.rho = rho self.groups = groups self.max_iter = max_iter self.tol = tol self.normalize = normalize self.fit_intercept = fit_intercept self.copy_X = copy_X
[docs] def fit(self, x, y, sample_weight=None): x, y = check_X_y(x, y, accept_sparse=[], y_numeric=True, multi_output=False) x, y, X_offset, y_offset, X_scale = self._preprocess_data( x, y, fit_intercept=self.fit_intercept, normalize=self.normalize, copy=self.copy_X, sample_weight=sample_weight, ) if sample_weight is not None: x, y = _rescale_data(x, y, sample_weight) self.coef_ = sparse_group_lasso( x, y, self.alpha, self.rho, self.groups, max_iter=self.max_iter, rtol=self.tol ) self._set_intercept(X_offset, y_offset, X_scale) return self