Skip to main content
Version: 0.96

Supported modeling libraries

This page describes the three major modeling libraries that Snorkel Flow supports: Scikit-Learn, XGBoost, and HuggingFace’s Transformers. If you think that your application requires a custom library please reach out to a Snorkel representative.

Scikit-Learn

Snorkel Flow supports the logistic regression and k-nearest neighbors classifiers from the Scikit-Learn library, as set by the classifier_cls option. To configure the classifier, pass keyword arguments as a dictionary to the classifier_kwargs option, which is forwarded directly to the classifier initializer. See the Scikit-Learn logistic regression documentation and the Scikit-Learn k-nearest neighbors documentation for information about all supported keyword arguments.

To perform classification on text fields, we also support several Scikit-Learn vectorizers to transform the text into features. These are set by the vectorizer_cls option, which currently supports CountVectorizer, HashingVectorizer, and TfidfVectorizer. To configure the vectorizers, pass keyword arguments as a dictionary to the vectorizer_kwargs option, which is forwarded directly to the classifier initializer.

An example model configuration for Scikit-Learn is shown below:

{
"classifier": {
"classifier_cls": "LogisticRegression",
"classifier_kwargs": {
"C": 10,
"penalty": "l2",
"solver": "liblinear",
"random_state": 123
}
},
"text_vectorizer": {
"vectorizer_cls": "HashingVectorizer",
"vectorizer_kwargs": {
"ngram_range": [
1,
2
],
"n_features": 250000,
"lowercase": false
}
}
}

XGBoost

Snorkel Flow supports XGBoost for gradient boosted trees through the XGBClassifier in the Scikit-Learn Python API. You can set this using the classifier_cls option. Similar to logistic regression from the Scikit-Learn library, you can configure the XGBoost classifier by passing keyword arguments as a dictionary to the classifier_kwargs option. The dictionary is forwarded directly to the classifier initializer. See the XGBClassifier documentation for information about all supported keyword arguments.

To perform classification on text fields, the same vectorizers from the Scikit-Learn library are supported, as defined in the Scikit-Learn section.

An example model configuration for XGBoost is shown below:

{
"classifier": {
"classifier_cls": "XGBoostClassifier",
"classifier_kwargs": {
"n_estimators": 100,
"max_depth": null
}
},
"text_vectorizer": {
"vectorizer_cls": "CountVectorizer",
"vectorizer_kwargs": {
"ngram_range": [
1,
2
],
"max_features": 250000,
"lowercase": false
}
}
}

HuggingFace Transformers

Snorkel Flow supports pre-trained BERT classification models through the HuggingFace’s Transformers library. You can specify a pre-trained model with the pretrained_model_name option. The HuggingFace documentation lists the full set of options for pre-trained models.

tip

We recommend using DistilBERT models because of their smaller model size and lower compute cost.

To configure the HuggingFace Tokenizer, we support the tokenizer_kwargs option, which is passed to the initializer and supports the options that are available in the pretrained HuggingFace Tokenizers documentation. To configure the Adam Optimizer, we support the adamw_kwargs option, which is passed to the initializer and supports the options that are available in the PyTorch AdamW documentation.

tip

We recommend keeping the weight_decay = 0.0 within the AdamW optimizer because we do not want to apply weight decay to the bias and layer norm parameters. The external weight_decay option correctly excludes these parameters.

The remaining options are used to configure model training and inference, which we support via PyTorch. In particular, the freeze_bert_embeddings option enables you to train only the final classification layer. This option can yield good performance at lower computational cost.

An example model configuration for HuggingFace Transformers is shown below:

{
"pretrained_model_name": "distilbert-base-uncased",
"tokenizer_kwargs": {
"do_lower_case": true
},
"adamw_kwargs": {
"lr": 0.00005
},
"weight_decay": 0,
"freeze_bert_embeddings": true,
"max_sequence_length": 32,
"num_train_epochs": 1,
"train_batch_size": 64,
"max_grad_norm": 1,
"gradient_accumulation_steps": 1,
"lr_warmup_steps": 0,
"predict_batch_size": 64
}