snorkelflow.client.models.train_model
- snorkelflow.client.models.train_model(node, feature_fields, model_description, model_config, name=None, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)
Train a machine learning model.
Parameters
Parameters
Returns
Returns
A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.
Return type
Return type
Dict[str, Any]
Name Type Default Info node intUID of the node for which we’re training a model. feature_fields List[str]List of field names to use as features. model_description strDescription of the model shown on the Models page. model_config Dict[str, Any]JSON configuration describing the model and hyperparameters. name Optional[str]NoneThe name of the model. training_set Optional[int]NoneTraining set to train on. If None, defaults to latest. filter_unlabeled boolTrueRemove data points not covered by LFs when training?. filter_uncertain_labels boolTrueRemove low-confidence data points when training?. load_ground_truth boolTrueUse ground truth labels when training model?. train_on_dev_data boolFalseInclude the dev set when training model?. predict_on_dev_data boolFalsePredict on dev set only?. discretize_labels boolTrueRound probabilistic training labels to maximum probability class?. max_runs int8For hyperparameter search, maximum number of model configurations to sample. sampler_config Optional[Dict[str, Any]]NoneA dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see Sampler configs for data loading and model training. This option is only supported if discretize_labels is True. tune_threshold_on_valid boolFalseFor F1-tuned extraction models, tune the prediction threshold on valid set?. use_lf_labels boolFalseUse LF labels as features?. lf_uids_to_use Optional[List[int]]NoneIf set, and use_lf_labels is True, use only these LFs for model training. predict_for_all_train boolFalsePredict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?. apply_postprocessors boolTrueIf True, post processors are applied. sync boolTruePoll job status and block until complete?. scheduler Optional[str]NoneDask scheduler (threads, client, or group) to use. Examples
>>> from snorkelflow.models.model_configs import SKLEARN_LOGISTIC_REGRESSION_CONFIG
>>> import snorkelflow.client as sf
>>> sf.train_model(
>>> node=node,
>>> feature_fields=["text"],
>>> model_description='Logistic Regression',
>>> model_config=SKLEARN_LOGISTIC_REGRESSION_CONFIG.dict()
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}