snorkelflow.client.studio.add_model_based_lfs
- snorkelflow.client.studio.add_model_based_lfs(node, model, model_name, feature_fields, **model_kwargs)
Performs the following operations in sequence: (a) trains model on the dev set, (b) saves model to minio at the specified path, (c) applies model to predict over the entire dataset, (d) generates labeling functions for all model output labels. Useful for generating model-based labeling functions.
Note: this method currently only supports classification end-models.
noteThe provided model must be a class instance that has the following four methods:class CustomModel:
def fit(self, X: np.ndarray, y: np.ndarray, **model_kwargs: Any) -> None:
# Fits the model on (X, y)
pass
def predict(self, X: np.ndarray) -> np.ndarray:
# Predicts y_hat (shape: (X.shape[0], 1)) from X
pass
def save(self, dirpath: str) -> None:
# Saves necessary model information at dirpath
pass
@classmethod
def load(cls, dirpath: str) -> Any:
# Loads necessary model information from dirpath
passParameters
Parameters
Raises
Raises
ValueError – If model does not have the desired methods.
ValueError – If the associated application is not a classification application.
Returns
Returns
The trained model.
Return type
Return type
Any
Name Type Default Info node intUID of the node. model AnyThe model to train. Must have methods fit, predict, save, and load. model_name strThe name of the model, used to name the prediction column and associated generated LFs. feature_fields Union[str, List[str]]The names of the input columns to train the model on. model_kwargs AnyAny other keyword arguments for model training.