Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 150 additions & 3 deletions sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def deploy(
Automatically detects if the model is a Nova model and uses the appropriate
Bedrock API (create_custom_model for Nova, create_model_import_job for others).
For Nova models, also creates a custom model deployment for inference.
For OSS models, creates a model import job and waits for it to complete.
Once complete, the model is ready for on-demand inference. If provisioned
throughput is needed, use the separate create_provisioned_throughput() method.

Args:
job_name: Name for the model import job (non-Nova models only).
Expand All @@ -140,11 +143,12 @@ def deploy(

Returns:
Response from Bedrock API. For Nova models, returns the
create_custom_model_deployment response. For others, returns
the create_model_import_job response.
create_custom_model_deployment response. For OSS models, returns
the get_model_import_job response after the job completes.

Raises:
ValueError: If model_package is not set or required parameters are missing.
RuntimeError: If the import job fails or times out.
"""
if not self.model_package:
raise ValueError(
Expand Down Expand Up @@ -190,7 +194,16 @@ def deploy(
params = {k: v for k, v in params.items() if v is not None}

logger.info("Creating model import job for non-Nova deployment")
return self._get_bedrock_client().create_model_import_job(**params)
import_response = self._get_bedrock_client().create_model_import_job(**params)

job_arn = import_response.get("jobArn")
self._wait_for_import_job_complete(job_arn)

# Return the completed job details
job_details = self._get_bedrock_client().get_model_import_job(
jobIdentifier=job_arn
)
return job_details

def create_deployment(
self,
Expand Down Expand Up @@ -243,6 +256,140 @@ def create_deployment(

return response

def create_provisioned_throughput(
self,
model_id: str,
provisioned_model_name: str,
model_units: int = 1,
commitment_duration: Optional[str] = None,
tags: Optional[list] = None,
poll_interval: int = 60,
max_wait: int = 3600,
) -> Dict[str, Any]:
"""Create provisioned throughput for an imported model on Bedrock.

Calls CreateProvisionedModelThroughput and polls until the provisioned
throughput reaches InService status.

Args:
model_id: ARN or ID of the imported model.
provisioned_model_name: Name for the provisioned throughput resource.
model_units: Number of model units to provision. Defaults to 1.
commitment_duration: Commitment duration. Valid values: 'OneMonth',
'SixMonths'. If not provided, no commitment is set (on-demand).
tags: Tags for the provisioned throughput resource.
poll_interval: Seconds between status checks. Defaults to 60.
max_wait: Maximum seconds to wait. Defaults to 3600.

Returns:
Response from Bedrock create_provisioned_model_throughput API.

Raises:
RuntimeError: If the provisioned throughput fails or times out.
ValueError: If model_id or provisioned_model_name is not provided.
"""
if not model_id:
raise ValueError("model_id is required for create_provisioned_throughput.")
if not provisioned_model_name:
raise ValueError(
"provisioned_model_name is required for create_provisioned_throughput."
)

params = {
"modelId": model_id,
"provisionedModelName": provisioned_model_name,
"modelUnits": model_units,
}
if commitment_duration:
params["commitmentDuration"] = commitment_duration
if tags:
params["tags"] = tags

logger.info(
"Creating provisioned throughput '%s' for model %s with %d model units",
provisioned_model_name,
model_id,
model_units,
)
response = self._get_bedrock_client().create_provisioned_model_throughput(**params)

provisioned_model_arn = response.get("provisionedModelArn")
if provisioned_model_arn:
self._wait_for_provisioned_throughput_in_service(
provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait
)

return response

def _wait_for_import_job_complete(
self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600
):
"""Poll Bedrock until the model import job reaches Completed status.

Args:
job_arn: ARN of the model import job.
poll_interval: Seconds between status checks. Defaults to 60.
max_wait: Maximum seconds to wait. Defaults to 3600.

Raises:
RuntimeError: If the import job fails or times out.
"""
elapsed = 0
status = None
while elapsed < max_wait:
resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn)
status = resp.get("status")
logger.info("Import job status: %s (elapsed %ds)", status, elapsed)
if status == "Completed":
return
if status == "Failed":
failure_reason = resp.get("failureMessage", "Unknown")
raise RuntimeError(
f"Model import job {job_arn} failed. Reason: {failure_reason}"
)
time.sleep(poll_interval)
elapsed += poll_interval
raise RuntimeError(
f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. "
f"Last status: {status}"
)

def _wait_for_provisioned_throughput_in_service(
self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600
):
"""Poll Bedrock until provisioned throughput reaches InService status.

Args:
provisioned_model_arn: ARN of the provisioned model throughput.
poll_interval: Seconds between status checks. Defaults to 60.
max_wait: Maximum seconds to wait. Defaults to 3600.

Raises:
RuntimeError: If the provisioned throughput fails or times out.
"""
elapsed = 0
status = None
while elapsed < max_wait:
resp = self._get_bedrock_client().get_provisioned_model_throughput(
provisionedModelId=provisioned_model_arn
)
status = resp.get("status")
logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed)
if status == "InService":
return
if status == "Failed":
failure_reason = resp.get("failureMessage", "Unknown")
raise RuntimeError(
f"Provisioned throughput {provisioned_model_arn} failed. "
f"Reason: {failure_reason}"
)
time.sleep(poll_interval)
elapsed += poll_interval
raise RuntimeError(
f"Timed out after {max_wait}s waiting for provisioned throughput "
f"{provisioned_model_arn} to become InService. Last status: {status}"
)

def _wait_for_model_active(
self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600
):
Expand Down
Loading
Loading