Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ PARAMETERS: '"Parameters"';

CREDENTIALS: '"Credentials"';

ROLEARN: '"RoleArn"';

ROLEARNPATH: '"RoleArn.$"';

RESULTSELECTOR: '"ResultSelector"';

ITEMREADER: '"ItemReader"';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,16 @@ max_concurrency_path_decl:

parameters_decl: PARAMETERS COLON payload_tmpl_decl;

credentials_decl: CREDENTIALS COLON payload_tmpl_decl;
credentials_decl: CREDENTIALS COLON LBRACE role_arn_decl RBRACE;

role_arn_decl:
ROLEARN COLON STRINGJSONATA # role_arn_jsonata
| ROLEARNPATH COLON STRINGPATH # role_arn_path
| ROLEARNPATH COLON STRINGPATHCONTEXTOBJ # role_arn_path_context_obj
| ROLEARNPATH COLON STRINGINTRINSICFUNC # role_arn_intrinsic_func
| ROLEARNPATH COLON variable_sample # role_arn_var
| ROLEARN COLON keyword_or_string # role_arn_str
;

timeout_seconds_decl:
TIMEOUTSECONDS COLON STRINGJSONATA # timeout_seconds_jsonata
Expand Down Expand Up @@ -648,6 +657,8 @@ keyword_or_string:
| RESULT
| PARAMETERS
| CREDENTIALS
| ROLEARN
| ROLEARNPATH
| RESULTSELECTOR
| ITEMREADER
| READERCONFIG
Expand Down
2,166 changes: 1,091 additions & 1,075 deletions localstack-core/localstack/services/stepfunctions/asl/antlr/runtime/ASLLexer.py

Large diffs are not rendered by default.

3,434 changes: 1,868 additions & 1,566 deletions localstack-core/localstack/services/stepfunctions/asl/antlr/runtime/ASLParser.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,60 @@ def exitCredentials_decl(self, ctx:ASLParser.Credentials_declContext):
pass


# Enter a parse tree produced by ASLParser#role_arn_jsonata.
def enterRole_arn_jsonata(self, ctx:ASLParser.Role_arn_jsonataContext):
pass

# Exit a parse tree produced by ASLParser#role_arn_jsonata.
def exitRole_arn_jsonata(self, ctx:ASLParser.Role_arn_jsonataContext):
pass


# Enter a parse tree produced by ASLParser#role_arn_path.
def enterRole_arn_path(self, ctx:ASLParser.Role_arn_pathContext):
pass

# Exit a parse tree produced by ASLParser#role_arn_path.
def exitRole_arn_path(self, ctx:ASLParser.Role_arn_pathContext):
pass


# Enter a parse tree produced by ASLParser#role_arn_path_context_obj.
def enterRole_arn_path_context_obj(self, ctx:ASLParser.Role_arn_path_context_objContext):
pass

# Exit a parse tree produced by ASLParser#role_arn_path_context_obj.
def exitRole_arn_path_context_obj(self, ctx:ASLParser.Role_arn_path_context_objContext):
pass


# Enter a parse tree produced by ASLParser#role_arn_intrinsic_func.
def enterRole_arn_intrinsic_func(self, ctx:ASLParser.Role_arn_intrinsic_funcContext):
pass

# Exit a parse tree produced by ASLParser#role_arn_intrinsic_func.
def exitRole_arn_intrinsic_func(self, ctx:ASLParser.Role_arn_intrinsic_funcContext):
pass


# Enter a parse tree produced by ASLParser#role_arn_var.
def enterRole_arn_var(self, ctx:ASLParser.Role_arn_varContext):
pass

# Exit a parse tree produced by ASLParser#role_arn_var.
def exitRole_arn_var(self, ctx:ASLParser.Role_arn_varContext):
pass


# Enter a parse tree produced by ASLParser#role_arn_str.
def enterRole_arn_str(self, ctx:ASLParser.Role_arn_strContext):
pass

# Exit a parse tree produced by ASLParser#role_arn_str.
def exitRole_arn_str(self, ctx:ASLParser.Role_arn_strContext):
pass


