Skip to main content
Version: 0.95

Warm start (SDK)

Warm start is the Snorkel Flow tool for getting your first training labels using state of the art foundation model (FM) based techniques. You can create labeling functions (LFs) by applying zero-shot learning (ZSL) and few-shot learning (FSL) on your data. This provides baseline LF results to feed into a machine learning model.

Most customers kick-off FM warm start models using the UI. However, warm start can also be triggered through the notebook. The SDK provides additional functionality and more control over the ZSL/FSL process, enabling you to specify:

  • Which data features to include in the model (default: all text columns)
  • Which data splits to run inference over (default: all active datasources)
  • Whether to output multiple unipolar LFs or a single multi-polar LF (default: multi-polar LF)
  • The base foundation model to use (default: our recommended model)
  • Whether to hide the warm start LFs and register the votes as a Snorkel Flow model (default: no)
note

FM warm start supports the following application types: Single-label text classification, multi-label text classification, and single-entity sequence tagging.

Running warm start

To begin, import our SDK and retrieve the relevant model node for your application:

import snorkelflow.client_v3 as sf
ctx = sf.SnorkelFlowContext.from_kwargs() # Import context to our current application

APP_NAME = "APPLICATION_NAME" # Name of your application
NODE_UID = sf.get_model_node(APP_NAME) # Model node used in the application

To retrieve the available warm start methods for your node, along with their descriptions, use the get_available_warm_start_methods() function. An example of this is shown below:

>>> results = sf.get_available_warm_start_methods(NODE_UID)
>>> print(results)
Note: All possible warm start methods for this node type are supported.
Methods:
zsl_prompt: A zero-shot learning method that uses the class names to
automatically generate a prompt template which is used to predict the label
of each datapoint. For example, if the class names are ['cat', 'dog', 'bird'],
we may ask the model `Is this a cat, dog, or bird?`, the model's output is
then used automatically to make a prediction. Recommended models include:
['google/flan-t5-large', 'google/flan-t5-base']

...

>>> print(results.methods["zsl_prompt"])
Description:
A zero-shot learning method that uses the class names to
automatically generate a prompt template which is used to predict the label
of each datapoint. For example, if the class names are ['cat', 'dog', 'bird'],
we may ask the model `Is this a cat, dog, or bird?`, the model's output is
then used automatically to make a prediction.
Any generic sequence-to-sequence huggingface model can be used for ZSL Prompt however, it's recommeded that the model included classification tasks in the training process.
Recommended Models:
google/flan-t5-large: FLAN (Fine-tuned LAnguage Net) T5 (Text-To-Text Transfer Transformer) is a transformer-based architecture where all NLP training tasks are reframed into a unified text-to-text format and learned jointly without the need for separate task-specific heads, as in BERT-style models. The large variant has 780M parameters.
google/flan-t5-base: FLAN (Fine-tuned LAnguage Net) T5 (Text-To-Text Transfer Transformer) is a transformer-based architecture where all NLP training tasks are reframed into a unified text-to-text format and learned jointly without the need for separate task-specific heads, as in BERT-style models. The base variant has 250M parameters.

From this, you'll be provided with a list of methods that you can use for warm start. We recommend running warm start over just the dev and valid dataset splits initially to gauge how well it performs before running on all of your data. To do this, pass ["dev", "valid"] to the run_warm_start method along with your node uid, warm start method name, and the base foundational model. More information about the available parameters and an example can be seen below:

Signature:
sf.run_warm_start (
node : int ,
warm_start_method : str ,
foundation_model : str ,
lf_name : str = 'Warm Start SDK' ,
splits : Union [ List [ str ] , NoneType ] = None ,
columns : Union [ List [ str ] , NoneType ] = None ,
one_lf_per_class : bool = False ,
sync : bool = True ,
allow_unsupported_foundation_model : bool = False ,
** additional_model_kwargs : Any ,
) -> str
Docstring:
Kick off an FM Warm Start job on a given node.

