snorkelflow.client.models.train_custom_model
- snorkelflow.client.models.train_custom_model(node, custom_cls, feature_fields, model_description, name, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)
Train a custom machine learning model.
- Parameters:
node (
int
) – UID of the node for which we’re training a modelcustom_cls (
Any
) – Class of the custom defined model with train, predict, load, and save defined.feature_fields (
List
[str
]) – List of field names to use as featuresmodel_description (
str
) – Description of the model shown on the Models pagename (
str
) – The name of the modeltraining_set (
Optional
[int
], default:None
) – Training set to train on. If None, defaults to latest.filter_unlabeled (
bool
, default:True
) – Remove data points not covered by LFs when training?filter_uncertain_labels (
bool
, default:True
) – Remove low-confidence data points when training?load_ground_truth (
bool
, default:True
) – Use ground truth labels when training model?train_on_dev_data (
bool
, default:False
) – Include the dev set when training model?predict_on_dev_data (
bool
, default:False
) – Predict on dev set only?discretize_labels (
bool
, default:True
) – Round probabilistic training labels to maximum probability class?max_runs (
int
, default:8
) – For hyperparameter search, maximum number of model configurations to samplesampler_config (
Optional
[Dict
[str
,Any
]], default:None
) – A dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see Sampler configs for data loading and model training. This option is only supported if discretize_labels is True.tune_threshold_on_valid (
bool
, default:False
) – For F1-tuned extraction models, tune the prediction threshold on valid set?use_lf_labels (
bool
, default:False
) – Use LF labels as features?lf_uids_to_use (
Optional
[List
[int
]], default:None
) – If set, and use_lf_labels is True, use only these LFs for model training.predict_for_all_train (
bool
, default:False
) – Predict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?apply_postprocessors (
bool
, default:True
) – If True, post processors are applied.sync (
bool
, default:True
) – Poll job status and block until complete?scheduler (
Optional
[str
], default:None
) – Dask scheduler (threads, client, or group) to use.
- Returns:
A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.
- Return type:
Dict[str, Any]
Examples
>>> from snorkelflow.models.cls_model import TrainedClassificationModelV2
>>> class CustomModel(TrainedClassificationModelV2):
>>> ...
>>>
>>> sf.train_custom_model(
>>> node=node_uid,
>>> custom_cls=CustomModel,
>>> feature_fields=["feature_1", "feature_2"],
>>> model_description="trainable model",
>>> name="CustomModel",
>>> training_set=training_set_uid,
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}