Skip to main content
Version: 0.95

snorkelflow.sdk.ExternalModel

class snorkelflow.sdk.ExternalModel(external_model_name, column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)

Bases: object

__init__(external_model_name, column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)

Class to represent a trained 3rd party model. Returns from ExternalModelTrainer.finetune

Parameters:
  • external_model_name (str) – The name generated by the external model finetuning job

  • column_mappings (Dict[str, FineTuningColumnType], default: {'instruction': <FineTuningColumnType.INSTRUCTION: 'instruction'>, 'context': <FineTuningColumnType.CONTEXT: 'context'>, 'response': <FineTuningColumnType.RESPONSE: 'response'>, 'prompt_prefix': <FineTuningColumnType.PROMPT_PREFIX: 'prompt_prefix'>}) – The column mappings from the dataset’s columns to the standard finetuning columns

  • finetuning_provider_type (FinetuningProvider, default: <FinetuningProvider.AWS_SAGEMAKER: 'sagemaker'>) – The 3rd party training service. Defaults to FinetuningProvider.AWS_SAGEMAKER.

Methods

__init__(external_model_name[, ...])

Class to represent a trained 3rd party model.

inference(datasource_uids[, x_uids, ...])

3rd Party Inference using the finetuned model.

inference(datasource_uids, x_uids=None, generation_config=None, deployment_config=None, prompt_template=None, sync=True)

3rd Party Inference using the finetuned model.

Parameters:
  • datasource_uids (List[int]) – The datasource uids to use for inference

  • x_uids (Optional[List[str]], default: None) – Optional x_uids for filtering the datasources

  • generation_config (Optional[Dict[str, Any]], default: None) –

    Optional generation configuration, e.g.:

    {
    "max_new_tokens": 300,
    "top_k": 50,
    "top_p": 0.8,
    "do_sample": True,
    "temperature": 1,
    }

  • deployment_config (Optional[Dict[str, Any]], default: None) –

    Optional deployment configuration e.g.:

    {
    "instance_type": "ml.g5.12xlarge",
    "instance_count": 1,
    }

  • prompt_template (Optional[str], default: None) – Optional prompt template to LLM inference. Defaults to {instruction}

  • sync (bool, default: True) – Whether to wait for the inference job to complete before returning

Returns:

if (sync=True) returns the model source uid, else returns the job id

Return type:

str