Skip to main content
Version: 0.91

snorkelflow.client.models.train_model

snorkelflow.client.models.train_model(node, feature_fields, model_description, model_config, name=None, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)

Train a machine learning model.

Parameters

NameTypeDefaultInfo
nodeintUID of the node for which we’re training a model.
feature_fieldsList[str]List of field names to use as features.
model_descriptionstrDescription of the model shown on the Models page.
model_configDict[str, Any]JSON configuration describing the model and hyperparameters.
nameOptional[str]NoneThe name of the model.
training_setOptional[int]NoneTraining set to train on. If None, defaults to latest.
filter_unlabeledboolTrueRemove data points not covered by LFs when training?.
filter_uncertain_labelsboolTrueRemove low-confidence data points when training?.
load_ground_truthboolTrueUse ground truth labels when training model?.
train_on_dev_databoolFalseInclude the dev set when training model?.
predict_on_dev_databoolFalsePredict on dev set only?.
discretize_labelsboolTrueRound probabilistic training labels to maximum probability class?.
max_runsint8For hyperparameter search, maximum number of model configurations to sample.
sampler_configOptional[Dict[str, Any]]NoneA dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see Sampler configs for data loading and model training. This option is only supported if discretize_labels is True.
tune_threshold_on_validboolFalseFor F1-tuned extraction models, tune the prediction threshold on valid set?.
use_lf_labelsboolFalseUse LF labels as features?.
lf_uids_to_useOptional[List[int]]NoneIf set, and use_lf_labels is True, use only these LFs for model training.
predict_for_all_trainboolFalsePredict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?.
apply_postprocessorsboolTrueIf True, post processors are applied.
syncboolTruePoll job status and block until complete?.
schedulerOptional[str]NoneDask scheduler (threads, client, or group) to use.

Returns

A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.

Return type

Dict[str, Any]

Examples

>>> from snorkelflow.models.model_configs import SKLEARN_LOGISTIC_REGRESSION_CONFIG
>>> import snorkelflow.client as sf
>>> sf.train_model(
>>> node=node,
>>> feature_fields=["text"],
>>> model_description='Logistic Regression',
>>> model_config=SKLEARN_LOGISTIC_REGRESSION_CONFIG.dict()
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}