Skip to main content
Version: 0.93

snorkelflow.sdk.ExternalModel

class snorkelflow.sdk.ExternalModel(external_model_name, column_mappings, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)

Bases: object

__init__

__init__(external_model_name, column_mappings, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)

Class to represent a trained 3rd party model. Returns from ExternalModelTrainer.finetune

Parameters

NameTypeDefaultInfo
external_model_namestrThe name generated by the external model finetuning job.
column_mappingsDict[FineTuningColumnType, str]The column mappings from FineTuningColumnType to the specific column names in the dataset.
finetuning_provider_typeFinetuningProvider<FinetuningProvider.AWS_SAGEMAKER: 'sagemaker'>The 3rd party training service. Defaults to FinetuningProvider.AWS_SAGEMAKER.

Methods

__init__(external_model_name, column_mappings)

Class to represent a trained 3rd party model.

inference(datasource_uids[, x_uids, ...])

3rd Party Inference using the finetuned model.

inference

inference(datasource_uids, x_uids=None, generation_config=None, deployment_config=None, prompt_template=None, sync=True)

3rd Party Inference using the finetuned model.

Parameters

NameTypeDefaultInfo
datasource_uidsList[int]The datasource uids to use for inference.
x_uidsOptional[List[str]]NoneOptional x_uids for filtering the datasources.
generation_configOptional[Dict[str, Any]]None

Optional generation configuration, e.g.:

{
"max_new_tokens": 300,
"top_k": 50,
"top_p": 0.8,
"do_sample": True,
"temperature": 1,
}
deployment_configOptional[Dict[str, Any]]None

Optional deployment configuration e.g.:

{
"instance_type": "ml.g5.12xlarge",
"instance_count": 1,
}
prompt_templateOptional[str]NoneOptional prompt template to LLM inference. Defaults to {instruction}
syncboolTrueWhether to wait for the inference job to complete before returning.

Returns

if (sync=True) returns the model source uid, else returns the job id

Return type

str