snorkelflow.client.studio.add_model_based_lfs
- snorkelflow.client.studio.add_model_based_lfs(node, model, model_name, feature_fields, **model_kwargs)
Performs the following operations in sequence: (a) trains model on the dev set, (b) saves model to minio at the specified path, (c) applies model to predict over the entire dataset, (d) generates labeling functions for all model output labels. Useful for generating model-based labeling functions.
Note: this method currently only supports classification end-models.
noteThe provided model must be a class instance that has the following four methods:class CustomModel:
def fit(self, X: np.ndarray, y: np.ndarray, **model_kwargs: Any) -> None:
# Fits the model on (X, y)
pass
def predict(self, X: np.ndarray) -> np.ndarray:
# Predicts y_hat (shape: (X.shape[0], 1)) from X
pass
def save(self, dirpath: str) -> None:
# Saves necessary model information at dirpath
pass
@classmethod
def load(cls, dirpath: str) -> Any:
# Loads necessary model information from dirpath
pass- Parameters:
node (
int
) – UID of the nodemodel (
Any
) – The model to train. Must have methods fit, predict, save, and load.model_name (
str
) – The name of the model, used to name the prediction column and associated generated LFs.feature_fields (
Union
[str
,List
[str
]]) – The names of the input columns to train the model on.model_kwargs (
Any
) – Any other keyword arguments for model training.
- Raises:
ValueError – If model does not have the desired methods.
ValueError – If the associated application is not a classification application.
- Returns:
The trained model.
- Return type:
Any