Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.item_reader.resource_eval.resource_eval import (
ResourceEval,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
StateCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
)
Expand All @@ -15,31 +18,41 @@

class ResourceEvalS3(ResourceEval):
_HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_"
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None]
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart, StateCredentials], None]

@staticmethod
def _get_s3_client(resource_runtime_part: ResourceRuntimePart):
def _get_s3_client(
resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials
):
return boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service="s3",
region=resource_runtime_part.region, service="s3", state_credentials=state_credentials
)

@staticmethod
def _handle_get_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None:
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
def _handle_get_object(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
state_credentials: StateCredentials,
) -> None:
s3_client = ResourceEvalS3._get_s3_client(
resource_runtime_part=resource_runtime_part, state_credentials=state_credentials
)
parameters = env.stack.pop()
response = s3_client.get_object(**parameters)
response = s3_client.get_object(**parameters) # noqa
content = to_str(response["Body"].read())
env.stack.append(content)

@staticmethod
def _handle_list_objects_v2(
env: Environment, resource_runtime_part: ResourceRuntimePart
env: Environment,
resource_runtime_part: ResourceRuntimePart,
state_credentials: StateCredentials,
) -> None:
s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
s3_client = ResourceEvalS3._get_s3_client(
resource_runtime_part=resource_runtime_part, state_credentials=state_credentials
)
parameters = env.stack.pop()
response = s3_client.list_objects_v2(**parameters)
response = s3_client.list_objects_v2(**parameters) # noqa
contents = response["Contents"]
env.stack.append(contents)

Expand All @@ -55,4 +68,5 @@ def eval_resource(self, env: Environment) -> None:
self.resource.eval(env=env)
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
resolver_handler = self._get_api_action_handler()
resolver_handler(env, resource_runtime_part)
state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn)
resolver_handler(env, resource_runtime_part, state_credentials)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.result_writer.resource_eval.resource_eval import (
ResourceEval,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
StateCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
)
Expand All @@ -16,22 +19,28 @@

class ResourceEvalS3(ResourceEval):
_HANDLER_REFLECTION_PREFIX: Final[str] = "_handle_"
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart], None]
_API_ACTION_HANDLER_TYPE = Callable[[Environment, ResourceRuntimePart, StateCredentials], None]

@staticmethod
def _get_s3_client(resource_runtime_part: ResourceRuntimePart):
def _get_s3_client(
resource_runtime_part: ResourceRuntimePart, state_credentials: StateCredentials
):
return boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service="s3",
service="s3", region=resource_runtime_part.region, state_credentials=state_credentials
)

@staticmethod
def _handle_put_object(env: Environment, resource_runtime_part: ResourceRuntimePart) -> None:
def _handle_put_object(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
state_credentials: StateCredentials,
) -> None:
parameters = env.stack.pop()
env.stack.pop() # TODO: results

s3_client = ResourceEvalS3._get_s3_client(resource_runtime_part=resource_runtime_part)
s3_client = ResourceEvalS3._get_s3_client(
resource_runtime_part=resource_runtime_part, state_credentials=state_credentials
)
map_run_record = env.map_run_record_pool_manager.get_all().pop()
map_run_uuid = map_run_record.map_run_arn.split(":")[-1]
if parameters["Prefix"] != "" and not parameters["Prefix"].endswith("/"):
Expand Down Expand Up @@ -66,4 +75,5 @@ def eval_resource(self, env: Environment) -> None:
self.resource.eval(env=env)
resource_runtime_part: ResourceRuntimePart = env.stack.pop()
resolver_handler = self._get_api_action_handler()
resolver_handler(env, resource_runtime_part)
state_credentials = StateCredentials(role_arn=env.aws_execution_details.role_arn)
resolver_handler(env, resource_runtime_part, state_credentials)
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Final, Optional
from dataclasses import dataclass
from typing import Final

from localstack.services.stepfunctions.asl.component.common.string.string_expression import (
StringExpression,
)
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
from localstack.services.stepfunctions.asl.eval.environment import Environment

_CREDENTIALS_ROLE_ARN_KEY: Final[str] = "RoleArn"
ComputedCredentials = dict

@dataclass
class StateCredentials:
role_arn: str


class RoleArn(EvalComponent):
Expand All @@ -26,12 +29,8 @@ class Credentials(EvalComponent):
def __init__(self, role_arn: RoleArn):
self.role_arn = role_arn

@staticmethod
def get_role_arn_from(computed_credentials: ComputedCredentials) -> Optional[str]:
return computed_credentials.get(_CREDENTIALS_ROLE_ARN_KEY)

