Skip to content

Commit 1790318

Browse files
authored
StepFunctions: Support for Output Blocks in Choice Rules, Improvments to JSONata Choice Defaults (#12075)
1 parent 0a5cd13 commit 1790318

File tree

17 files changed

+1283
-734
lines changed

17 files changed

+1283
-734
lines changed

localstack-core/localstack/services/stepfunctions/asl/antlr/ASLParser.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ comparison_variable_stmt:
254254
| comparison_func
255255
| next_decl
256256
| assign_decl
257+
| output_decl
257258
| comment_decl
258259
;
259260

localstack-core/localstack/services/stepfunctions/asl/antlr/runtime/ASLParser.py

Lines changed: 701 additions & 691 deletions
Large diffs are not rendered by default.

localstack-core/localstack/services/stepfunctions/asl/component/state/state.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
TaskFailedEventDetails,
1717
)
1818
from localstack.services.stepfunctions.asl.component.common.assign.assign_decl import AssignDecl
19-
from localstack.services.stepfunctions.asl.component.common.catch.catch_outcome import CatchOutcome
2019
from localstack.services.stepfunctions.asl.component.common.comment import Comment
2120
from localstack.services.stepfunctions.asl.component.common.error_name.failure_event import (
2221
FailureEvent,
@@ -187,9 +186,26 @@ def _verify_size_quota(self, env: Environment, value: Union[str, json]) -> None:
187186
)
188187
)
189188

189+
def _eval_state_input(self, env: Environment) -> None:
190+
# Filter the input onto the stack.
191+
if self.input_path:
192+
self.input_path.eval(env)
193+
else:
194+
env.stack.append(env.states.get_input())
195+
190196
@abc.abstractmethod
191197
def _eval_state(self, env: Environment) -> None: ...
192198

199+
def _eval_state_output(self, env: Environment) -> None:
200+
# Process output value as next state input.
201+
if self.output_path:
202+
self.output_path.eval(env=env)
203+
elif self.output:
204+
self.output.eval(env=env)
205+
else:
206+
current_output = env.stack.pop()
207+
env.states.reset(input_value=current_output)
208+
193209
def _eval_body(self, env: Environment) -> None:
194210
env.event_manager.add_event(
195211
context=env.event_history_context,
@@ -198,18 +214,12 @@ def _eval_body(self, env: Environment) -> None:
198214
stateEnteredEventDetails=self._get_state_entered_event_details(env=env)
199215
),
200216
)
201-
202217
env.states.context_object.context_object_data["State"] = StateData(
203218
EnteredTime=datetime.datetime.now(tz=datetime.timezone.utc).isoformat(), Name=self.name
204219
)
205220

206-
# Filter the input onto the stack.
207-
if self.input_path:
208-
self.input_path.eval(env)
209-
else:
210-
env.stack.append(env.states.get_input())
221+
self._eval_state_input(env=env)
211222

212-
# Exec the state's logic.
213223
try:
214224
self._eval_state(env)
215225
except NoSuchJsonPathError as no_such_json_path_error:
@@ -234,26 +244,11 @@ def _eval_body(self, env: Environment) -> None:
234244
if not isinstance(env.program_state(), ProgramRunning):
235245
return
236246

237-
# Obtain a reference to the state output.
238-
output = env.stack[-1]
239-
240-
# CatcherOutputs (i.e. outputs of Catch blocks) are never subjects of output normalisers,
241-
# the entire value is instead passed by value as input to the next state, or program output.
242-
if not isinstance(output, CatchOutcome):
243-
# Ensure the state's output is within state size quotas.
244-
self._verify_size_quota(env=env, value=output)
245-
246-
# Process output value as next state input.
247-
if self.output_path:
248-
self.output_path.eval(env=env)
249-
elif self.output:
250-
self.output.eval(env=env)
251-
else:
252-
current_output = env.stack.pop()
253-
env.states.reset(input_value=current_output)
254-
255-
# Set next state or halt (end).
256-
self._set_next(env)
247+
self._eval_state_output(env=env)
248+
249+
self._verify_size_quota(env=env, value=env.states.get_input())
250+
251+
self._set_next(env)
257252

