snorkelflow.client.models.train_custom_model
- snorkelflow.client.models.train_custom_model(node, custom_cls, feature_fields, model_description, name, training_set=None, filter_unlabeled=True, filter_uncertain_labels=True, load_ground_truth=True, train_on_dev_data=False, predict_on_dev_data=False, discretize_labels=True, max_runs=8, sampler_config=None, tune_threshold_on_valid=False, use_lf_labels=False, lf_uids_to_use=None, predict_for_all_train=False, apply_postprocessors=True, sync=True, scheduler=None)
Train a custom machine learning model.
Parameters
Parameters
Returns
Returns
A dictionary that looks like {“job_id”: “rq-XX-YYY”, “model_uid”: ZZ}. The value for model_uid will be None if sync=False.
Return type
Return type
Dict[str, Any]
Name Type Default Info node int
UID of the node for which we’re training a model. custom_cls Any
Class of the custom defined model with train, predict, load, and save defined. feature_fields List[str]
List of field names to use as features. model_description str
Description of the model shown on the Models page. name str
The name of the model. training_set Optional[int]
None
Training set to train on. If None, defaults to latest. filter_unlabeled bool
True
Remove data points not covered by LFs when training?. filter_uncertain_labels bool
True
Remove low-confidence data points when training?. load_ground_truth bool
True
Use ground truth labels when training model?. train_on_dev_data bool
False
Include the dev set when training model?. predict_on_dev_data bool
False
Predict on dev set only?. discretize_labels bool
True
Round probabilistic training labels to maximum probability class?. max_runs int
8
For hyperparameter search, maximum number of model configurations to sample. sampler_config Optional[Dict[str, Any]]
None
A dictionary with fields “strategy” (required), “params” (optional), and “class_counts” (optional) representing a sampler configuration. For details, see Sampler configs for data loading and model training. This option is only supported if discretize_labels is True. tune_threshold_on_valid bool
False
For F1-tuned extraction models, tune the prediction threshold on valid set?. use_lf_labels bool
False
Use LF labels as features?. lf_uids_to_use Optional[List[int]]
None
If set, and use_lf_labels is True, use only these LFs for model training. predict_for_all_train bool
False
Predict for all data points in the train set (True), or just ones that can be loaded into the dev set (False)?. apply_postprocessors bool
True
If True, post processors are applied. sync bool
True
Poll job status and block until complete?. scheduler Optional[str]
None
Dask scheduler (threads, client, or group) to use. Examples
>>> from snorkelflow.models.cls_model import TrainedClassificationModelV2
>>> class CustomModel(TrainedClassificationModelV2):
>>> ...
>>>
>>> sf.train_custom_model(
>>> node=node_uid,
>>> custom_cls=CustomModel,
>>> feature_fields=["feature_1", "feature_2"],
>>> model_description="trainable model",
>>> name="CustomModel",
>>> training_set=training_set_uid,
>>> )
{'job_id': <job_id>, 'model_uid': <model_uid>}