from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
[docs]class ColumnSelector(TransformerMixin, BaseEstimator):
def __init__(self, index=slice(None)):
self.index = index
self.n_features = None
[docs] def fit(self, x, y=None):
if len(x.shape) == 2:
_, self.n_features = x.shape
else:
self.n_features = x.shape[0]
return self
[docs] def get_feature_names(self, input_features=None):
input_features = input_features or ["x_{}".format(i) for i in range(self.n_features)]
if self.index == slice(None):
return input_features
else:
return [n for i, n in zip(self.index, input_features) if i]