258253
if self.state_exited_event_type is not None:
259254
env.event_manager.add_event(

localstack-core/localstack/services/stepfunctions/asl/component/state/state_choice/choice_rule.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from localstack.services.stepfunctions.asl.component.common.assign.assign_decl import AssignDecl
44
from localstack.services.stepfunctions.asl.component.common.comment import Comment
55
from localstack.services.stepfunctions.asl.component.common.flow.next import Next
6+
from localstack.services.stepfunctions.asl.component.common.outputdecl import Output
67
from localstack.services.stepfunctions.asl.component.eval_component import EvalComponent
78
from localstack.services.stepfunctions.asl.component.state.state_choice.comparison.comparison_type import (
89
Comparison,
@@ -15,22 +16,28 @@ class ChoiceRule(EvalComponent):
1516
next_stmt: Final[Optional[Next]]
1617
comment: Final[Optional[Comment]]
1718
assign: Final[Optional[AssignDecl]]
19+
output: Final[Optional[Output]]
1820

1921
def __init__(
2022
self,
2123
comparison: Optional[Comparison],
2224
next_stmt: Optional[Next],
2325
comment: Optional[Comment],
2426
assign: Optional[AssignDecl],
27+
output: Optional[Output],
2528
):
2629
self.comparison = comparison
2730
self.next_stmt = next_stmt
2831
self.comment = comment
2932
self.assign = assign
33+
self.output = output
3034

3135
def _eval_body(self, env: Environment) -> None:
3236
self.comparison.eval(env)
37+
is_condition_true: bool = env.stack[-1]
38+
if not is_condition_true:
39+
return
3340
if self.assign:
34-
is_condition_true: bool = env.stack[-1]
35-
if is_condition_true:
36-
self.assign.eval(env=env)
41+
self.assign.eval(env=env)
42+
if self.output:
43+
self.output.eval(env=env)

localstack-core/localstack/services/stepfunctions/asl/component/state/state_choice/state_choice.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
class StateChoice(CommonStateField):
1818
choices_decl: ChoicesDecl
1919
default_state: Optional[DefaultDecl]
20-
_next_state_name: Optional[str]
2120

2221
def __init__(self):
2322
super(StateChoice, self).__init__(
@@ -31,6 +30,7 @@ def from_state_props(self, state_props: StateProps) -> None:
3130
super(StateChoice, self).from_state_props(state_props)
3231
self.choices_decl = state_props.get(ChoicesDecl)
3332
self.default_state = state_props.get(DefaultDecl)
33+
3434
if state_props.get(Next) or state_props.get(End):
3535
raise ValueError(
3636
"Choice states don't support the End field. "
@@ -39,14 +39,9 @@ def from_state_props(self, state_props: StateProps) -> None:
3939
)
4040

4141
def _set_next(self, env: Environment) -> None:
42-
if self._next_state_name is None:
43-
raise RuntimeError(f"No Next option from state: '{self}'.")
44-
env.next_state_name = self._next_state_name
42+
pass
4543

4644
def _eval_state(self, env: Environment) -> None:
47-
if self.default_state:
48-
self._next_state_name = self.default_state.state_name
49-
5045
for rule in self.choices_decl.rules:
5146
rule.eval(env)
5247
res = env.stack.pop()
@@ -55,8 +50,29 @@ def _eval_state(self, env: Environment) -> None:
5550
raise RuntimeError(
5651
f"Missing Next definition for state_choice rule '{rule}' in choices '{self}'."
5752
)
58-
self._next_state_name = rule.next_stmt.name
59-
break
53+
env.stack.append(rule.next_stmt.name)
54+
return
55+
56+
if self.default_state is None:
57+
raise RuntimeError("No branching option reached in state %s", self.name)
58+
env.stack.append(self.default_state.state_name)
59+
60+
def _eval_state_output(self, env: Environment) -> None:
61+
next_state_name: str = env.stack.pop()
62+
63+
# No choice rule matched: the default state is evaluated.
64+
if self.default_state and self.default_state.state_name == next_state_name:
65+
if self.assign_decl:
66+
self.assign_decl.eval(env=env)
67+
if self.output:
68+
self.output.eval(env=env)
69+
70+
# Handle legacy output sequences if in JsonPath mode.
71+
if self._is_language_query_jsonpath():
72+
if self.output_path:
73+
self.output_path.eval(env=env)
74+
else:
75+
current_output = env.stack.pop()
76+
env.states.reset(input_value=current_output)
6077

61-
if self.assign_decl:
62-
self.assign_decl.eval(env=env)
78+
env.next_state_name = next_state_name

localstack-core/localstack/services/stepfunctions/asl/component/state/state_execution/execute_state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,11 @@ def _eval_state(self, env: Environment) -> None:
267267
break
268268

269269
self._handle_uncaught(env=env, failure_event=failure_event)
270+
271+
def _eval_state_output(self, env: Environment) -> None:
272+
# Obtain a reference to the state output.
273+
output = env.stack[-1]
274+
# CatcherOutputs (i.e. outputs of Catch blocks) are never subjects of output normalisers,
275+
# the entire value is instead passed by value as input to the next state, or program output.
276+
if not isinstance(output, CatchOutcome):
277+
super()._eval_state_output(env=env)

localstack-core/localstack/services/stepfunctions/asl/parse/preprocessor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ def visitChoice_rule_comparison_composite(
697697
next_stmt=composite_stmts.get(Next),
698698
comment=composite_stmts.get(Comment),
699699
assign=composite_stmts.get(AssignDecl),
700+
output=composite_stmts.get(Output),
700701
)
701702

702703
def visitChoice_rule_comparison_variable(
@@ -726,7 +727,8 @@ def visitChoice_rule_comparison_variable(
726727
comparison=comparison_variable,
727728
next_stmt=comparison_stmts.get(Next),
728729
comment=comparison_stmts.get(Comment),
729-
assign=comparison_stmts.get(AssignDecl),
730+
assign=None,
731+
output=None,
730732
)
731733
else:
732734
condition: Comparison = comparison_stmts.get(
@@ -740,6 +742,7 @@ def visitChoice_rule_comparison_variable(
740742
next_stmt=comparison_stmts.get(Next),
741743
comment=comparison_stmts.get(Comment),
742744
assign=comparison_stmts.get(AssignDecl),
745+
output=comparison_stmts.get(Output),
743746
)
744747

745748
def visitChoices_decl(self, ctx: ASLParser.Choices_declContext) -> ChoicesDecl:

tests/aws/services/stepfunctions/templates/assign/assign_templates.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,7 @@ class AssignTemplate(TemplateLoader):
128128
MAP_STATE_REFERENCE_IN_ITEM_SELECTOR = os.path.join(
129129
_THIS_FOLDER, "statemachines/map_state_reference_in_item_selector.json5"
130130
)
131+
132+
CHOICE_CONDITION_JSONATA: Final[str] = os.path.join(
133+
_THIS_FOLDER, "statemachines/choice_condition_jsonata.json5"
134+
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"QueryLanguage": "JSONata",
3+
"StartAt": "ChoiceState",
4+
"States": {
5+
"ChoiceState": {
6+
"Type": "Choice",
7+
"Choices": [
8+
{
9+
"Condition": "{% $states.input.condition %}",
10+
"Next": "ConditionTrue",
11+
"Assign": {
12+
"Assignment": "Condition assignment"
13+
}
14+
}
15+
],
16+
"Default": "DefaultState",
17+
"Assign": {
18+
"Assignment": "Default Assignment"
19+
}
20+
},
21+
"ConditionTrue": {
22+
"Type": "Pass",
23+
"End": true
24+
},
25+
"DefaultState": {
26+
"Type": "Fail",
27+
"Cause": "Condition is false"
28+
}
29+
}
30+
}

tests/aws/services/stepfunctions/templates/outputdecl/output_templates.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ class OutputTemplates(TemplateLoader):
1414
BASE_LAMBDA = os.path.join(_THIS_FOLDER, "statemachines/base_lambda.json5")
1515
BASE_TASK_LAMBDA = os.path.join(_THIS_FOLDER, "statemachines/base_task_lambda.json5")
1616
BASE_OUTPUT_ANY = os.path.join(_THIS_FOLDER, "statemachines/base_output_any.json5")
17+
CHOICE_CONDITION_JSONATA = os.path.join(
18+
_THIS_FOLDER, "statemachines/choice_condition_jsonata.json5"
19+
)

0 commit comments

Comments
 (0)