snorkelflow.client.models.train_model
- snorkelflow.client.models.train_model(node, feature_fields, model_description, model_config, name=None, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)
Train a machine learning model.
Parameters
Parameters
Returns
Returns
A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.
Return type
Return type
Dict[str, Any]
Name Type Default Info node int
UID of the node for which we’re training a model. feature_fields List[str]
List of field names to use as features. model_description str
Description of the model shown on the Models page. model_config Dict[str, Any]
JSON configuration describing the model and hyperparameters. name Optional[str]
None
The name of the model. training_set Optional[int]
None
Training set to train on. If None, defaults to latest. filter_unlabeled bool
True
Remove data points not covered by LFs when training?. filter_uncertain_labels bool
True
Remove low-confidence data points when training?. load_ground_truth bool
True
Use ground truth labels when training model?. train_on_dev_data bool
False
Include the dev set when training model?. predict_on_dev_data bool
False
Predict on dev set only?. discretize_labels bool
True
Round probabilistic training labels to maximum probability class?. max_runs int
8
For hyperparameter search, maximum number of model configurations to sample. sampler_config Optional[Dict[str, Any]]
None
A dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see sampler-config. This option is only supported if discretize_labels is True. tune_threshold_on_valid bool
False
For F1-tuned extraction models, tune the prediction threshold on valid set?. use_lf_labels bool
False
Use LF labels as features?. lf_uids_to_use Optional[List[int]]
None
If set, and use_lf_labels is True, use only these LFs for model training. predict_for_all_train bool
False
Predict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?. apply_postprocessors bool
True
If True, post processors are applied. sync bool
True
Poll job status and block until complete?. scheduler Optional[str]
None
Dask scheduler (threads, client, or group) to use. Examples
>>> from snorkelflow.models.model_configs import SKLEARN_LOGISTIC_REGRESSION_CONFIG
>>> import snorkelflow.client as sf
>>> sf.train_model(
>>> node=node,
>>> feature_fields=["text"],
>>> model_description='Logistic Regression',
>>> model_config=SKLEARN_LOGISTIC_REGRESSION_CONFIG.dict()
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}