# Enter a parse tree produced by ASLParser#timeout_seconds_jsonata.
def enterTimeout_seconds_jsonata(self, ctx:ASLParser.Timeout_seconds_jsonataContext):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,36 @@ def visitCredentials_decl(self, ctx:ASLParser.Credentials_declContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#role_arn_jsonata.
def visitRole_arn_jsonata(self, ctx:ASLParser.Role_arn_jsonataContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#role_arn_path.
def visitRole_arn_path(self, ctx:ASLParser.Role_arn_pathContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#role_arn_path_context_obj.
def visitRole_arn_path_context_obj(self, ctx:ASLParser.Role_arn_path_context_objContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#role_arn_intrinsic_func.
def visitRole_arn_intrinsic_func(self, ctx:ASLParser.Role_arn_intrinsic_funcContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#role_arn_var.
def visitRole_arn_var(self, ctx:ASLParser.Role_arn_varContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#role_arn_str.
def visitRole_arn_str(self, ctx:ASLParser.Role_arn_strContext):
return self.visitChildren(ctx)


# Visit a parse tree produced by ASLParser#timeout_seconds_jsonata.
def visitTimeout_seconds_jsonata(self, ctx:ASLParser.Timeout_seconds_jsonataContext):
return self.visitChildren(ctx)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,93 @@
from typing import Final
import abc
import copy
from typing import Final, Optional

from localstack.services.stepfunctions.asl.component.common.payload.payloadvalue.payloadtmpl.payload_tmpl import (
PayloadTmpl,
from localstack.services.stepfunctions.asl.component.common.jsonata.jsonata_template_value_terminal import (
JSONataTemplateValueTerminalExpression,
)
from localstack.services.stepfunctions.asl.component.common.variable_sample import VariableSample
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
from localstack.services.stepfunctions.asl.component.intrinsic.function.function import Function
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.parse.intrinsic.intrinsic_parser import IntrinsicParser
from localstack.services.stepfunctions.asl.utils.json_path import extract_json

_CREDENTIALS_ROLE_ARN_KEY: Final[str] = "RoleArn"
ComputedCredentials = dict


class RoleArn(EvalComponent, abc.ABC): ...


class RoleArnConst(RoleArn):
value: Final[str]

def __init__(self, value: str):
self.value = value

def _eval_body(self, env: Environment) -> None:
env.stack.append(self.value)


class RoleArnJSONata(RoleArn):
jsonata_template_value_terminal_expression: Final[JSONataTemplateValueTerminalExpression]

def __init__(
self, jsonata_template_value_terminal_expression: JSONataTemplateValueTerminalExpression
):
super().__init__()
self.jsonata_template_value_terminal_expression = jsonata_template_value_terminal_expression

def _eval_body(self, env: Environment) -> None:
self.jsonata_template_value_terminal_expression.eval(env=env)


class RoleArnVar(RoleArn):
variable_sample: Final[VariableSample]

def __init__(self, variable_sample: VariableSample):
self.variable_sample = variable_sample

def _eval_body(self, env: Environment) -> None:
self.variable_sample.eval(env=env)


class RoleArnPath(RoleArnConst):
def _eval_body(self, env: Environment) -> None:
current_output = env.stack[-1]
arn = extract_json(self.value, current_output)
env.stack.append(arn)


class RoleArnContextObject(RoleArnConst):
def _eval_body(self, env: Environment) -> None:
value = extract_json(self.value, env.states.context_object.context_object_data)
env.stack.append(copy.deepcopy(value))


class RoleArnIntrinsicFunction(RoleArnConst):
function: Final[Function]

def __init__(self, value: str) -> None:
super().__init__(value=value)
self.function, _ = IntrinsicParser.parse(value)

def _eval_body(self, env: Environment) -> None:
self.function.eval(env=env)


class Credentials(EvalComponent):
payload_template: Final[PayloadTmpl]
role_arn: Final[RoleArn]

def __init__(self, role_arn: RoleArn):
self.role_arn = role_arn

def __init__(self, payload_template: PayloadTmpl):
self.payload_template = payload_template
@staticmethod
def get_role_arn_from(computed_credentials: ComputedCredentials) -> Optional[str]:
return computed_credentials.get(_CREDENTIALS_ROLE_ARN_KEY)

def _eval_body(self, env: Environment) -> None:
self.payload_template.eval(env=env)
self.role_arn.eval(env=env)
role_arn = env.stack.pop()
computes_credentials: ComputedCredentials = {_CREDENTIALS_ROLE_ARN_KEY: role_arn}
env.stack.append(computes_credentials)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from typing import IO, Any, Final, Optional, Union

from localstack.aws.api.lambda_ import InvocationResponse
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
)
from localstack.services.stepfunctions.asl.eval.environment import Environment
from localstack.services.stepfunctions.asl.utils.boto_client import boto_client_for
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
Expand Down Expand Up @@ -35,8 +38,12 @@ def _from_payload(payload_streaming_body: IO[bytes]) -> Union[json, str]:
return decoded_data


def exec_lambda_function(env: Environment, parameters: dict, region: str, account: str) -> None:
lambda_client = boto_client_for(region=region, account=account, service="lambda")
def exec_lambda_function(
env: Environment, parameters: dict, region: str, account: str, credentials: ComputedCredentials
) -> None:
lambda_client = boto_client_for(
region=region, account=account, service="lambda", credentials=credentials
)

invocation_resp: InvocationResponse = lambda_client.invoke(**parameters)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from localstack.aws.api.stepfunctions import (
HistoryEventExecutionDataDetails,
HistoryEventType,
TaskCredentials,
TaskFailedEventDetails,
TaskScheduledEventDetails,
TaskStartedEventDetails,
Expand All @@ -28,6 +29,10 @@
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
Credentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceRuntimePart,
ServiceResource,
Expand Down Expand Up @@ -231,11 +236,15 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: dict,
task_credentials: ComputedCredentials,
): ...

def _before_eval_execution(
self, env: Environment, resource_runtime_part: ResourceRuntimePart, raw_parameters: dict
self,
env: Environment,
resource_runtime_part: ResourceRuntimePart,
raw_parameters: dict,
task_credentials: TaskCredentials,
) -> None:
parameters_str = to_json_str(raw_parameters)

Expand All @@ -253,6 +262,10 @@ def _before_eval_execution(
self.heartbeat.eval(env=env)
heartbeat_seconds = env.stack.pop()
scheduled_event_details["heartbeatInSeconds"] = heartbeat_seconds
if self.credentials:
scheduled_event_details["taskCredentials"] = TaskCredentials(
roleArn=Credentials.get_role_arn_from(computed_credentials=task_credentials)
)
env.event_manager.add_event(
context=env.event_history_context,
event_type=HistoryEventType.TaskScheduled,
Expand All @@ -274,6 +287,7 @@ def _after_eval_execution(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: ComputedCredentials,
) -> None:
output = env.stack[-1]
self._verify_size_quota(env=env, value=output)
Expand All @@ -295,16 +309,18 @@ def _eval_execution(self, env: Environment) -> None:
resource_runtime_part: ResourceRuntimePart = env.stack.pop()

raw_parameters = self._eval_parameters(env=env)
task_credentials = self._eval_credentials(env=env)

self._before_eval_execution(
env=env, resource_runtime_part=resource_runtime_part, raw_parameters=raw_parameters
env=env,
resource_runtime_part=resource_runtime_part,
raw_parameters=raw_parameters,
task_credentials=task_credentials,
)

normalised_parameters = copy.deepcopy(raw_parameters)
self._normalise_parameters(normalised_parameters)

task_credentials = self._eval_credentials(env=env)

self._eval_service_task(
env=env,
resource_runtime_part=resource_runtime_part,
Expand All @@ -319,4 +335,5 @@ def _eval_execution(self, env: Environment) -> None:
env=env,
resource_runtime_part=resource_runtime_part,
normalised_parameters=normalised_parameters,
task_credentials=task_credentials,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
FailureEvent,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceCondition,
ResourceRuntimePart,
Expand Down Expand Up @@ -291,8 +294,10 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: dict,
task_credentials: ComputedCredentials,
):
# TODO: add support for task credentials

task_parameters: TaskParameters = select_from_typed_dict(
typed_dict=TaskParameters, obj=normalised_parameters
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from localstack.services.stepfunctions.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.credentials import (
ComputedCredentials,
)
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.service.resource import (
ResourceCondition,
ResourceRuntimePart,
Expand Down Expand Up @@ -125,14 +128,15 @@ def _eval_service_task(
env: Environment,
resource_runtime_part: ResourceRuntimePart,
normalised_parameters: dict,
task_credentials: dict,
task_credentials: ComputedCredentials,
):
service_name = self._get_boto_service_name()
api_action = self._get_boto_service_action()
api_client = boto_client_for(
region=resource_runtime_part.region,
account=resource_runtime_part.account,
service=service_name,
credentials=task_credentials,
)
response = getattr(api_client, api_action)(**normalised_parameters) or dict()
if response:
Expand Down
Loading
Loading