Skip to main content
Version: 0.93

snorkelflow.client.models.train_custom_model

snorkelflow.client.models.train_custom_model(node, custom_cls, feature_fields, model_description, name, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)

Train a custom machine learning model.

Parameters

NameTypeDefaultInfo
nodeintUID of the node for which we’re training a model.
custom_clsAnyClass of the custom defined model with train, predict, load, and save defined.
feature_fieldsList[str]List of field names to use as features.
model_descriptionstrDescription of the model shown on the Models page.
namestrThe name of the model.
training_setOptional[int]NoneTraining set to train on. If None, defaults to latest.
filter_unlabeledboolTrueRemove data points not covered by LFs when training?.
filter_uncertain_labelsboolTrueRemove low-confidence data points when training?.
load_ground_truthboolTrueUse ground truth labels when training model?.
train_on_dev_databoolFalseInclude the dev set when training model?.
predict_on_dev_databoolFalsePredict on dev set only?.
discretize_labelsboolTrueRound probabilistic training labels to maximum probability class?.
max_runsint8For hyperparameter search, maximum number of model configurations to sample.
sampler_configOptional[Dict[str, Any]]NoneA dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see Sampler configs for data loading and model training. This option is only supported if discretize_labels is True.
tune_threshold_on_validboolFalseFor F1-tuned extraction models, tune the prediction threshold on valid set?.
use_lf_labelsboolFalseUse LF labels as features?.
lf_uids_to_useOptional[List[int]]NoneIf set, and use_lf_labels is True, use only these LFs for model training.
predict_for_all_trainboolFalsePredict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?.
apply_postprocessorsboolTrueIf True, post processors are applied.
syncboolTruePoll job status and block until complete?.
schedulerOptional[str]NoneDask scheduler (threads, client, or group) to use.

Returns

A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.

Return type

Dict[str, Any]

Examples

>>> from snorkelflow.models.cls_model import TrainedClassificationModelV2
>>> class CustomModel(TrainedClassificationModelV2):
>>> ...
>>>
>>> sf.train_custom_model(
>>> node=node_uid,
>>> custom_cls=CustomModel,
>>> feature_fields=["feature_1", "feature_2"],
>>> model_description="trainable model",
>>> name="CustomModel",
>>> training_set=training_set_uid,
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}