Skip to main content
Version: 0.93

snorkelflow.client.models.train_custom_model

snorkelflow.client.models.train_custom_model(node, custom_cls, feature_fields, model_description, name, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)

Train a custom machine learning model.

Parameters:
  • node (int) – UID of the node for which we’re training a model

  • custom_cls (Any) – Class of the custom defined model with train, predict, load, and save defined.

  • feature_fields (List[str]) – List of field names to use as features

  • model_description (str) – Description of the model shown on the Models page

  • name (str) – The name of the model

  • training_set (Optional[int], default: None) – Training set to train on. If None, defaults to latest.

  • filter_unlabeled (bool, default: True) – Remove data points not covered by LFs when training?

  • filter_uncertain_labels (bool, default: True) – Remove low-confidence data points when training?

  • load_ground_truth (bool, default: True) – Use ground truth labels when training model?

  • train_on_dev_data (bool, default: False) – Include the dev set when training model?

  • predict_on_dev_data (bool, default: False) – Predict on dev set only?

  • discretize_labels (bool, default: True) – Round probabilistic training labels to maximum probability class?

  • max_runs (int, default: 8) – For hyperparameter search, maximum number of model configurations to sample

  • sampler_config (Optional[Dict[str, Any]], default: None) – A dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see Sampler configs for data loading and model training. This option is only supported if discretize_labels is True.

  • tune_threshold_on_valid (bool, default: False) – For F1-tuned extraction models, tune the prediction threshold on valid set?

  • use_lf_labels (bool, default: False) – Use LF labels as features?

  • lf_uids_to_use (Optional[List[int]], default: None) – If set, and use_lf_labels is True, use only these LFs for model training.

  • predict_for_all_train (bool, default: False) – Predict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?

  • apply_postprocessors (bool, default: True) – If True, post processors are applied.

  • sync (bool, default: True) – Poll job status and block until complete?

  • scheduler (Optional[str], default: None) – Dask scheduler (threads, client, or group) to use.

Returns:

A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.

Return type:

Dict[str, Any]

Examples

>>> from snorkelflow.models.cls_model import TrainedClassificationModelV2
>>> class CustomModel(TrainedClassificationModelV2):
>>> ...
>>>
>>> sf.train_custom_model(
>>> node=node_uid,
>>> custom_cls=CustomModel,
>>> feature_fields=["feature_1", "feature_2"],
>>> model_description="trainable model",
>>> name="CustomModel",
>>> training_set=training_set_uid,
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}