snorkelflow.sdk.ExternalModelTrainer
- class snorkelflow.sdk.ExternalModelTrainer(column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)
Bases:
object
The ExternalModelTrainer class provides methods to finetune a 3rd party model.
ExternalModelTrainer Quickstart
ExternalModelTrainer Quickstart
In this quickstart, we will create a ExternalModelTrainer object, finetune a 3rd party model, and run inference on that model, saving the result as new datasources.
We will need the following imports
import pandas as pd
import snorkelflow.client as sf
from snorkelflow.sdk import Dataset
from snorkelflow.sdk.fine_tuning_app import (
ExternalModel,
ExternalModelTrainer,
)
from snorkelflow.types.finetuning import (
FineTuningColumnType,
FinetuningProvider,
)First set all your AWS secrets, making sure have you created a Sagemaker execution role with AmazonSageMakerFullAccess.
>>> AWS_ACCESS_KEY_ID = "aws::finetuning::access_key_id"
>>> AWS_SECRET_ACCESS_KEY = "aws::finetuning::secret_access_key"
>>> SAGEMAKER_EXECUTION_ROLE = "aws::finetuning::sagemaker_execution_role"
>>> FINETUNING_AWS_REGION = "aws::finetuning::region"
>>> sf.set_secret(AWS_ACCESS_KEY_ID, "<YOUR_ACCESS_KEY>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(AWS_SECRET_ACCESS_KEY, "<YOUR_SECRET_KEY>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(SAGEMAKER_EXECUTION_ROLE, "arn:aws:iam::<YOUR_EXECUTION_ROLE>", secret_store='local_store', workspace_uid=1, kwargs=None)
>>> sf.set_secret(FINETUNING_AWS_REGION, "<YOUR_REGION>", secret_store='local_store', workspace_uid=1, kwargs=None)Define your data and training settings.
df = pd.DataFrame({
"context_uid": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"context": ["", "", "", "", "", "", "", "", "", ""],
"prompt": [
"What is the capital of India?",
"What is the capital of France?",
"What is the capital of Nigeria?",
"What is the capital of Germany?",
"What is the capital of Japan?",
"What is the capital of Italy?",
"What is the capital of Brazil?",
"What is the capital of Spain?",
"What is the capital of Australia?",
"What is the capital of Mozambique?",
],
"response": [
"New Delhi",
"Paris",
"Abuja",
"Berlin",
"Tokyo",
"Rome",
"Brasília",
"Madrid",
"Canberra",
"Maputo"
],
})
training_configs = {
"instance_type": "ml.g5.12xlarge",
"instance_count": "1"
}
finetuning_configs = {
"epoch": "1",
"instruction_tuned": "True",
"peft_type": "lora",
"gradient_checkpointing": "False",
}
column_mapping = {
"instruction": FineTuningColumnType.INSTRUCTION,
"response": FineTuningColumnType.RESPONSE,
"context": FineTuningColumnType.CONTEXT,
}Create a new Dataset and Datasource.
>>> ft_dataset = Dataset.create("ExternalModelTrainer-Quickstart")
Successfully created dataset FineTuningModel-Quickstart with UID 0.
>>> datasource_uid = int(ft_dataset.create_datasource(df, uid_col="context_uid", split="train"))Now it’s time to check our AWS permissions and kick off a fine-tuning job
>>> external_model_trainer = ExternalModelTrainer(
>>> column_mappings=column_mapping,
>>> finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER,
>>> )
>>> external_model_trainer.check_finetuning_provider_authentication()
>>> external_model = external_model_trainer.finetune(
>>> base_model_id="meta-textgeneration-llama-3-8b",
>>> base_model_version="2.0.1",
>>> finetuning_configs=finetuning_configs,
>>> training_configs=training_configs,
>>> datasource_uids=[datasource_uid],
>>> sync=True,
>>> )
>>> assert isinstance(external_model, ExternalModel) # mypyNow that we have a finetuned model, we can run inference on it.
>>> model_source_uid = external_model.inference(
>>> datasource_uids=[datasource_uid],
>>> sync=True
>>> )Once the inference job completes, we can inspect the new datasources.
>>> datasources = sf.datasources.get_datasources(ft_dataset.uid)
>>> new_datasources = [ds for ds in datasources if ds['source_uid'] == model_source_uid]
>>> new_df = ft_dataset.get_dataframe(datasource_uid=new_datasources[0]['datasource_uid'])
>>> new_df.head()- __init__(column_mappings={'context': FineTuningColumnType.CONTEXT, 'instruction': FineTuningColumnType.INSTRUCTION, 'prompt_prefix': FineTuningColumnType.PROMPT_PREFIX, 'response': FineTuningColumnType.RESPONSE}, finetuning_provider_type=FinetuningProvider.AWS_SAGEMAKER)
Class to fine-tune an external model with SnorkelFlow data.
Parameters
Parameters
Name Type Default Info column_mappings Dict[str, FineTuningColumnType]
{'instruction': <FineTuningColumnType.INSTRUCTION: 'instruction'>, 'context': <FineTuningColumnType.CONTEXT: 'context'>, 'response': <FineTuningColumnType.RESPONSE: 'response'>, 'prompt_prefix': <FineTuningColumnType.PROMPT_PREFIX: 'prompt_prefix'>}
Mapping from FineTuningColumnType to the specific columns in the dataset. finetuning_provider_type FinetuningProvider
<FinetuningProvider.AWS_SAGEMAKER: 'sagemaker'>
The external model provider. Defaults to FinetuningProvider.AWS_SAGEMAKER.
\_\_init\_\_
__init__
Methods
__init__
([column_mappings, ...])Class to fine-tune an external model with SnorkelFlow data.
Check if the user has the necessary permissions to use the finetuning provider.
finetune
(base_model_id, finetuning_configs, ...)Finetune a 3rd Party Model
- check_finetuning_provider_authentication()
Check if the user has the necessary permissions to use the finetuning provider.
check\_finetuning\_provider\_authentication
check_finetuning_provider_authentication
- finetune(base_model_id, finetuning_configs, training_configs, datasource_uids, base_model_version='*', x_uids=None, sync=True)
Finetune a 3rd Party Model
Parameters
Parameters
Return type
Return type
Union
[ExternalModel
,str
]Returns
Returns
ExternalModel – if (sync=True) returns an ExternalModel object for running inference
str – if (sync=False) returns the job id without waiting for the job to finish
Name Type Default Info base_model_id str
The id of the base model. See available models: https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-latest.html finetuning_configs Dict[str, Any]
The finetuning hyperparameters, e.g.:
{
"epoch": "1",
"instruction_tuned": "True",
"validation_split_ratio": "0.1",
"max_input_length": "1024"
}training_configs Dict[str, Any]
The training infrastructure configurations, e.g.:
{
"instance_type": "ml.g5.12xlarge",
}datasource_uids List[int]
The datasource uids to use for finetuning. base_model_version str
'*'
The version of the base model. Defaults to “*” for latest version. See aws docs above for available versions. x_uids Optional[List[str]]
None
Optional x_uids for filtering the datasources. sync bool
True
Whether to wait for the finetuning job to complete before returning.
finetune
finetune