Skip to main content
Version: 0.94

snorkelflow.sdk.ExternalModelTrainer

class snorkelflow.sdk.ExternalModelTrainer(column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)

Bases: object

The ExternalModelTrainer class provides methods to finetune a 3rd party model.

ExternalModelTrainer Quickstart

In this quickstart, we will create a ExternalModelTrainer object, finetune a 3rd party model, and run inference on that model, saving the result as new datasources.

We will need the following imports

import pandas as pd

import snorkelflow.client as sf
from snorkelflow.sdk import Dataset
from snorkelflow.sdk.fine_tuning_app import (
ExternalModel,
ExternalModelTrainer,
)
from snorkelflow.types.finetuning import (
FineTuningColumnType,
FinetuningProvider,
)

First set all your AWS secrets, making sure have you created a Sagemaker execution role with AmazonSageMakerFullAccess.

>>> AWS_ACCESS_KEY_ID = "aws::finetuning::access_key_id"
>>> AWS_SECRET_ACCESS_KEY = "aws::finetuning::secret_access_key"
>>> SAGEMAKER_EXECUTION_ROLE = "aws::finetuning::sagemaker_execution_role"
>>> FINETUNING_AWS_REGION = "aws::finetuning::region"
>>> sf.set_secret(AWS_ACCESS_KEY_ID, "<YOUR_ACCESS_KEY>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(AWS_SECRET_ACCESS_KEY, "<YOUR_SECRET_KEY>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(SAGEMAKER_EXECUTION_ROLE, "arn:aws:iam::<YOUR_EXECUTION_ROLE>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(FINETUNING_AWS_REGION, "<YOUR_REGION>", secret_store='local_store', workspace_uid=1, kwargs=None)

Define your data and training settings.

df = pd.DataFrame({
"context_uid": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"context": ["", "", "", "", "", "", "", "", "", ""],
"prompt": [
"What is the capital of India?",
"What is the capital of France?",
"What is the capital of Nigeria?",
"What is the capital of Germany?",
"What is the capital of Japan?",
"What is the capital of Italy?",
"What is the capital of Brazil?",
"What is the capital of Spain?",
"What is the capital of Australia?",
"What is the capital of Mozambique?",
],
"response": [
"New Delhi",
"Paris",
"Abuja",
"Berlin",
"Tokyo",
"Rome",
"Brasília",
"Madrid",
"Canberra",
"Maputo"
],
})
training_configs = {
"instance_type": "ml.g5.12xlarge",
"instance_count": "1"
}
finetuning_configs = {
"epoch": "1",
"instruction_tuned": "True",
"peft_type": "lora",
"gradient_checkpointing": "False",
}
column_mapping = {
"instruction": FineTuningColumnType.INSTRUCTION,
"response": FineTuningColumnType.RESPONSE,
"context": FineTuningColumnType.CONTEXT,
}

Create a new Dataset and Datasource.

>>> ft_dataset = Dataset.create("ExternalModelTrainer-Quickstart")
Successfully created dataset FineTuningModel-Quickstart with UID 0.
>>> datasource_uid = int(ft_dataset.create_datasource(df, uid_col="context_uid", split="train"))

Now it’s time to check our AWS permissions and kick off a fine-tuning job

>>> external_model_trainer = ExternalModelTrainer(
>>> column_mappings=column_mapping,
>>> finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER,
>>> )
>>> external_model_trainer.check_finetuning_provider_authentication()
>>> external_model = external_model_trainer.finetune(
>>> base_model_id="meta-textgeneration-llama-3-8b",
>>> base_model_version="2.0.1",
>>> finetuning_configs=finetuning_configs,
>>> training_configs=training_configs,
>>> datasource_uids=[datasource_uid],
>>> sync=True,
>>> )
>>> assert isinstance(external_model, ExternalModel) # mypy

Now that we have a finetuned model, we can run inference on it.

>>> model_source_uid = external_model.inference(
>>> datasource_uids=[datasource_uid],
>>> sync=True
>>> )

Once the inference job completes, we can inspect the new datasources.

>>> datasources = sf.datasources.get_datasources(ft_dataset.uid)
>>> new_datasources = [ds for ds in datasources if ds['source_uid'] == model_source_uid]
>>> new_df = ft_dataset.get_dataframe(datasource_uid=new_datasources[0]['datasource_uid'])
>>> new_df.head()
__init__(column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)

Class to fine-tune an external model with SnorkelFlow data.

Parameters:
  • column_mappings (Dict[str, FineTuningColumnType], default: {'instruction': <FineTuningColumnType.INSTRUCTION: 'instruction'>, 'context': <FineTuningColumnType.CONTEXT: 'context'>, 'response': <FineTuningColumnType.RESPONSE: 'response'>, 'prompt_prefix': <FineTuningColumnType.PROMPT_PREFIX: 'prompt_prefix'>}) – Mapping from FineTuningColumnType to the specific columns in the dataset

  • finetuning_provider_type (FinetuningProvider, default: <FinetuningProvider.AWS_SAGEMAKER: 'sagemaker'>) – The external model provider. Defaults to FinetuningProvider.AWS_SAGEMAKER.

Methods

__init__([column_mappings, ...])

Class to fine-tune an external model with SnorkelFlow data.

check_finetuning_provider_authentication()

Check if the user has the necessary permissions to use the finetuning provider.

finetune(base_model_id, finetuning_configs, ...)

Finetune a 3rd Party Model

check_finetuning_provider_authentication()

Check if the user has the necessary permissions to use the finetuning provider.

Returns:

True if the user has the necessary permissions, or raises an error with the missing permissions

Return type:

bool

finetune(base_model_id, finetuning_configs, training_configs, datasource_uids, base_model_version='*', x_uids=None, sync=True)

Finetune a 3rd Party Model

Parameters:
  • base_model_id (str) – The id of the base model. See available models: https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-latest.html

  • finetuning_configs (Dict[str, Any]) –

    The finetuning hyperparameters, e.g.:

    {
    "epoch": "1",
    "instruction_tuned": "True",
    "validation_split_ratio": "0.1",
    "max_input_length": "1024"
    }

  • training_configs (Dict[str, Any]) –

    The training infrastructure configurations, e.g.:

    {
    "instance_type": "ml.g5.12xlarge",
    }

  • datasource_uids (List[int]) – The datasource uids to use for finetuning

  • base_model_version (str, default: '*') – The version of the base model. Defaults to “*” for latest version. See aws docs above for available versions.

  • x_uids (Optional[List[str]], default: None) – Optional x_uids for filtering the datasources

  • sync (bool, default: True) – Whether to wait for the finetuning job to complete before returning

Return type:

Union[ExternalModel, str]

Returns:

  • ExternalModel – if (sync=True) returns an ExternalModel object for running inference

  • str – if (sync=False) returns the job id without waiting for the job to finish