snorkelflow.sdk.ExternalModelTrainer
- class snorkelflow.sdk.ExternalModelTrainer(column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)
Bases:
object
The ExternalModelTrainer class provides methods to finetune a 3rd party model.
ExternalModelTrainer Quickstart
ExternalModelTrainer Quickstart
In this quickstart, we will create a ExternalModelTrainer object, finetune a 3rd party model, and run inference on that model, saving the result as new datasources.
We will need the following imports
import pandas as pd
import snorkelflow.client as sf
from snorkelflow.sdk import Dataset
from snorkelflow.sdk.fine_tuning_app import (
ExternalModel,
ExternalModelTrainer,
)
from snorkelflow.types.finetuning import (
FineTuningColumnType,
FinetuningProvider,
)First set all your AWS secrets, making sure have you created a Sagemaker execution role with AmazonSageMakerFullAccess.
>>> AWS_ACCESS_KEY_ID = "aws::finetuning::access_key_id"
>>> AWS_SECRET_ACCESS_KEY = "aws::finetuning::secret_access_key"
>>> SAGEMAKER_EXECUTION_ROLE = "aws::finetuning::sagemaker_execution_role"
>>> FINETUNING_AWS_REGION = "aws::finetuning::region"
>>> sf.set_secret(AWS_ACCESS_KEY_ID, "<YOUR_ACCESS_KEY>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(AWS_SECRET_ACCESS_KEY, "<YOUR_SECRET_KEY>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(SAGEMAKER_EXECUTION_ROLE, "arn:aws:iam::<YOUR_EXECUTION_ROLE>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(FINETUNING_AWS_REGION, "<YOUR_REGION>", secret_store='local_store', workspace_uid=1, kwargs=None)Define your data and training settings.
df = pd.DataFrame({
"context_uid": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"context": ["", "", "", "", "", "", "", "", "", ""],
"prompt": [
"What is the capital of India?",
"What is the capital of France?",
"What is the capital of Nigeria?",
"What is the capital of Germany?",
"What is the capital of Japan?",
"What is the capital of Italy?",
"What is the capital of Brazil?",
"What is the capital of Spain?",
"What is the capital of Australia?",
"What is the capital of Mozambique?",
],
"response": [
"New Delhi",
"Paris",
"Abuja",
"Berlin",
"Tokyo",
"Rome",
"Brasília",
"Madrid",
"Canberra",
"Maputo"
],
})
training_configs = {
"instance_type": "ml.g5.12xlarge",
"instance_count": "1"
}
finetuning_configs = {
"epoch": "1",
"instruction_tuned": "True",
"peft_type": "lora",
"gradient_checkpointing": "False",
}
column_mapping = {
"instruction": FineTuningColumnType.INSTRUCTION,
"response": FineTuningColumnType.RESPONSE,
"context": FineTuningColumnType.CONTEXT,
}Create a new Dataset and Datasource.
>>> ft_dataset = Dataset.create("ExternalModelTrainer-Quickstart")
Successfully created dataset FineTuningModel-Quickstart with UID 0.
>>> datasource_uid = int(ft_dataset.create_datasource(df, uid_col="context_uid", split="train"))Now it’s time to check our AWS permissions and kick off a fine-tuning job
>>> external_model_trainer = ExternalModelTrainer(
>>> column_mappings=column_mapping,
>>> finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER,
>>> )
>>> external_model_trainer.check_finetuning_provider_authentication()
>>> external_model = external_model_trainer.finetune(
>>> base_model_id="meta-textgeneration-llama-3-8b",
>>> base_model_version="2.0.1",
>>> finetuning_configs=finetuning_configs,
>>> training_configs=training_configs,
>>> datasource_uids=[datasource_uid],
>>> sync=True,
>>> )
>>> assert isinstance(external_model, ExternalModel) # mypyNow that we have a finetuned model, we can run inference on it.
>>> model_source_uid = external_model.inference(
>>> datasource_uids=[datasource_uid],
>>> sync=True
>>> )Once the inference job completes, we can inspect the new datasources.
>>> datasources = sf.datasources.get_datasources(ft_dataset.uid)
>>> new_datasources = [ds for ds in datasources if ds['source_uid'] == model_source_uid]
>>> new_df = ft_dataset.get_dataframe(datasource_uid=new_datasources[0]['datasource_uid'])
>>> new_df.head()- __init__(column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)
Class to fine-tune an external model with SnorkelFlow data.
- Parameters:
column_mappings (
Dict
[str
,FineTuningColumnType
], default:{'instruction': <FineTuningColumnType.INSTRUCTION: 'instruction'>, 'context': <FineTuningColumnType.CONTEXT: 'context'>, 'response': <FineTuningColumnType.RESPONSE: 'response'>, 'prompt_prefix': <FineTuningColumnType.PROMPT_PREFIX: 'prompt_prefix'>}
) – Mapping from FineTuningColumnType to the specific columns in the datasetfinetuning_provider_type (
FinetuningProvider
, default:<FinetuningProvider.AWS_SAGEMAKER: 'sagemaker'>
) – The external model provider. Defaults to FinetuningProvider.AWS_SAGEMAKER.
Methods
__init__
([column_mappings, ...])Class to fine-tune an external model with SnorkelFlow data.
Check if the user has the necessary permissions to use the finetuning provider.
finetune
(base_model_id, finetuning_configs, ...)Finetune a 3rd Party Model
- check_finetuning_provider_authentication()
Check if the user has the necessary permissions to use the finetuning provider.
- Returns:
True if the user has the necessary permissions, or raises an error with the missing permissions
- Return type:
bool
- finetune(base_model_id, finetuning_configs, training_configs, datasource_uids, base_model_version='*', x_uids=None, sync=True)
Finetune a 3rd Party Model
- Parameters:
base_model_id (
str
) – The id of the base model. See available models: https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-latest.htmlfinetuning_configs (
Dict
[str
,Any
]) –The finetuning hyperparameters, e.g.:
{
"epoch": "1",
"instruction_tuned": "True",
"validation_split_ratio": "0.1",
"max_input_length": "1024"
}training_configs (
Dict
[str
,Any
]) –The training infrastructure configurations, e.g.:
{
"instance_type": "ml.g5.12xlarge",
}datasource_uids (
List
[int
]) – The datasource uids to use for finetuningbase_model_version (
str
, default:'*'
) – The version of the base model. Defaults to “*” for latest version. See aws docs above for available versions.x_uids (
Optional
[List
[str
]], default:None
) – Optional x_uids for filtering the datasourcessync (
bool
, default:True
) – Whether to wait for the finetuning job to complete before returning
- Return type:
Union
[ExternalModel
,str
]- Returns:
ExternalModel – if (sync=True) returns an ExternalModel object for running inference
str – if (sync=False) returns the job id without waiting for the job to finish