Parameters
----------
node
Node uid to run Warm Start on.
warm_start_method
The string identifier of the Warm Start method to use. If you are
unsure, try running
sf.get_available_warm_start_methods(NODE_UID) first to see
which methods are available for this node.
foundation_model
The Foundation Model to use within the Warm Start method. For
example, openai/gpt-4 or bigscience/T0pp
lf_name
Name of the created Labeling Function.
splits
Splits to run Warm Start inference over. Defaults to all, ["train",
"dev", "valid", "test"].
columns
Text fields to use in Warm Start. Defaults to all text columns
except in Information Extraction nodes where it defaults to
[LEFT_CONTEXT, SPAN_PREVIEW, RIGHT_CONTEXT].
one_lf_per_class
If True, a separate LF will be created for each class. This allows
for a finer level of control over the Warm Start output.
sync
If True, method will block until the Warm Start job is complete.
Note the job progress can always be be monitored within the 'In
Progress' LFs table on Studio or via sf.poll_job_status(job_uid).
allow_unsupported_foundation_model
If True, run Warm Start for the specified foundation_model even if
the model is not supported. This allows additional control to
bypass foundation model arg validation.
additional_model_kwargs
Additional kwargs to pass to the Foundation Model. For example, to
increase the context window size specify the `max_input_tokens`
kwarg.

Returns
-------
job_uid : str
The uid of the Warm Start job which can be used to monitor progress
with sf.poll_job_status(job_uid).

Example
-------
>>> sf.run_warm_start(NODE_UID, "zsl_prompt", "google/flan-t5-base")
Note the job progress can always be be monitored within the 'In Progress' LFs table on Studio or via sf.poll_job_status('XXX')

You can then monitor the progress within Develop (Studio) in the Labeling Functions tab, under In Progress. If the LF is not showing, try clicking the refresh button.

Running inference on new data

You can run an existing warm start model on new data through the SDK. To do so, the data source must be active on your application and included within a split. You can then make use of the following code snippet to kick off warm start inference on the specified splits:

>>> LF_UID = sf.get_lf_uid(node_uid, "Warm Start (SDK)")
>>> results = sf.run_lf_inference(NODE_UID, LF_UID, ["train", "test"])
Note the job progress can be monitored with sf.poll_job_status('XXX')
note

Run the empty SDK method followed by a question mark to see it's documentation. For example, sf.run_lf_inference?

Any data within the split that this warm start LF has already processed will be automatically retrieved from the cache. This saves compute time and allows for this method to be used for failure retries if say HuggingFace went down while the LF was processing.

Registering a warm start LF as a Snorkel Flow model

To help with error analysis, you may want to register the warm start LF as a model:

# set the only active LF to be the warm start LF
sf.archive_lfs(node_uid)
ctx.tdm_client.put(f"/nodes/{node_uid}/active-lfs/{warm_start_lf_uid}", json={"state": "active"})

# create a training set out of the LF
lf_pkg = sf.package_lfs(node_uid, "Warm Start LF Package")
rsp = sf.add_training_set(node_uid, lf_package_version=lf_pkg)

# register the warm start model
model_uid = sf.register_model(
node_uid,
description="Warm Start model predictions",
training_set=rsp["training_set_uid"],
name="Warm Start",
)

# add the LF predictions to the model
import pandas as pd
predictions = pd.concat([
sf.get_training_set(node_uid, rsp["training_set_uid"], split, include_data_columns=[])
for split in ["train", "dev", "valid", "test"]
])
sf.add_predictions(
node_uid,
x_uids=list(predictions.index),
predicted_labels=list(predictions["training_set_labels"]),
model_uid=model_uid,
)

# hide the warm start LF (this can also be done via the UI)
ctx.tdm_client.post("/user-settings", json={
"user_uid": user_uid,
"application_uid": app_uid,
"node_uid": node_uid,
"settings": {"global_preferences": {"should_show_warm_start_lfs": False}}
})

Remote warm start / GPT baseline

If you don't have access to local GPUs, you can utilize a remote model with an automatically generated ZSL prompt to initiate warm start for your problem. Please note that this requires setting up a third-party integration within Snorkel Flow. For more information, see Utilizing external models.

Given that this approach may involve third-party expenses when employing models from providers such as OpenAI, we recommend estimating the cost of running warm start on a portion or the entirety of your dataset. You can use the get_warm_start_cost command to provide an approximate USD total for the operation:

>>> results = sf.get_warm_start_cost(
node=NODE_UID,
configs=[{"method": "zsl_prompt_remote", "provider": "openai", "model": "openai/gpt-4"}],
splits=["dev", "valid"]
)
>>> results[0].additional_cost
63.20

Once you have linked your OpenAI account and are comfortable with incurring the specified third-party cost, you can initiate the remote warm start job using the same run_warm_start method that was explained at the beginning of this page. Be sure to specify the zsl_prompt_remote warm start method when doing so.

sf.run_warm_start(NODE_UID, "zsl_prompt_remote", "openai/gpt-4", splits=["dev"])

Progress can be tracked via the In Progress LFs table.