snorkelflow.client.studio.add_model_based_lfs
- snorkelflow.client.studio.add_model_based_lfs(node, model, model_name, feature_fields, **model_kwargs)
Performs the following operations in sequence: (a) trains model on the dev set, (b) saves model to minio at the specified path, (c) applies model to predict over the entire dataset, (d) generates labeling functions for all model output labels. Useful for generating model-based labeling functions.
Note: this method currently only supports classification end-models.
noteThe provided model must be a class instance that has the following four methods:class CustomModel:
def fit(self, X: np.ndarray, y: np.ndarray, **model_kwargs: Any) -> None:
# Fits the model on (X, y)
pass
def predict(self, X: np.ndarray) -> np.ndarray:
# Predicts y_hat (shape: (X.shape[0], 1)) from X
pass
def save(self, dirpath: str) -> None:
# Saves necessary model information at dirpath
pass
@classmethod
def load(cls, dirpath: str) -> Any:
# Loads necessary model information from dirpath
passParameters
Parameters
Raises
Raises
ValueError – If model does not have the desired methods.
ValueError – If the associated application is not a classification application.
Returns
Returns
The trained model.
Return type
Return type
Any
Name Type Default Info node int
UID of the node. model Any
The model to train. Must have methods fit, predict, save, and load. model_name str
The name of the model, used to name the prediction column and associated generated LFs. feature_fields Union[str, List[str]]
The names of the input columns to train the model on. model_kwargs Any
Any other keyword arguments for model training.