From a0da02f79e2cfc0e9ae612a96125950c45dc296b Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 10:42:03 -0700 Subject: [PATCH 1/2] feat: add import job polling and provisioned throughput for Bedrock OSS deployments - deploy() for non-Nova models now waits for import job completion and returns job details (model ready for on-demand inference). - New public method: create_provisioned_throughput() with polling. - New private methods: _wait_for_import_job_complete(), _wait_for_provisioned_throughput_in_service(). - Added unit tests and integ tests (serial to avoid concurrent quota issues). - Mark bedrock integ tests as serial to avoid concurrent import job quota issues. X-AI-Prompt: add import polling and PT for bedrock OSS deployments X-AI-Tool: kiro-cli --- .gitignore | 1 + .../sagemaker/serve/bedrock_model_builder.py | 164 ++++++++- .../test_bedrock_provisioned_throughput.py | 313 +++++++++++++++++ .../test_model_customization_deployment.py | 1 + .../tests/unit/test_bedrock_model_builder.py | 314 +++++++++++++----- 5 files changed, 696 insertions(+), 97 deletions(-) create mode 100644 sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py diff --git a/.gitignore b/.gitignore index 811e8b5905..378048cdf0 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ sagemaker_train/src/**/container_drivers/distributed.json docs/api/generated/ .hypothesis .kiro +bedrock_api_logs/ diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index fc269343d4..786cea18b2 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -157,15 +157,15 @@ def deploy( For Nova models, also creates a custom model deployment for inference. Args: - job_name: Name for the model import job (non-Nova models only). - imported_model_name: Name for the imported model (non-Nova models only). + job_name: Name for the model import job (OSS models only). + imported_model_name: Name for the imported model (OSS models only). custom_model_name: Name for the custom model (Nova models only). role_arn: IAM role ARN with permissions for Bedrock operations. - job_tags: Tags for the import job (non-Nova models only). - imported_model_tags: Tags for the imported model (non-Nova models only). + job_tags: Tags for the import job (OSS models only). + imported_model_tags: Tags for the imported model (OSS models only). model_tags: Tags for the custom model (Nova models only). - client_request_token: Unique token for idempotency (non-Nova models only). - imported_model_kms_key_id: KMS key ID for encryption (non-Nova models only). + client_request_token: Unique token for idempotency (OSS models only). + imported_model_kms_key_id: KMS key ID for encryption (OSS models only). deployment_name: Name for the deployment (Nova models only). If not provided, defaults to custom_model_name suffixed with '-deployment'. @@ -238,15 +238,23 @@ def deploy( } params = {k: v for k, v in params.items() if v is not None} - logger.info("Creating model import job for non-Nova deployment") + logger.info("Creating model import job for OSS model deployment") print(f"[BedrockModelBuilder] Resolved S3 artifacts path: {self.s3_model_artifacts}") print(f"[BedrockModelBuilder] create_model_import_job params: {params}") - response = self._get_bedrock_client().create_model_import_job(**params) + import_response = self._get_bedrock_client().create_model_import_job(**params) logger.warning( - "Bedrock create_model_import_job request: %s, response: %s", params, response + "Bedrock create_model_import_job request: %s, response: %s", params, import_response ) - _log_bedrock_api_call("create_model_import_job", params, response) - return response + _log_bedrock_api_call("create_model_import_job", params, import_response) + + job_arn = import_response.get("jobArn") + self._wait_for_import_job_complete(job_arn) + + # Return the completed job details + job_details = self._get_bedrock_client().get_model_import_job( + jobIdentifier=job_arn + ) + return job_details def create_deployment( self, @@ -303,6 +311,140 @@ def create_deployment( return response + def create_provisioned_throughput( + self, + model_id: str, + provisioned_model_name: str, + model_units: int = 1, + commitment_duration: Optional[str] = None, + tags: Optional[list] = None, + poll_interval: int = 60, + max_wait: int = 3600, + ) -> Dict[str, Any]: + """Create provisioned throughput for an imported model on Bedrock. + + Calls CreateProvisionedModelThroughput and polls until the provisioned + throughput reaches InService status. + + Args: + model_id: ARN or ID of the imported model. + provisioned_model_name: Name for the provisioned throughput resource. + model_units: Number of model units to provision. Defaults to 1. + commitment_duration: Commitment duration. Valid values: 'OneMonth', + 'SixMonths'. If not provided, no commitment is set (on-demand). + tags: Tags for the provisioned throughput resource. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Returns: + Response from Bedrock create_provisioned_model_throughput API. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + ValueError: If model_id or provisioned_model_name is not provided. + """ + if not model_id: + raise ValueError("model_id is required for create_provisioned_throughput.") + if not provisioned_model_name: + raise ValueError( + "provisioned_model_name is required for create_provisioned_throughput." + ) + + params = { + "modelId": model_id, + "provisionedModelName": provisioned_model_name, + "modelUnits": model_units, + } + if commitment_duration: + params["commitmentDuration"] = commitment_duration + if tags: + params["tags"] = tags + + logger.info( + "Creating provisioned throughput '%s' for model %s with %d model units", + provisioned_model_name, + model_id, + model_units, + ) + response = self._get_bedrock_client().create_provisioned_model_throughput(**params) + + provisioned_model_arn = response.get("provisionedModelArn") + if provisioned_model_arn: + self._wait_for_provisioned_throughput_in_service( + provisioned_model_arn, poll_interval=poll_interval, max_wait=max_wait + ) + + return response + + def _wait_for_import_job_complete( + self, job_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until the model import job reaches Completed status. + + Args: + job_arn: ARN of the model import job. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the import job fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_model_import_job(jobIdentifier=job_arn) + status = resp.get("status") + logger.info("Import job status: %s (elapsed %ds)", status, elapsed) + if status == "Completed": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Model import job {job_arn} failed. Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for import job {job_arn} to complete. " + f"Last status: {status}" + ) + + def _wait_for_provisioned_throughput_in_service( + self, provisioned_model_arn: str, poll_interval: int = 60, max_wait: int = 3600 + ): + """Poll Bedrock until provisioned throughput reaches InService status. + + Args: + provisioned_model_arn: ARN of the provisioned model throughput. + poll_interval: Seconds between status checks. Defaults to 60. + max_wait: Maximum seconds to wait. Defaults to 3600. + + Raises: + RuntimeError: If the provisioned throughput fails or times out. + """ + elapsed = 0 + status = None + while elapsed < max_wait: + resp = self._get_bedrock_client().get_provisioned_model_throughput( + provisionedModelId=provisioned_model_arn + ) + status = resp.get("status") + logger.info("Provisioned throughput status: %s (elapsed %ds)", status, elapsed) + if status == "InService": + return + if status == "Failed": + failure_reason = resp.get("failureMessage", "Unknown") + raise RuntimeError( + f"Provisioned throughput {provisioned_model_arn} failed. " + f"Reason: {failure_reason}" + ) + time.sleep(poll_interval) + elapsed += poll_interval + raise RuntimeError( + f"Timed out after {max_wait}s waiting for provisioned throughput " + f"{provisioned_model_arn} to become InService. Last status: {status}" + ) + def _wait_for_model_active( self, model_arn: str, poll_interval: int = 60, max_wait: int = 3600 ): diff --git a/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py new file mode 100644 index 0000000000..aa4ee03cb8 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_bedrock_provisioned_throughput.py @@ -0,0 +1,313 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for BedrockModelBuilder import job polling and provisioned throughput.""" +from __future__ import absolute_import + +import json +import time +import random +import logging +from urllib.parse import urlparse + +import boto3 +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.resources import TrainingJob +from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder + +logger = logging.getLogger(__name__) + +AWS_REGION = "us-west-2" + + +@pytest.fixture(scope="module") +def training_job_name(): + """Training job name for testing (OSS model).""" + return "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445" + + +@pytest.fixture(scope="module") +def role_arn(): + """IAM role ARN with Bedrock permissions.""" + return get_execution_role() + + +@pytest.fixture(scope="module") +def bedrock_client(): + """Create Bedrock client.""" + return boto3.client("bedrock", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def s3_client(): + """Create S3 client.""" + return boto3.client("s3", region_name=AWS_REGION) + + +@pytest.fixture(scope="module") +def training_job(training_job_name): + """Get the training job.""" + return TrainingJob.get( + training_job_name=training_job_name, region=AWS_REGION + ) + + +def _setup_model_files(s3_artifacts_uri, s3_client): + """Setup required model files for Bedrock deployment. + + Bedrock model import requires HuggingFace-format files (config.json, + tokenizer.json, etc.) at the root of the S3 model artifacts path. + Training jobs often store these under checkpoints/hf_merged/, so we + copy them to the expected location. + + Args: + s3_artifacts_uri: The S3 URI that BedrockModelBuilder will use for import. + s3_client: boto3 S3 client. + """ + parsed = urlparse(s3_artifacts_uri) + bucket = parsed.netloc + base_prefix = parsed.path.lstrip("/").rstrip("/") + + hf_merged_prefix = f"{base_prefix}/checkpoints/hf_merged/" + root_prefix = f"{base_prefix}/" + + files_to_copy = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "model.safetensors", + ] + + for file in files_to_copy: + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + file) + logger.info("File already exists: s3://%s/%s%s", bucket, root_prefix, file) + except Exception: + try: + s3_client.copy_object( + Bucket=bucket, + CopySource={"Bucket": bucket, "Key": hf_merged_prefix + file}, + Key=root_prefix + file, + ) + logger.info("Copied %s to root", file) + except Exception as e: + logger.warning("Could not copy %s: %s", file, e) + + try: + s3_client.head_object(Bucket=bucket, Key=root_prefix + "added_tokens.json") + except Exception: + try: + s3_client.put_object( + Bucket=bucket, + Key=root_prefix + "added_tokens.json", + Body=json.dumps({}), + ContentType="application/json", + ) + logger.info("Created added_tokens.json") + except Exception as e: + logger.warning("Could not create added_tokens.json: %s", e) + + +@pytest.mark.serial +class TestBedrockImportJobPolling: + """Test import job polling for OSS models (Option C: deploy only waits for import).""" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._imported_model_arn = None + yield + self._cleanup() + + def _cleanup(self): + """Clean up Bedrock resources created during the test.""" + if self._imported_model_arn: + try: + logger.info("Deleting imported model: %s", self._imported_model_arn) + self._bedrock_client.delete_imported_model( + modelIdentifier=self._imported_model_arn + ) + except Exception as e: + logger.warning("Failed to delete imported model: %s", e) + + @pytest.mark.slow + def test_deploy_oss_model_waits_for_import_completion( + self, training_job, role_arn, bedrock_client, s3_client + ): + """Test that deploy() waits for import job to complete and returns job details. + + This test verifies that BedrockModelBuilder.deploy() for OSS models: + 1. Creates a model import job + 2. Polls until the import job reaches Completed status + 3. Returns the completed job details (model is ready for on-demand invoke) + 4. Does NOT create provisioned throughput + """ + builder = BedrockModelBuilder(model=training_job) + assert builder.s3_model_artifacts is not None + + _setup_model_files(builder.s3_model_artifacts, s3_client) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + job_name = f"test-import-poll-{suffix}" + imported_model_name = f"test-import-model-{suffix}" + + result = builder.deploy( + job_name=job_name, + imported_model_name=imported_model_name, + role_arn=role_arn, + ) + + # Verify the result is the completed job details + assert result["status"] == "Completed", ( + f"Expected Completed, got {result.get('status')}" + ) + assert "importedModelName" in result + assert "importedModelArn" in result or "jobArn" in result + + # Track for cleanup + self._imported_model_arn = result.get("importedModelArn") + + # Verify model can be found (it exists and is ready) + models = bedrock_client.list_imported_models() + model_names = [m["modelName"] for m in models.get("modelSummaries", [])] + assert imported_model_name in model_names + + +@pytest.mark.serial +class TestBedrockProvisionedThroughput: + """Test create_provisioned_throughput as a standalone method. + + Uses a pre-existing Bedrock custom model (fine-tuned Llama 3.1 8B) to test + provisioned throughput creation and polling. The custom model was created via + Bedrock CreateModelCustomizationJob and persists in the CI account. + + Prerequisites: + - Account 729646638167, us-west-2 + - PT MU quota for Llama 3.1 8B (requested via Matador/Bedrock team) + - A pre-existing custom model (see below for how to recreate) + + How to recreate the custom model if it gets deleted: + + 1. Ensure training data exists at: + s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl + + If not, create it (minimal JSONL with prompt/completion pairs): + echo '{"prompt":"What is ML?","completion":"ML is a subset of AI."}' > /tmp/train.jsonl + aws s3 cp /tmp/train.jsonl s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl + + 2. Create the fine-tuning job: + aws bedrock create-model-customization-job \\ + --job-name test-llama31-8b-pt-integ \\ + --custom-model-name test-llama31-8b-pt-model \\ + --role-arn arn:aws:iam::729646638167:role/Admin \\ + --base-model-identifier meta.llama3-1-8b-instruct-v1:0:128k \\ + --customization-type FINE_TUNING \\ + --training-data-config '{"s3Uri":"s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl"}' \\ + --output-data-config '{"s3Uri":"s3://mc-flows-sdk-testing/pt-test-output/"}' \\ + --hyper-parameters '{"epochCount":"1","batchSize":"1","learningRate":"0.00001"}' \\ + --region us-west-2 + + 3. Wait for the job to complete (~2-4 hours for 8B model): + aws bedrock get-model-customization-job \\ + --job-identifier --region us-west-2 \\ + --query "status" + + 4. Update CUSTOM_MODEL_ARN below with the outputModelArn from the job. + """ + + # Pre-existing custom model created via Bedrock fine-tuning. + # Base model: meta.llama3-1-8b-instruct-v1:0:128k + # This model must exist in account 729646638167, us-west-2. + CUSTOM_MODEL_ARN = ( + "arn:aws:bedrock:us-west-2:729646638167:custom-model/" + "meta.llama3-1-8b-instruct-v1:0:128k/k2mjykwgn62p" + ) + CUSTOM_MODEL_NAME = "test-llama31-8b-pt-model" + + @pytest.fixture(autouse=True) + def _setup(self, bedrock_client): + """Store bedrock client and track resources for cleanup.""" + self._bedrock_client = bedrock_client + self._provisioned_model_arn = None + yield + # Always clean up PT, even if test fails + self._cleanup() + + def _cleanup(self): + """Clean up provisioned throughput created during the test.""" + if self._provisioned_model_arn: + try: + logger.info("Deleting provisioned throughput: %s", self._provisioned_model_arn) + self._bedrock_client.delete_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + logger.info("Provisioned throughput deleted successfully.") + except Exception as e: + logger.warning("Failed to delete provisioned throughput: %s", e) + + @pytest.mark.slow + def test_create_provisioned_throughput(self, bedrock_client): + """Test create_provisioned_throughput() with a pre-existing custom model. + + This test verifies: + 1. Calls CreateProvisionedModelThroughput with a custom model ARN + 2. Polls until provisioned throughput reaches InService + 3. Returns the provisioned throughput response + 4. Cleans up the PT after the test + """ + # Check if the pre-existing custom model exists + try: + bedrock_client.get_custom_model(modelIdentifier=self.CUSTOM_MODEL_ARN) + except Exception: + pytest.skip( + f"Pre-existing custom model not found: {self.CUSTOM_MODEL_ARN}. " + f"Recreate it with: aws bedrock create-model-customization-job " + f"--job-name test-llama31-8b-pt-integ " + f"--custom-model-name {self.CUSTOM_MODEL_NAME} " + f"--role-arn " + f"--base-model-identifier meta.llama3-1-8b-instruct-v1:0:128k " + f"--customization-type FINE_TUNING " + f"--training-data-config '{{\"s3Uri\":\"s3://mc-flows-sdk-testing/pt-test-data/train_llama31.jsonl\"}}' " + f"--output-data-config '{{\"s3Uri\":\"s3://mc-flows-sdk-testing/pt-test-output/\"}}' " + f"--hyper-parameters '{{\"epochCount\":\"1\",\"batchSize\":\"1\",\"learningRate\":\"0.00001\"}}' " + f"--region us-west-2" + ) + + suffix = f"{int(time.time())}-{random.randint(1000, 9999)}" + provisioned_model_name = f"test-pt-integ-{suffix}" + + builder = BedrockModelBuilder(model=None) + + # Create provisioned throughput + pt_result = builder.create_provisioned_throughput( + model_id=self.CUSTOM_MODEL_ARN, + provisioned_model_name=provisioned_model_name, + model_units=1, + ) + + # Verify result contains provisioned model ARN + assert "provisionedModelArn" in pt_result, ( + f"Expected 'provisionedModelArn' in result, got keys: {list(pt_result.keys())}" + ) + self._provisioned_model_arn = pt_result["provisionedModelArn"] + + # Verify provisioned throughput is InService (create_provisioned_throughput + # already polls until InService, but double-check) + pt_response = bedrock_client.get_provisioned_model_throughput( + provisionedModelId=self._provisioned_model_arn + ) + assert pt_response["status"] == "InService", ( + f"Expected InService, got {pt_response['status']}" + ) diff --git a/sagemaker-serve/tests/integ/test_model_customization_deployment.py b/sagemaker-serve/tests/integ/test_model_customization_deployment.py index b38ca249c7..5b22c16851 100644 --- a/sagemaker-serve/tests/integ/test_model_customization_deployment.py +++ b/sagemaker-serve/tests/integ/test_model_customization_deployment.py @@ -307,6 +307,7 @@ def test_dpo_trainer_build(self, training_job_name, sagemaker_session): from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder +@pytest.mark.serial class TestModelCustomizationDeployment: """Test suite for deploying fine-tuned models to Bedrock.""" diff --git a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py index 3fdfaa01a3..6a5d54ca18 100644 --- a/sagemaker-serve/tests/unit/test_bedrock_model_builder.py +++ b/sagemaker-serve/tests/unit/test_bedrock_model_builder.py @@ -69,7 +69,7 @@ def test_nova_via_recipe_name(self): def test_nova_via_hub_content_name(self): assert _is_nova_model(_make_container(hub_content_name="amazon-nova-lite")) is True - def test_non_nova(self): + def test_oss(self): assert _is_nova_model(_make_container(recipe_name="llama-3-8b", hub_content_name="llama")) is False def test_no_base_model(self): @@ -95,8 +95,7 @@ def test_none_model(self): def test_with_model(self): m = Mock() with patch.object(BedrockModelBuilder, "_fetch_model_package", return_value=Mock()), \ - patch.object(BedrockModelBuilder, "_get_s3_artifacts", return_value="s3://b/k"), \ - patch(f"{MODULE}.is_restricted_model_package", return_value=False): + patch.object(BedrockModelBuilder, "_get_s3_artifacts", return_value="s3://b/k"): b = BedrockModelBuilder(model=m) assert b.model is m assert b.s3_model_artifacts == "s3://b/k" @@ -210,13 +209,13 @@ def test_none_when_no_model_package(self): b.model_package = None assert b._get_s3_artifacts() is None - def test_non_nova_returns_s3_uri(self): + def test_oss_returns_s3_uri(self): c = _make_container(recipe_name="llama", hub_content_name="llama", s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) assert b._get_s3_artifacts() == "s3://b/m.tar.gz" - def test_non_nova_no_data_source(self): + def test_oss_no_data_source(self): c = _make_container(recipe_name="llama", hub_content_name="llama") b = _builder() b.model_package = _make_model_package(c) @@ -470,15 +469,48 @@ def test_timeout_raises(self): class TestDeploy: - def test_non_nova(self): + def test_oss_waits_for_import_and_returns_job_details(self): + """OSS deploy: import job → wait → return job details.""" c = _make_container(s3_uri="s3://b/m.tar.gz") b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/m.tar.gz" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} - result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") - assert result == {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "my-imported-model", + "importedModelArn": "arn:aws:bedrock:us-west-2:123:imported-model/abc", + } + + with patch(f"{MODULE}.time.sleep"): + result = b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_model_import_job.assert_called_once() + b._bedrock_client.get_model_import_job.assert_called() + # Should NOT call create_provisioned_model_throughput + b._bedrock_client.create_provisioned_model_throughput.assert_not_called() + assert result["status"] == "Completed" + assert result["importedModelName"] == "my-imported-model" + + def test_oss_does_not_create_provisioned_throughput(self): + """deploy() for OSS models should never call CreateProvisionedModelThroughput.""" + c = _make_container(s3_uri="s3://b/m.tar.gz") + b = _builder() + b.model_package = _make_model_package(c) + b.s3_model_artifacts = "s3://b/m.tar.gz" + b._bedrock_client = Mock() + b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn:job"} + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "m", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + + b._bedrock_client.create_provisioned_model_throughput.assert_not_called() + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() def test_nova_full_chain(self): c = _make_container(recipe_name="nova-micro", hub_content_name="nova") @@ -573,116 +605,226 @@ def test_nova_missing_role_arn_raises(self): with pytest.raises(ValueError, match="role_arn is required"): b.deploy(custom_model_name="m") - def test_non_nova_strips_none_params(self): + def test_oss_strips_none_params(self): c = _make_container() b = _builder() b.model_package = _make_model_package(c) b.s3_model_artifacts = "s3://b/k" b._bedrock_client = Mock() b._bedrock_client.create_model_import_job.return_value = {"jobArn": "arn"} - b.deploy(job_name="j", imported_model_name="m", role_arn="r") + b._bedrock_client.get_model_import_job.return_value = { + "status": "Completed", + "importedModelName": "m", + } + + with patch(f"{MODULE}.time.sleep"): + b.deploy(job_name="j", imported_model_name="m", role_arn="r") + kw = b._bedrock_client.create_model_import_job.call_args[1] assert "importedModelKmsKeyId" not in kw assert "clientRequestToken" not in kw - def test_nova_rmp_uses_model_package_arn_data_source(self): - """When model package is RMP, use customModelDataSource.""" - c = _make_container(recipe_name="nova-lite") + +# ── _wait_for_import_job_complete ─────────────────────────────────────────── + + +class TestWaitForImportJobComplete: + def test_immediate_completed(self): b = _builder() - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/my-rmp/1" - pkg.managed_storage_type = "Restricted" - b.model_package = pkg - b._is_rmp = True - b.s3_model_artifacts = None b._bedrock_client = Mock() - b._bedrock_client.create_custom_model.return_value = {"modelArn": "arn:m"} - b._bedrock_client.get_custom_model.return_value = {"modelStatus": "Active"} - b._bedrock_client.create_custom_model_deployment.return_value = { - "customModelDeploymentArn": "arn:dep" + b._bedrock_client.get_model_import_job.return_value = {"status": "Completed"} + b._wait_for_import_job_complete("arn:job") + b._bedrock_client.get_model_import_job.assert_called_once_with( + jobIdentifier="arn:job" + ) + + def test_polls_then_completed(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.side_effect = [ + {"status": "InProgress"}, + {"status": "InProgress"}, + {"status": "Completed"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=10) + assert b._bedrock_client.get_model_import_job.call_count == 3 + + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = { + "status": "Failed", + "failureMessage": "Invalid model format", } - b._bedrock_client.get_custom_model_deployment.return_value = {"status": "Active"} + with pytest.raises(RuntimeError, match="Invalid model format"): + b._wait_for_import_job_complete("arn:job") - b.deploy(custom_model_name="rmp-test", role_arn="r") - kw = b._bedrock_client.create_custom_model.call_args[1] - assert "customModelDataSource" in kw - assert kw["customModelDataSource"]["modelPackageArnDataSource"]["modelPackageArn"] == ( - "arn:aws:sagemaker:us-east-1:123456789012:model-package/my-rmp/1" + def test_failed_unknown_reason(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "Failed"} + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_import_job_complete("arn:job") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_model_import_job.return_value = {"status": "InProgress"} + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_import_job_complete("arn:job", poll_interval=1, max_wait=2) + + +# ── create_provisioned_throughput ─────────────────────────────────────────── + + +class TestCreateProvisionedThroughput: + def test_creates_and_polls(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + + result = b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="my-pt" + ) + + b._bedrock_client.create_provisioned_model_throughput.assert_called_once_with( + modelId="arn:model", + provisionedModelName="my-pt", + modelUnits=1, ) - assert "modelSourceConfig" not in kw + b._bedrock_client.get_provisioned_model_throughput.assert_called_once() + assert result["provisionedModelArn"] == "arn:pt" - def test_nova_s3_uri_uses_model_source_config(self): - """When model package is not RMP, use modelSourceConfig (existing path).""" - c = _make_container(recipe_name="nova-lite", s3_uri="s3://bucket/checkpoint/step_10/") + def test_passes_commitment_duration(self): b = _builder() - pkg = _make_model_package(c) - pkg.managed_storage_type = None - b.model_package = pkg - b.s3_model_artifacts = "s3://bucket/checkpoint/step_10/" b._bedrock_client = Mock() - b._bedrock_client.create_custom_model.return_value = {"modelArn": "arn:m"} - b._bedrock_client.get_custom_model.return_value = {"modelStatus": "Active"} - b._bedrock_client.create_custom_model_deployment.return_value = { - "customModelDeploymentArn": "arn:dep" + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" } - b._bedrock_client.get_custom_model_deployment.return_value = {"status": "Active"} - b.deploy(custom_model_name="s3-test", role_arn="r") - kw = b._bedrock_client.create_custom_model.call_args[1] - assert "modelSourceConfig" in kw - assert kw["modelSourceConfig"]["s3DataSource"]["s3Uri"] == "s3://bucket/checkpoint/step_10/" - assert "customModelDataSource" not in kw + b.create_provisioned_throughput( + model_id="arn:model", + provisioned_model_name="pt", + model_units=5, + commitment_duration="OneMonth", + ) + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["modelUnits"] == 5 + assert kw["commitmentDuration"] == "OneMonth" -# ── _get_s3_artifacts RMP detection ─────────────────────────────────────── + def test_passes_tags(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = { + "provisionedModelArn": "arn:pt" + } + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + tags = [{"Key": "team", "Value": "ml"}] + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt", tags=tags + ) -class TestGetS3ArtifactsRMP: - def test_nova_rmp_returns_none(self): - """When model package is RMP (s3_uri is None), return None.""" - c = _make_container(recipe_name="nova-lite") - s3_data = Mock() - s3_data.s3_uri = None - data_source = Mock() - data_source.s3_data_source = s3_data - c.model_data_source = data_source + kw = b._bedrock_client.create_provisioned_model_throughput.call_args[1] + assert kw["tags"] == tags - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/rmp/1" - pkg.managed_storage_type = "Restricted" + def test_skips_polling_when_no_arn_in_response(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.create_provisioned_model_throughput.return_value = {} + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="pt" + ) + b._bedrock_client.get_provisioned_model_throughput.assert_not_called() + + def test_empty_model_id_raises(self): b = _builder() - b.model = "not-a-training-job" - b.model_package = pkg - result = b._get_s3_artifacts() - assert result is None + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id="", provisioned_model_name="pt") + + def test_none_model_id_raises(self): + b = _builder() + with pytest.raises(ValueError, match="model_id is required"): + b.create_provisioned_throughput(model_id=None, provisioned_model_name="pt") + + def test_empty_provisioned_model_name_raises(self): + b = _builder() + with pytest.raises(ValueError, match="provisioned_model_name is required"): + b.create_provisioned_throughput( + model_id="arn:model", provisioned_model_name="" + ) + - def test_nova_rmp_no_data_source_returns_none(self): - """When model_data_source is None and managed_storage_type is Restricted, return None.""" - c = _make_container(recipe_name="nova-lite") - c.model_data_source = None +# ── _wait_for_provisioned_throughput_in_service ───────────────────────────── - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/rmp/2" - pkg.managed_storage_type = "Restricted" +class TestWaitForProvisionedThroughputInService: + def test_immediate_in_service(self): b = _builder() - b.model = "not-a-training-job" - b.model_package = pkg - result = b._get_s3_artifacts() - assert result is None + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "InService" + } + b._wait_for_provisioned_throughput_in_service("arn:pt") + b._bedrock_client.get_provisioned_model_throughput.assert_called_once_with( + provisionedModelId="arn:pt" + ) - def test_non_nova_rmp_returns_none(self): - """Non-Nova RMP models should also return None.""" - c = _make_container(recipe_name="llama", hub_content_name="llama") - c.model_data_source = None + def test_polls_then_in_service(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.side_effect = [ + {"status": "Creating"}, + {"status": "Creating"}, + {"status": "InService"}, + ] + with patch(f"{MODULE}.time.sleep"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=10 + ) + assert b._bedrock_client.get_provisioned_model_throughput.call_count == 3 - pkg = _make_model_package(c) - pkg.model_package_arn = "arn:aws:sagemaker:us-east-1:123456789012:model-package/rmp/1" - pkg.managed_storage_type = "Restricted" + def test_failed_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed", + "failureMessage": "Insufficient capacity", + } + with pytest.raises(RuntimeError, match="Insufficient capacity"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + def test_failed_unknown_reason(self): b = _builder() - b.model = "not-a-training-job" - b.model_package = pkg - result = b._get_s3_artifacts() - assert result is None + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Failed" + } + with pytest.raises(RuntimeError, match="Unknown"): + b._wait_for_provisioned_throughput_in_service("arn:pt") + + def test_timeout_raises(self): + b = _builder() + b._bedrock_client = Mock() + b._bedrock_client.get_provisioned_model_throughput.return_value = { + "status": "Creating" + } + with patch(f"{MODULE}.time.sleep"): + with pytest.raises(RuntimeError, match="Timed out"): + b._wait_for_provisioned_throughput_in_service( + "arn:pt", poll_interval=1, max_wait=2 + ) From f0293cff4ac10d21b54883ab0b2124cc4c00ac55 Mon Sep 17 00:00:00 2001 From: Lucas Jia Date: Wed, 3 Jun 2026 11:07:52 -0700 Subject: [PATCH 2/2] imporve docstring --- .../src/sagemaker/serve/bedrock_model_builder.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py index 786cea18b2..fea478dbb8 100644 --- a/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/bedrock_model_builder.py @@ -153,8 +153,11 @@ def deploy( """Deploy the model to Bedrock. Automatically detects if the model is a Nova model and uses the appropriate - Bedrock API (create_custom_model for Nova, create_model_import_job for others). - For Nova models, also creates a custom model deployment for inference. + Bedrock API (create_custom_model for Nova, create_model_import_job for OSS). + For Nova models, creates a custom model deployment and polls until active. + For OSS models, creates a model import job and polls until complete. Once + deploy() returns, the model is ready for on-demand inference. For provisioned + throughput, use the separate create_provisioned_throughput() method. Args: job_name: Name for the model import job (OSS models only). @@ -170,12 +173,12 @@ def deploy( defaults to custom_model_name suffixed with '-deployment'. Returns: - Response from Bedrock API. For Nova models, returns the - create_custom_model_deployment response. For others, returns - the create_model_import_job response. + For Nova models: the create_custom_model_deployment response. + For OSS models: the completed get_model_import_job response. Raises: ValueError: If model_package is not set or required parameters are missing. + RuntimeError: If the import job or deployment fails or times out. """ if not self.model_package: raise ValueError(