Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions src/together/lib/cli/api/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from click.core import ParameterSource # type: ignore[attr-defined]

from together import Together
from together.types import fine_tuning_estimate_price_params as pe_params
from together._types import NOT_GIVEN, NotGiven
from together.lib.utils import log_warn
from together.lib.utils.tools import format_timestamp, finetune_price_to_dollars
Expand All @@ -24,13 +25,21 @@

_CONFIRMATION_MESSAGE = (
"You are about to create a fine-tuning job. "
"The cost of your job will be determined by the model size, the number of tokens "
"The estimated price of this job is {price}. "
"The actual cost of your job will be determined by the model size, the number of tokens "
"in the training file, the number of tokens in the validation file, the number of epochs, and "
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
"the number of evaluations. Visit https://www.together.ai/pricing to learn more about pricing.\n"
"{warning}"
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
"Do you want to proceed?"
)

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of this job is significantly greater than your current credit limit and balance combined. "
"It will likely get cancelled due to insufficient funds. "
"Consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
)

_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"


Expand Down Expand Up @@ -323,7 +332,60 @@ def create(
elif n_evals > 0 and not validation_file:
raise click.BadParameter("You have specified a number of evaluation loops but no validation file.")

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
training_type_cls: pe_params.TrainingType
if lora:
training_type_cls = pe_params.TrainingTypeLoRaTrainingType(
lora_alpha=int(lora_alpha or 0),
lora_r=lora_r or 0,
lora_dropout=lora_dropout or 0,
lora_trainable_modules=lora_trainable_modules or "all-linear",
type="Lora",
)
else:
training_type_cls = pe_params.TrainingTypeFullTrainingType(
type="Full",
)

training_method_cls: pe_params.TrainingMethod
if training_method == "sft":
training_method_cls = pe_params.TrainingMethodTrainingMethodSft(
method="sft",
train_on_inputs=train_on_inputs or "auto",
)
else:
training_method_cls = pe_params.TrainingMethodTrainingMethodDpo(
method="dpo",
dpo_beta=dpo_beta or 0,
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length or False,
dpo_reference_free=False,
rpo_alpha=rpo_alpha or 0,
simpo_gamma=simpo_gamma or 0,
)

finetune_price_estimation_result = client.fine_tuning.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model or "",
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)
price = click.style(
f"${finetune_price_estimation_result.estimated_total_price:.2f}",
bold=True,
)
if not finetune_price_estimation_result.allowed_to_proceed:
warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
else:
warning = ""

confirmation_message = _CONFIRMATION_MESSAGE.format(
price=price,
warning=warning,
)

if confirm or click.confirm(confirmation_message, default=True, show_default=True):
response = client.fine_tuning.create(
**training_args,
verbose=True,
Expand Down
43 changes: 41 additions & 2 deletions src/together/lib/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from rich import print as rprint

from together.types import fine_tuning_estimate_price_params as pe_params
from together.lib.utils import log_warn_once

if TYPE_CHECKING:
Expand Down Expand Up @@ -66,7 +67,7 @@ def create_finetune_request(
hf_model_revision: str | None = None,
hf_api_token: str | None = None,
hf_output_repo_name: str | None = None,
) -> FinetuneRequest:
) -> tuple[FinetuneRequest, pe_params.TrainingType, pe_params.TrainingMethod]:
if model is not None and from_checkpoint is not None:
raise ValueError("You must specify either a model or a checkpoint to start a job from, not both")

Expand Down Expand Up @@ -233,8 +234,46 @@ def create_finetune_request(
hf_output_repo_name=hf_output_repo_name,
)

return finetune_request
training_type_pe, training_method_pe = create_price_estimation_params(finetune_request)

return finetune_request, training_type_pe, training_method_pe

def create_price_estimation_params(finetune_request: FinetuneRequest) -> tuple[pe_params.TrainingType, pe_params.TrainingMethod]:
training_type_cls: pe_params.TrainingType
if isinstance(finetune_request.training_type, FullTrainingType):
training_type_cls = pe_params.TrainingTypeFullTrainingType(
type="Full",
)
elif isinstance(finetune_request.training_type, LoRATrainingType):
training_type_cls = pe_params.TrainingTypeLoRaTrainingType(
lora_alpha=finetune_request.training_type.lora_alpha,
lora_r=finetune_request.training_type.lora_r,
lora_dropout=finetune_request.training_type.lora_dropout,
lora_trainable_modules=finetune_request.training_type.lora_trainable_modules,
type="Lora",
)
else:
raise ValueError(f"Unknown training type: {finetune_request.training_type}")

training_method_cls: pe_params.TrainingMethod
if isinstance(finetune_request.training_method, TrainingMethodSFT):
training_method_cls = pe_params.TrainingMethodTrainingMethodSft(
method="sft",
train_on_inputs=finetune_request.training_method.train_on_inputs,
)
elif isinstance(finetune_request.training_method, TrainingMethodDPO):
training_method_cls = pe_params.TrainingMethodTrainingMethodDpo(
method="dpo",
dpo_beta=finetune_request.training_method.dpo_beta or 0,
dpo_normalize_logratios_by_length=finetune_request.training_method.dpo_normalize_logratios_by_length,
dpo_reference_free=finetune_request.training_method.dpo_reference_free,
rpo_alpha=finetune_request.training_method.rpo_alpha or 0,
simpo_gamma=finetune_request.training_method.simpo_gamma or 0,
)
else:
raise ValueError(f"Unknown training method: {finetune_request.training_method}")

