Skip to content

Commit 54db09b

Browse files
authored
feature: Add support for SageMaker Serverless inference Provisioned Concurrency feature (aws#3851)
1 parent bfbb35f commit 54db09b

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def read_requirements(filename):
4848
# Declare minimal set for installation
4949
required_packages = [
5050
"attrs>=20.3.0,<23",
51-
"boto3>=1.26.28,<2.0",
51+
"boto3>=1.26.131,<2.0",
5252
"cloudpickle==2.2.1",
5353
"google-pasta",
5454
"numpy>=1.9.0,<2.0",

src/sagemaker/serverless/serverless_inference_config.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains code related to the ServerlessInferenceConfig class.
1414
15-
Codes are used for configuring async inference endpoint. Use it when deploying
15+
Codes are used for configuring serverless inference endpoint. Use it when deploying
1616
the model to the endpoints.
1717
"""
1818
from __future__ import print_function, absolute_import
19+
from typing import Optional
1920

2021

2122
class ServerlessInferenceConfig(object):
@@ -29,6 +30,7 @@ def __init__(
2930
self,
3031
memory_size_in_mb: int = 2048,
3132
max_concurrency: int = 5,
33+
provisioned_concurrency: Optional[int] = None,
3234
):
3335
"""Initialize a ServerlessInferenceConfig object for serverless inference configuration.
3436
@@ -40,9 +42,13 @@ def __init__(
4042
max_concurrency (int): Optional. The maximum number of concurrent invocations
4143
your serverless endpoint can process. If no value is provided, Amazon
4244
SageMaker will choose the default value for you. (Default: 5)
45+
provisioned_concurrency (int): Optional. The provisioned concurrency of your
46+
serverless endpoint. If no value is provided, Amazon SageMaker will not
47+
apply provisioned concucrrency to your Serverless endpoint. (Default: None)
4348
"""
4449
self.memory_size_in_mb = memory_size_in_mb
4550
self.max_concurrency = max_concurrency
51+
self.provisioned_concurrency = provisioned_concurrency
4652

4753
def _to_request_dict(self):
4854
"""Generates a request dictionary using the parameters provided to the class."""
@@ -51,4 +57,7 @@ def _to_request_dict(self):
5157
"MaxConcurrency": self.max_concurrency,
5258
}
5359

60+
if self.provisioned_concurrency is not None:
61+
request_dict["ProvisionedConcurrency"] = self.provisioned_concurrency
62+
5463
return request_dict

tests/integ/test_serverless_inference.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,30 @@ def test_serverless_walkthrough(sagemaker_session, cpu_instance_type, training_s
4444
pca.extra_components = 5
4545
pca.fit(pca.record_set(training_set[0][:100]), job_name=job_name)
4646

47-
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
47+
serverless_name = unique_name_from_base("pca-serverless")
48+
with timeout_and_delete_endpoint_by_name(serverless_name, sagemaker_session):
4849

4950
predictor_serverless = pca.deploy(
50-
endpoint_name=job_name, serverless_inference_config=ServerlessInferenceConfig()
51+
endpoint_name=serverless_name, serverless_inference_config=ServerlessInferenceConfig()
5152
)
5253

5354
result = predictor_serverless.predict(training_set[0][:5])
5455

5556
assert len(result) == 5
5657
for record in result:
5758
assert record.label["projection"] is not None
59+
60+
# Test out Serverless Provisioned Concurrency endpoint happy case
61+
serverless_pc_name = unique_name_from_base("pca-serverless-pc")
62+
with timeout_and_delete_endpoint_by_name(serverless_pc_name, sagemaker_session):
63+
64+
predictor_serverless_pc = pca.deploy(
65+
endpoint_name=serverless_pc_name,
66+
serverless_inference_config=ServerlessInferenceConfig(provisioned_concurrency=1),
67+
)
68+
69+
result = predictor_serverless_pc.predict(training_set[0][:5])
70+
71+
assert len(result) == 5
72+
for record in result:
73+
assert record.label["projection"] is not None

tests/unit/sagemaker/serverless/test_serverless_inference_config.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,54 @@
1616

1717
DEFAULT_MEMORY_SIZE_IN_MB = 2048
1818
DEFAULT_MAX_CONCURRENCY = 5
19+
DEFAULT_PROVISIONED_CONCURRENCY = 5
1920

2021
DEFAULT_REQUEST_DICT = {
2122
"MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB,
2223
"MaxConcurrency": DEFAULT_MAX_CONCURRENCY,
2324
}
2425

26+
PROVISIONED_CONCURRENCY_REQUEST_DICT = {
27+
"MemorySizeInMB": DEFAULT_MEMORY_SIZE_IN_MB,
28+
"MaxConcurrency": DEFAULT_MAX_CONCURRENCY,
29+
"ProvisionedConcurrency": DEFAULT_PROVISIONED_CONCURRENCY,
30+
}
31+
2532

2633
def test_init():
2734
serverless_inference_config = ServerlessInferenceConfig()
2835

2936
assert serverless_inference_config.memory_size_in_mb == DEFAULT_MEMORY_SIZE_IN_MB
3037
assert serverless_inference_config.max_concurrency == DEFAULT_MAX_CONCURRENCY
3138

39+
serverless_provisioned_concurrency_inference_config = ServerlessInferenceConfig(
40+
provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY
41+
)
42+
43+
assert (
44+
serverless_provisioned_concurrency_inference_config.memory_size_in_mb
45+
== DEFAULT_MEMORY_SIZE_IN_MB
46+
)
47+
assert (
48+
serverless_provisioned_concurrency_inference_config.max_concurrency
49+
== DEFAULT_MAX_CONCURRENCY
50+
)
51+
assert (
52+
serverless_provisioned_concurrency_inference_config.provisioned_concurrency
53+
== DEFAULT_PROVISIONED_CONCURRENCY
54+
)
55+
3256

3357
def test_to_request_dict():
3458
serverless_inference_config_dict = ServerlessInferenceConfig()._to_request_dict()
3559

3660
assert serverless_inference_config_dict == DEFAULT_REQUEST_DICT
61+
62+
serverless_provisioned_concurrency_inference_config_dict = ServerlessInferenceConfig(
63+
provisioned_concurrency=DEFAULT_PROVISIONED_CONCURRENCY
64+
)._to_request_dict()
65+
66+
assert (
67+
serverless_provisioned_concurrency_inference_config_dict
68+
== PROVISIONED_CONCURRENCY_REQUEST_DICT
69+
)

0 commit comments

Comments
 (0)