Source code for sparsereg.util.pipeline

from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin


[docs]class ColumnSelector(TransformerMixin, BaseEstimator): def __init__(self, index=slice(None)): self.index = index self.n_features = None
[docs] def fit(self, x, y=None): if len(x.shape) == 2: _, self.n_features = x.shape else: self.n_features = x.shape[0] return self
[docs] def transform(self, x, y=None): xnew = x[..., self.index] if len(xnew.shape) == 2: return xnew else: return xnew.reshape(-1, 1)
[docs] def get_feature_names(self, input_features=None): input_features = input_features or ["x_{}".format(i) for i in range(self.n_features)] if self.index == slice(None): return input_features else: return [n for i, n in zip(self.index, input_features) if i]