def _eval_body(self, env: Environment) -> None:
self.role_arn.eval(env=env)
role_arn = env.stack.pop()
computes_credentials: ComputedCredentials = {_CREDENTIALS_ROLE_ARN_KEY: role_arn}
computes_credentials = StateCredentials(role_arn=role_arn)
env.stack.append(computes_credentials)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from localstack.aws.api.lambda_ import InvocationResponse
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
StateCredentials,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
Expand Down Expand Up @@ -39,10 +39,10 @@ def _from_payload(payload_streaming_body: IO[bytes]) -> Union[json, str]:


def exec_lambda_function(
env: Environment, parameters: dict, region: str, account: str, credentials: ComputedCredentials
env: Environment, parameters: dict, region: str, state_credentials: StateCredentials
) -> None:
lambda_client = boto_client_for(
region=region, account=account, service="lambda", credentials=credentials
service="lambda", region=region, state_credentials=state_credentials
)

invocation_resp: InvocationResponse = lambda_client.invoke(**parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
Credentials,
StateCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
Expand Down Expand Up @@ -235,15 +234,15 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
): ...

def _before_eval_execution(
self,
env: Environment,
resource_runtime_part: ResourceRuntimePart,
raw_parameters: dict,
task_credentials: TaskCredentials,
state_credentials: StateCredentials,
) -> None:
parameters_str = to_json_str(raw_parameters)

Expand All @@ -263,7 +262,7 @@ def _before_eval_execution(
scheduled_event_details["heartbeatInSeconds"] = heartbeat_seconds
if self.credentials:
scheduled_event_details["taskCredentials"] = TaskCredentials(
roleArn=Credentials.get_role_arn_from(computed_credentials=task_credentials)
roleArn=state_credentials.role_arn
)
env.event_manager.add_event(
context=env.event_history_context,
Expand All @@ -286,7 +285,7 @@ def _after_eval_execution(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
) -> None:
output = env.stack[-1]
self._verify_size_quota(env=env, value=output)
Expand All @@ -308,13 +307,13 @@ def _eval_execution(self, env: Environment) -> None:
resource_runtime_part: ResourceRuntimePart = env.stack.pop()

raw_parameters = self._eval_parameters(env=env)
task_credentials = self._eval_credentials(env=env)
state_credentials = self._eval_state_credentials(env=env)

self._before_eval_execution(
env=env,
resource_runtime_part=resource_runtime_part,
raw_parameters=raw_parameters,
task_credentials=task_credentials,
state_credentials=state_credentials,
)

normalised_parameters = copy.deepcopy(raw_parameters)
Expand All @@ -324,7 +323,7 @@ def _eval_execution(self, env: Environment) -> None:
env=env,
resource_runtime_part=resource_runtime_part,
normalised_parameters=normalised_parameters,
task_credentials=task_credentials,
state_credentials=state_credentials,
)

output_value = env.stack[-1]
Expand All @@ -334,5 +333,5 @@ def _eval_execution(self, env: Environment) -> None:
env=env,
resource_runtime_part=resource_runtime_part,
normalised_parameters=normalised_parameters,
task_credentials=task_credentials,
state_credentials=state_credentials,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
FailureEvent,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
StateCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceCondition,
Expand Down Expand Up @@ -294,7 +294,7 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
):
# TODO: add support for task credentials

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
StateCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceCondition,
Expand Down Expand Up @@ -128,15 +128,14 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
):
service_name = self._get_boto_service_name()
api_action = self._get_boto_service_action()
api_client = boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service=service_name,
credentials=task_credentials,
region=resource_runtime_part.region,
state_credentials=state_credentials,
)
response = getattr(api_client, api_action)(**normalised_parameters) or dict()
if response:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
StateCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceCondition,
Expand Down Expand Up @@ -90,15 +90,15 @@ def _before_eval_execution(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
raw_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
) -> None:
if self.resource.condition == ResourceCondition.Sync:
self._attach_aws_environment_variables(parameters=raw_parameters)
super()._before_eval_execution(
env=env,
resource_runtime_part=resource_runtime_part,
raw_parameters=raw_parameters,
task_credentials=task_credentials,
state_credentials=state_credentials,
)

def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
Expand Down Expand Up @@ -138,12 +138,12 @@ def _build_sync_resolver(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
) -> Callable[[], Optional[Any]]:
batch_client = boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service="batch",
region=resource_runtime_part.region,
state_credentials=state_credentials,
)
submission_output: dict = env.stack.pop()
job_id = submission_output["JobId"]
Expand Down Expand Up @@ -186,15 +186,14 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
state_credentials: StateCredentials,
):
service_name = self._get_boto_service_name()
api_action = self._get_boto_service_action()
batch_client = boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service=service_name,
credentials=task_credentials,
state_credentials=state_credentials,
)
response = getattr(batch_client, api_action)(**normalised_parameters)
response.pop("ResponseMetadata", None)
Expand Down
Loading
Loading