From c3e071ac589f94d5759a0bcf1de1dac22cb6577c Mon Sep 17 00:00:00 2001 From: MEPalma <64580864+MEPalma@users.noreply.github.com> Date: Thu, 2 Jan 2025 12:06:34 +0100 Subject: [PATCH] assume role boto clients --- .../resource_eval/resource_eval_s3.py | 38 +++++++++++++------ .../resource_eval/resource_eval_s3.py | 26 +++++++++---- .../state_execution/state_task/credentials.py | 15 ++++---- .../state_task/lambda_eval_utils.py | 6 +-- .../state_task/service/state_task_service.py | 19 +++++----- .../service/state_task_service_api_gateway.py | 4 +- .../service/state_task_service_aws_sdk.py | 9 ++--- .../service/state_task_service_batch.py | 17 ++++----- .../service/state_task_service_callback.py | 18 ++++----- .../service/state_task_service_dynamodb.py | 9 ++--- .../service/state_task_service_ecs.py | 18 ++++----- .../service/state_task_service_events.py | 9 ++--- .../service/state_task_service_glue.py | 29 +++++++------- .../service/state_task_service_lambda.py | 7 ++-- .../service/state_task_service_sfn.py | 19 ++++------ .../service/state_task_service_sns.py | 9 ++--- .../service/state_task_service_sqs.py | 9 ++--- .../service/state_task_service_unsupported.py | 9 ++--- .../state_execution/state_task/state_task.py | 10 ++--- .../state_task/state_task_lambda.py | 10 ++--- .../stepfunctions/asl/utils/boto_client.py | 28 +++----------- 21 files changed, 153 insertions(+), 165 deletions(-) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/item_reader/resource_eval/resource_eval_s3.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/item_reader/resource_eval/resource_eval_s3.py index 6eed0be685eaa..262c4f00ca540 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/item_reader/resource_eval/resource_eval_s3.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/item_reader/resource_eval/resource_eval_s3.py @@ -5,6 +5,9 @@ from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.item_reader.resource_eval.resource_eval import ( ResourceEval, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( + StateCredentials, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceRuntimePart, ) @@ -15,31 +18,41 @@ class ResourceEvalS3(ResourceEval): _HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_" - _API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None] + _API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart, StateCredentials], None] @staticmethod - def _get_s3_client(resource_runtime_part: ResourceRuntimePart): + def _get_s3_client( + resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials + ): return boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, - service="s3", + region=resource_runtime_part.region, service="s3", state_credentials=state_credentials ) @staticmethod - def _handle_get_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None: - s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part) + def _handle_get_object( + env: Environment, + resource_runtime_part: ResourceRuntimePart, + state_credentials: StateCredentials, + ) -> None: + s3_client = ResourceEvalS3._get_s3_client( + resource_runtime_part=resource_runtime_part, state_credentials=state_credentials + ) parameters = env.stack.pop() - response = s3_client.get_object(**parameters) + response = s3_client.get_object(**parameters) # noqa content = to_str(response["Body"].read()) env.stack.append(content) @staticmethod def _handle_list_objects_v2( - env: Environment, resource_runtime_part: ResourceRuntimePart + env: Environment, + resource_runtime_part: ResourceRuntimePart, + state_credentials: StateCredentials, ) -> None: - s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part) + s3_client = ResourceEvalS3._get_s3_client( + resource_runtime_part=resource_runtime_part, state_credentials=state_credentials + ) parameters = env.stack.pop() - response = s3_client.list_objects_v2(**parameters) + response = s3_client.list_objects_v2(**parameters) # noqa contents = response["Contents"] env.stack.append(contents) @@ -55,4 +68,5 @@ def eval_resource(self, env: Environment) -> None: self.resource.eval(env=env) resource_runtime_part: ResourceRuntimePart = env.stack.pop() resolver_handler = self._get_api_action_handler() - resolver_handler(env, resource_runtime_part) + state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn) + resolver_handler(env, resource_runtime_part, state_credentials) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/result_writer/resource_eval/resource_eval_s3.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/result_writer/resource_eval/resource_eval_s3.py index 21e3157b1381f..178c9653c83c6 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/result_writer/resource_eval/resource_eval_s3.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_map/result_writer/resource_eval/resource_eval_s3.py @@ -6,6 +6,9 @@ from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.result_writer.resource_eval.resource_eval import ( ResourceEval, ) +from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( + StateCredentials, +) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceRuntimePart, ) @@ -16,22 +19,28 @@ class ResourceEvalS3(ResourceEval): _HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_" - _API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None] + _API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart, StateCredentials], None] @staticmethod - def _get_s3_client(resource_runtime_part: ResourceRuntimePart): + def _get_s3_client( + resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials + ): return boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, - service="s3", + service="s3", region=resource_runtime_part.region, state_credentials=state_credentials ) @staticmethod - def _handle_put_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None: + def _handle_put_object( + env: Environment, + resource_runtime_part: ResourceRuntimePart, + state_credentials: StateCredentials, + ) -> None: parameters = env.stack.pop() env.stack.pop() # TODO: results - s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part) + s3_client = ResourceEvalS3._get_s3_client( + resource_runtime_part=resource_runtime_part, state_credentials=state_credentials + ) map_run_record = env.map_run_record_pool_manager.get_all().pop() map_run_uuid = map_run_record.map_run_arn.split(":")[-1] if parameters["Prefix"] != "" and not parameters["Prefix"].endswith("/"): @@ -66,4 +75,5 @@ def eval_resource(self, env: Environment) -> None: self.resource.eval(env=env) resource_runtime_part: ResourceRuntimePart = env.stack.pop() resolver_handler = self._get_api_action_handler() - resolver_handler(env, resource_runtime_part) + state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn) + resolver_handler(env, resource_runtime_part, state_credentials) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/credentials.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/credentials.py index c15562aacaebc..6839dc1c64a97 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/credentials.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/credentials.py @@ -1,4 +1,5 @@ -from typing import Final, Optional +from dataclasses import dataclass +from typing import Final from localstack.services.stepfunctions.asl.component.common.string.string_expression import ( StringExpression, @@ -6,8 +7,10 @@ from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent from localstack.services.stepfunctions.asl.eval.environment import Environment -_CREDENTIALS_ROLE_ARN_KEY: Final[str] = "RoleArn" -ComputedCredentials = dict + +@dataclass +class StateCredentials: + role_arn: str class RoleArn(EvalComponent): @@ -26,12 +29,8 @@ class Credentials(EvalComponent): def __init__(self, role_arn: RoleArn): self.role_arn = role_arn - @staticmethod - def get_role_arn_from(computed_credentials: ComputedCredentials) -> Optional[str]: - return computed_credentials.get(_CREDENTIALS_ROLE_ARN_KEY) - def _eval_body(self, env: Environment) -> None: self.role_arn.eval(env=env) role_arn = env.stack.pop() - computes_credentials: ComputedCredentials = {_CREDENTIALS_ROLE_ARN_KEY: role_arn} + computes_credentials = StateCredentials(role_arn=role_arn) env.stack.append(computes_credentials) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py index cd09c8a841c95..94cc1fc35817d 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/lambda_eval_utils.py @@ -4,7 +4,7 @@ from localstack.aws.api.lambda_ import InvocationResponse from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.eval.environment import Environment from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for @@ -39,10 +39,10 @@ def _from_payload(payload_streaming_body: IO[bytes]) -> Union[json, str]: def exec_lambda_function( - env: Environment, parameters: dict, region: str, account: str, credentials: ComputedCredentials + env: Environment, parameters: dict, region: str, state_credentials: StateCredentials ) -> None: lambda_client = boto_client_for( - region=region, account=account, service="lambda", credentials=credentials + service="lambda", region=region, state_credentials=state_credentials ) invocation_resp: InvocationResponse = lambda_client.invoke(**parameters) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py index eb9680eb3fbb7..5cc45e024200e 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service.py @@ -30,8 +30,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, - Credentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceRuntimePart, @@ -235,7 +234,7 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): ... def _before_eval_execution( @@ -243,7 +242,7 @@ def _before_eval_execution( env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict, - task_credentials: TaskCredentials, + state_credentials: StateCredentials, ) -> None: parameters_str = to_json_str(raw_parameters) @@ -263,7 +262,7 @@ def _before_eval_execution( scheduled_event_details["heartbeatInSeconds"] = heartbeat_seconds if self.credentials: scheduled_event_details["taskCredentials"] = TaskCredentials( - roleArn=Credentials.get_role_arn_from(computed_credentials=task_credentials) + roleArn=state_credentials.role_arn ) env.event_manager.add_event( context=env.event_history_context, @@ -286,7 +285,7 @@ def _after_eval_execution( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> None: output = env.stack[-1] self._verify_size_quota(env=env, value=output) @@ -308,13 +307,13 @@ def _eval_execution(self, env: Environment) -> None: resource_runtime_part: ResourceRuntimePart = env.stack.pop() raw_parameters = self._eval_parameters(env=env) - task_credentials = self._eval_credentials(env=env) + state_credentials = self._eval_state_credentials(env=env) self._before_eval_execution( env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) normalised_parameters = copy.deepcopy(raw_parameters) @@ -324,7 +323,7 @@ def _eval_execution(self, env: Environment) -> None: env=env, resource_runtime_part=resource_runtime_part, normalised_parameters=normalised_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) output_value = env.stack[-1] @@ -334,5 +333,5 @@ def _eval_execution(self, env: Environment) -> None: env=env, resource_runtime_part=resource_runtime_part, normalised_parameters=normalised_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py index 7140cca4c6d23..b4d8c660a8f81 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_api_gateway.py @@ -24,7 +24,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -294,7 +294,7 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): # TODO: add support for task credentials diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py index 984a9ecc4d7e9..2e84aa2dc64d2 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_aws_sdk.py @@ -15,7 +15,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -128,15 +128,14 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() api_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) response = getattr(api_client, api_action)(**normalised_parameters) or dict() if response: diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_batch.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_batch.py index 1b6edcb03621e..bc83e1f327121 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_batch.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_batch.py @@ -18,7 +18,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -90,7 +90,7 @@ def _before_eval_execution( env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> None: if self.resource.condition == ResourceCondition.Sync: self._attach_aws_environment_variables(parameters=raw_parameters) @@ -98,7 +98,7 @@ def _before_eval_execution( env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: @@ -138,12 +138,12 @@ def _build_sync_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: batch_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service="batch", + region=resource_runtime_part.region, + state_credentials=state_credentials, ) submission_output: dict = env.stack.pop() job_id = submission_output["JobId"] @@ -186,15 +186,14 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() batch_client = boto_client_for( region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + state_credentials=state_credentials, ) response = getattr(batch_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py index 16db8a97f21e8..31c0e97dd9af5 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_callback.py @@ -17,7 +17,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -66,7 +66,7 @@ def _build_sync_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: raise RuntimeError( f"Unsupported .sync callback procedure in resource {self.resource.resource_arn}" @@ -77,7 +77,7 @@ def _build_sync2_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: raise RuntimeError( f"Unsupported .sync2 callback procedure in resource {self.resource.resource_arn}" @@ -149,7 +149,7 @@ def _eval_integration_pattern( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> None: task_output = env.stack.pop() @@ -190,7 +190,7 @@ def _eval_integration_pattern( env=env, resource_runtime_part=resource_runtime_part, normalised_parameters=normalised_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) else: # The condition checks about the resource's condition is exhaustive leaving @@ -199,7 +199,7 @@ def _eval_integration_pattern( env=env, resource_runtime_part=resource_runtime_part, normalised_parameters=normalised_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) outcome = self._eval_sync( @@ -326,7 +326,7 @@ def _after_eval_execution( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> None: if self._is_integration_pattern(): output = env.stack[-1] @@ -346,12 +346,12 @@ def _after_eval_execution( env=env, resource_runtime_part=resource_runtime_part, normalised_parameters=normalised_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) super()._after_eval_execution( env=env, resource_runtime_part=resource_runtime_part, normalised_parameters=normalised_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py index 329d358596796..9fb484abc6362 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_dynamodb.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceRuntimePart, @@ -133,15 +133,14 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() dynamodb_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) response = getattr(dynamodb_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_ecs.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_ecs.py index 64b064350a557..3b3473aaa848c 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_ecs.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_ecs.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Final, Optional from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -50,7 +50,7 @@ def _before_eval_execution( env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> None: if self.resource.condition == ResourceCondition.Sync: raw_parameters[_STARTED_BY_PARAMETER_RAW_KEY] = _STARTED_BY_PARAMETER_VALUE @@ -58,7 +58,7 @@ def _before_eval_execution( env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters, - task_credentials=task_credentials, + state_credentials=state_credentials, ) def _eval_service_task( @@ -66,15 +66,14 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() ecs_client = boto_client_for( region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + state_credentials=state_credentials, ) response = getattr(ecs_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) @@ -102,13 +101,12 @@ def _build_sync_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: ecs_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service="ecs", - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) submission_output: dict = env.stack.pop() task_arn: str = submission_output["Tasks"][0]["TaskArn"] diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py index d086ac1c98f5f..19640f84ab02f 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_events.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -87,16 +87,15 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): self._normalised_request_parameters(env=env, parameters=normalised_parameters) service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() events_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) response = getattr(events_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_glue.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_glue.py index 49c5664c8d484..b400c84300e1c 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_glue.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_glue.py @@ -18,7 +18,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -52,11 +52,11 @@ _SYNC_HANDLER_REFLECTION_PREFIX: Final[str] = "_sync_to_" # The type of (sync)handler function for StateTaskServiceGlue objects. _API_ACTION_HANDLER_TYPE = Callable[ - [Environment, ResourceRuntimePart, dict, ComputedCredentials], None + [Environment, ResourceRuntimePart, dict, StateCredentials], None ] # The type of (sync)handler builder function for StateTaskServiceGlue objects. _API_ACTION_HANDLER_BUILDER_TYPE = Callable[ - [Environment, ResourceRuntimePart, dict, ComputedCredentials], Callable[[], Optional[Any]] + [Environment, ResourceRuntimePart, dict, StateCredentials], Callable[[], Optional[Any]] ] @@ -82,13 +82,12 @@ def _get_api_action_sync_builder_handler(self) -> _API_ACTION_HANDLER_BUILDER_TY @staticmethod def _get_glue_client( - resource_runtime_part: ResourceRuntimePart, task_credentials: ComputedCredentials + resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials ) -> boto3.client: return boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service="glue", - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) def _from_error(self, env: Environment, ex: Exception) -> FailureEvent: @@ -125,10 +124,10 @@ def _handle_start_job_run( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - computed_credentials: ComputedCredentials, + computed_credentials: StateCredentials, ): glue_client = self._get_glue_client( - resource_runtime_part=resource_runtime_part, task_credentials=computed_credentials + resource_runtime_part=resource_runtime_part, state_credentials=computed_credentials ) response = glue_client.start_job_run(**normalised_parameters) response.pop("ResponseMetadata", None) @@ -143,18 +142,18 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): # Source the action handler and delegate the evaluation. api_action_handler = self._get_api_action_handler() - api_action_handler(env, resource_runtime_part, normalised_parameters, task_credentials) + api_action_handler(env, resource_runtime_part, normalised_parameters, state_credentials) def _sync_to_start_job_run( self, env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: # Poll the job run state from glue, using GetJobRun until the job has terminated. Hence, append the output # of GetJobRun to the state. @@ -166,7 +165,7 @@ def _sync_to_start_job_run( job_run_id: str = start_job_run_output["JobRunId"] glue_client = self._get_glue_client( - resource_runtime_part=resource_runtime_part, task_credentials=task_credentials + resource_runtime_part=resource_runtime_part, state_credentials=state_credentials ) def _sync_resolver() -> Optional[Any]: @@ -212,10 +211,10 @@ def _build_sync_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: sync_resolver_builder = self._get_api_action_sync_builder_handler() sync_resolver = sync_resolver_builder( - env, resource_runtime_part, normalised_parameters, task_credentials + env, resource_runtime_part, normalised_parameters, state_credentials ) return sync_resolver diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py index 3bce9a43c828e..405dcf595d799 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_lambda.py @@ -15,7 +15,7 @@ lambda_eval_utils, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -122,12 +122,11 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): lambda_eval_utils.exec_lambda_function( env=env, parameters=normalised_parameters, region=resource_runtime_part.region, - account=resource_runtime_part.account, - credentials=task_credentials, + state_credentials=state_credentials, ) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py index a3a3326c23212..b450b8e6da582 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sfn.py @@ -23,7 +23,7 @@ StatesErrorNameType, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -114,13 +114,12 @@ def _build_sync_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: sfn_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service="stepfunctions", - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) submission_output: dict = env.stack.pop() execution_arn: str = submission_output["ExecutionArn"] @@ -176,13 +175,12 @@ def _build_sync2_resolver( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ) -> Callable[[], Optional[Any]]: sfn_client = boto_client_for( region=resource_runtime_part.region, - account=resource_runtime_part.account, service="stepfunctions", - credentials=task_credentials, + state_credentials=state_credentials, ) submission_output: dict = env.stack.pop() execution_arn: str = submission_output["ExecutionArn"] @@ -227,15 +225,14 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() sfn_client = boto_client_for( region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + state_credentials=state_credentials, ) response = getattr(sfn_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py index 7c19e01c13ddc..45c6693d0dafd 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sns.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -90,15 +90,14 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() sns_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) # Optimised integration automatically stringifies diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py index c2a2c281907ed..836cb8ad1b95b 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_sqs.py @@ -10,7 +10,7 @@ FailureEvent, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -92,7 +92,7 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): # TODO: Stepfunctions automatically dumps to json MessageBody's definitions. # Are these other similar scenarios? @@ -104,10 +104,9 @@ def _eval_service_task( service_name = self._get_boto_service_name() api_action = self._get_boto_service_action() sqs_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) response = getattr(sqs_client, api_action)(**normalised_parameters) response.pop("ResponseMetadata", None) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_unsupported.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_unsupported.py index 6b972c2af374b..421e3c8619fa6 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_unsupported.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/service/state_task_service_unsupported.py @@ -2,7 +2,7 @@ from typing import Final from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( ResourceCondition, @@ -42,7 +42,7 @@ def _eval_service_task( env: Environment, resource_runtime_part: ResourceRuntimePart, normalised_parameters: dict, - task_credentials: ComputedCredentials, + state_credentials: StateCredentials, ): # Logs that the evaluation of this optimised service integration is not supported # and relays the call to the target service with the computed parameters. @@ -50,10 +50,9 @@ def _eval_service_task( service_name = self._get_boto_service_name() boto_action = self._get_boto_service_action() boto_client = boto_client_for( - region=resource_runtime_part.region, - account=resource_runtime_part.account, service=service_name, - credentials=task_credentials, + region=resource_runtime_part.region, + state_credentials=state_credentials, ) response = getattr(boto_client, boto_action)(**normalised_parameters) response.pop("ResponseMetadata", None) diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py index 1fc2441d8e5b3..79c5f496d7bf8 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task.py @@ -18,8 +18,8 @@ ExecutionState, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, Credentials, + StateCredentials, ) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( Resource, @@ -69,13 +69,13 @@ def _eval_parameters(self, env: Environment) -> dict: return parameters - def _eval_credentials(self, env: Environment) -> ComputedCredentials: + def _eval_state_credentials(self, env: Environment) -> StateCredentials: if not self.credentials: - task_credentials = dict() + state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn) else: self.credentials.eval(env=env) - task_credentials = env.stack.pop() - return task_credentials + state_credentials = env.stack.pop() + return state_credentials def _get_timed_out_failure_event(self, env: Environment) -> FailureEvent: return FailureEvent( diff --git a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py index ee584ed39423b..a6a9dbe0c78d3 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py +++ b/localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/state_task/state_task_lambda.py @@ -30,9 +30,6 @@ from localstack.services.stepfunctions.asl.component.state.state_execution.state_task import ( lambda_eval_utils, ) -from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - Credentials, -) from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import ( LambdaResource, ResourceRuntimePart, @@ -134,7 +131,7 @@ def _eval_parameters(self, env: Environment) -> dict: def _eval_execution(self, env: Environment) -> None: parameters = self._eval_parameters(env=env) - task_credentials = self._eval_credentials(env=env) + state_credentials = self._eval_state_credentials(env=env) payload = parameters["Payload"] scheduled_event_details = LambdaFunctionScheduledEventDetails( @@ -150,7 +147,7 @@ def _eval_execution(self, env: Environment) -> None: scheduled_event_details["timeoutInSeconds"] = timeout_seconds if self.credentials: scheduled_event_details["taskCredentials"] = TaskCredentials( - roleArn=Credentials.get_role_arn_from(computed_credentials=task_credentials) + roleArn=state_credentials.role_arn ) env.event_manager.add_event( context=env.event_history_context, @@ -171,8 +168,7 @@ def _eval_execution(self, env: Environment) -> None: env=env, parameters=parameters, region=resource_runtime_part.region, - account=resource_runtime_part.account, - credentials=task_credentials, + state_credentials=state_credentials, ) # In lambda invocations, only payload is passed on as output. diff --git a/localstack-core/localstack/services/stepfunctions/asl/utils/boto_client.py b/localstack-core/localstack/services/stepfunctions/asl/utils/boto_client.py index 021ca28425ac8..c7facf1bb532c 100644 --- a/localstack-core/localstack/services/stepfunctions/asl/utils/boto_client.py +++ b/localstack-core/localstack/services/stepfunctions/asl/utils/boto_client.py @@ -1,13 +1,10 @@ -from typing import Optional - from botocore.client import BaseClient from botocore.config import Config from localstack.aws.connect import connect_to from localstack.services.stepfunctions.asl.component.common.timeouts.timeout import TimeoutSeconds from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import ( - ComputedCredentials, - Credentials, + StateCredentials, ) from localstack.utils.aws.client_types import ServicePrincipal @@ -20,24 +17,11 @@ ) -def boto_client_for( - region: str, account: str, service: str, credentials: Optional[ComputedCredentials] = None -) -> BaseClient: - if credentials: - assume_role_arn: Optional[str] = Credentials.get_role_arn_from( - computed_credentials=credentials - ) - if assume_role_arn is not None: - client_factory = connect_to.with_assumed_role( - role_arn=assume_role_arn, - service_principal=ServicePrincipal.states, - region_name=region, - config=_BOTO_CLIENT_CONFIG, - ) - return client_factory.get_client(service=service) - return connect_to.get_client( - aws_access_key_id=account, +def boto_client_for(service: str, region: str, state_credentials: StateCredentials) -> BaseClient: + client_factory = connect_to.with_assumed_role( + role_arn=state_credentials.role_arn, + service_principal=ServicePrincipal.states, region_name=region, - service_name=service, config=_BOTO_CLIENT_CONFIG, ) + return client_factory.get_client(service=service)