return training_type_cls, training_method_cls

def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits:
"""
Expand Down
63 changes: 59 additions & 4 deletions src/together/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@
async_to_custom_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..lib.types.fine_tuning import FinetuneResponse as FinetuneResponseLib, FinetuneTrainingLimits
from ..lib.types.fine_tuning import (
FinetuneResponse as FinetuneResponseLib,
FinetuneTrainingLimits,
)
from ..types.finetune_response import FinetuneResponse
from ..lib.resources.fine_tuning import get_model_limits, async_get_model_limits, create_finetune_request
from ..lib.resources.fine_tuning import (
get_model_limits,
async_get_model_limits,
create_finetune_request,
)
from ..types.fine_tuning_list_response import FineTuningListResponse
from ..types.fine_tuning_cancel_response import FineTuningCancelResponse
from ..types.fine_tuning_delete_response import FineTuningDeleteResponse
Expand All @@ -39,6 +46,12 @@

__all__ = ["FineTuningResource", "AsyncFineTuningResource"]

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of the fine-tuning job is {} which is significantly "
"greater than your current credit limit and balance combined. "
"It will likely get cancelled due to insufficient funds. "
"Proceed at your own risk."
)

class FineTuningResource(SyncAPIResource):
@cached_property
Expand Down Expand Up @@ -180,7 +193,7 @@ def create(
pass
model_limits = get_model_limits(self._client, str(model_name))

finetune_request = create_finetune_request(
finetune_request, training_type_cls, training_method_cls = create_finetune_request(
model_limits=model_limits,
training_file=training_file,
model=model,
Expand Down Expand Up @@ -219,11 +232,32 @@ def create(
hf_output_repo_name=hf_output_repo_name,
)


price_estimation_result = self.estimate_price(
training_file=training_file,
from_checkpoint=from_checkpoint or Omit(),
validation_file=validation_file or Omit(),
model=model or "",
n_epochs=finetune_request.n_epochs,
n_evals=finetune_request.n_evals or 0,
training_type=training_type_cls,
training_method=training_method_cls,
)


if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_estimation_result.allowed_to_proceed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

return self._client.post(
Expand Down Expand Up @@ -691,7 +725,7 @@ async def create(
pass
model_limits = await async_get_model_limits(self._client, str(model_name))

finetune_request = create_finetune_request(
finetune_request, training_type_cls, training_method_cls = create_finetune_request(
model_limits=model_limits,
training_file=training_file,
model=model,
Expand Down Expand Up @@ -730,11 +764,32 @@ async def create(
hf_output_repo_name=hf_output_repo_name,
)


price_estimation_result = await self.estimate_price(
training_file=training_file,
from_checkpoint=from_checkpoint or Omit(),
validation_file=validation_file or Omit(),
model=model or "",
n_epochs=finetune_request.n_epochs,
n_evals=finetune_request.n_evals or 0,
training_type=training_type_cls,
training_method=training_method_cls,
)


if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_estimation_result.allowed_to_proceed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price # pyright: ignore[reportPossiblyUnboundVariable]
)
+ "[/red]",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way all of this logic could be put inside the create_finetune_request method?

This should be fine - but this file is generated, so I think minimizing code within this file while help avoid conflicts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't access self within build_finetuning_request, perhaps any other way would cause conflicts in the future

parameter_payload = finetune_request.model_dump(exclude_none=True)

return await self._client.post(
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_fine_tuning_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


def test_simple_request():
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand All @@ -53,7 +53,7 @@ def test_simple_request():


def test_validation_file():
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand All @@ -73,7 +73,7 @@ def test_no_training_file():


def test_lora_request():
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand All @@ -93,7 +93,7 @@ def test_lora_request():
@pytest.mark.parametrize("lora_dropout", [-1, 0, 0.5, 1.0, 10.0])
def test_lora_request_with_lora_dropout(lora_dropout: float):
if 0 <= lora_dropout < 1:
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand All @@ -117,7 +117,7 @@ def test_lora_request_with_lora_dropout(lora_dropout: float):


def test_dpo_request_lora():
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand All @@ -136,7 +136,7 @@ def test_dpo_request_lora():


def test_dpo_request():
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand All @@ -150,7 +150,7 @@ def test_dpo_request():


def test_from_checkpoint_request():
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
training_file=_TRAINING_FILE,
from_checkpoint=_FROM_CHECKPOINT,
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_bad_training_method():

@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None])
def test_train_on_inputs_for_sft(train_on_inputs: Union[bool, Literal["auto"], None]):
request = create_finetune_request(
request, _, _ = create_finetune_request(
model_limits=_MODEL_LIMITS,
model=_MODEL_NAME,
training_file=_TRAINING_FILE,
Expand Down