Skip to content

CloudFormation: [POC] Support Update Graph Modeling of Mappings and FindInMap #12432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ChangeSetTerminal(ChangeSetEntity, abc.ABC): ...


class NodeTemplate(ChangeSetNode):
mappings: Final[NodeMappings]
parameters: Final[NodeParameters]
conditions: Final[NodeConditions]
resources: Final[NodeResources]
Expand All @@ -121,11 +122,13 @@ def __init__(
self,
scope: Scope,
change_type: ChangeType,
mappings: NodeMappings,
parameters: NodeParameters,
conditions: NodeConditions,
resources: NodeResources,
):
super().__init__(scope=scope, change_type=change_type)
self.mappings = mappings
self.parameters = parameters
self.conditions = conditions
self.resources = resources
Expand Down Expand Up @@ -168,6 +171,24 @@ def __init__(self, scope: Scope, change_type: ChangeType, parameters: list[NodeP
self.parameters = parameters


class NodeMapping(ChangeSetNode):
name: Final[str]
bindings: Final[NodeObject]

def __init__(self, scope: Scope, change_type: ChangeType, name: str, bindings: NodeObject):
super().__init__(scope=scope, change_type=change_type)
self.name = name
self.bindings = bindings


class NodeMappings(ChangeSetNode):
mappings: Final[list[NodeMapping]]

def __init__(self, scope: Scope, change_type: ChangeType, mappings: list[NodeMapping]):
super().__init__(scope=scope, change_type=change_type)
self.mappings = mappings


class NodeCondition(ChangeSetNode):
name: Final[str]
body: Final[ChangeSetEntity]
Expand Down Expand Up @@ -300,6 +321,7 @@ def __init__(self, scope: Scope, value: Any):
TypeKey: Final[str] = "Type"
ConditionKey: Final[str] = "Condition"
ConditionsKey: Final[str] = "Conditions"
MappingsKey: Final[str] = "Mappings"
ResourcesKey: Final[str] = "Resources"
PropertiesKey: Final[str] = "Properties"
ParametersKey: Final[str] = "Parameters"
Expand All @@ -309,7 +331,15 @@ def __init__(self, scope: Scope, value: Any):
FnNot: Final[str] = "Fn::Not"
FnGetAttKey: Final[str] = "Fn::GetAtt"
FnEqualsKey: Final[str] = "Fn::Equals"
INTRINSIC_FUNCTIONS: Final[set[str]] = {RefKey, FnIf, FnNot, FnEqualsKey, FnGetAttKey}
FnFindInMapKey: Final[str] = "Fn::FindInMap"
INTRINSIC_FUNCTIONS: Final[set[str]] = {
RefKey,
FnIf,
FnNot,
FnEqualsKey,
FnGetAttKey,
FnFindInMapKey,
}


class ChangeSetModel:
Expand Down Expand Up @@ -455,6 +485,36 @@ def _resolve_intrinsic_function_ref(self, arguments: ChangeSetEntity) -> ChangeT
node_resource = self._retrieve_or_visit_resource(resource_name=logical_id)
return node_resource.change_type

def _resolve_intrinsic_function_fn_find_in_map(self, arguments: ChangeSetEntity) -> ChangeType:
if arguments.change_type != ChangeType.UNCHANGED:
return arguments.change_type
# TODO: validate arguments structure and type.
# TODO: add support for nested functions, here we assume the arguments are string literals.

if not isinstance(arguments, NodeArray) or not arguments.array:
raise RuntimeError()
argument_mapping_name = arguments.array[0]
if not isinstance(argument_mapping_name, TerminalValue):
raise NotImplementedError()
argument_top_level_key = arguments.array[1]
if not isinstance(argument_top_level_key, TerminalValue):
raise NotImplementedError()
argument_second_level_key = arguments.array[2]
if not isinstance(argument_second_level_key, TerminalValue):
raise NotImplementedError()
mapping_name = argument_mapping_name.value
top_level_key = argument_top_level_key.value
second_level_key = argument_second_level_key.value

node_mapping = self._retrieve_mapping(mapping_name=mapping_name)
# TODO: a lookup would be beneficial in this scenario too;
# consider implications downstream and for replication.
top_level_object = node_mapping.bindings.bindings.get(top_level_key)
if not isinstance(top_level_object, NodeObject):
raise RuntimeError()
target_map_value = top_level_object.bindings.get(second_level_key)
return target_map_value.change_type

def _resolve_intrinsic_function_fn_if(self, arguments: ChangeSetEntity) -> ChangeType:
# TODO: validate arguments structure and type.
if not isinstance(arguments, NodeArray) or not arguments.array:
Expand Down Expand Up @@ -705,6 +765,36 @@ def _visit_resources(
change_type = change_type.for_child(resource.change_type)
return NodeResources(scope=scope, change_type=change_type, resources=resources)

def _visit_mapping(
self, scope: Scope, name: str, before_mapping: Maybe[dict], after_mapping: Maybe[dict]
) -> NodeMapping:
bindings = self._visit_object(
scope=scope, before_object=before_mapping, after_object=after_mapping
)
return NodeMapping(
scope=scope, change_type=bindings.change_type, name=name, bindings=bindings
)

def _visit_mappings(
self, scope: Scope, before_mappings: Maybe[dict], after_mappings: Maybe[dict]
) -> NodeMappings:
change_type = ChangeType.UNCHANGED
mappings: list[NodeMapping] = list()
mapping_names = self._safe_keys_of(before_mappings, after_mappings)
for mapping_name in mapping_names:
scope_mapping, (before_mapping, after_mapping) = self._safe_access_in(
scope, mapping_name, before_mappings, after_mappings
)
mapping = self._visit_mapping(
scope=scope,
name=mapping_name,
before_mapping=before_mapping,
after_mapping=after_mapping,
)
mappings.append(mapping)
change_type = change_type.for_child(mapping.change_type)
return NodeMappings(scope=scope, change_type=change_type, mappings=mappings)

def _visit_dynamic_parameter(self, parameter_name: str) -> ChangeSetEntity:
scope = Scope("Dynamic").open_scope("Parameters")
scope_parameter, (before_parameter, after_parameter) = self._safe_access_in(
Expand Down Expand Up @@ -845,6 +935,14 @@ def _visit_conditions(
def _model(self, before_template: Maybe[dict], after_template: Maybe[dict]) -> NodeTemplate:
root_scope = Scope()
# TODO: visit other child types

mappings_scope, (before_mappings, after_mappings) = self._safe_access_in(
root_scope, MappingsKey, before_template, after_template
)
mappings = self._visit_mappings(
scope=mappings_scope, before_mappings=before_mappings, after_mappings=after_mappings
)

parameters_scope, (before_parameters, after_parameters) = self._safe_access_in(
root_scope, ParametersKey, before_template, after_template
)
Expand Down Expand Up @@ -876,6 +974,7 @@ def _model(self, before_template: Maybe[dict], after_template: Maybe[dict]) -> N
return NodeTemplate(
scope=root_scope,
change_type=resources.change_type,
mappings=mappings,
parameters=parameters,
conditions=conditions,
resources=resources,
Expand Down Expand Up @@ -919,6 +1018,23 @@ def _retrieve_parameter_if_exists(self, parameter_name: str) -> Optional[NodePar
return node_parameter
return None

def _retrieve_mapping(self, mapping_name) -> NodeMapping:
# TODO: add caching mechanism, and raise appropriate error if missing.
scope_mappings, (before_mappings, after_mappings) = self._safe_access_in(
Scope(), MappingsKey, self._before_template, self._after_template
)
before_mappings = before_mappings or dict()
after_mappings = after_mappings or dict()
if mapping_name in before_mappings or mapping_name in after_mappings:
scope_mapping, (before_mapping, after_mapping) = self._safe_access_in(
scope_mappings, mapping_name, before_mappings, after_mappings
)
node_mapping = self._visit_mapping(
scope_mapping, mapping_name, before_mapping, after_mapping
)
return node_mapping
raise RuntimeError()

def _retrieve_or_visit_resource(self, resource_name: str) -> NodeResource:
resources_scope, (before_resources, after_resources) = self._safe_access_in(
Scope(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NodeCondition,
NodeDivergence,
NodeIntrinsicFunction,
NodeMapping,
NodeObject,
NodeParameter,
NodeProperties,
Expand Down Expand Up @@ -74,6 +75,15 @@ def _get_node_property_for(property_name: str, node_resource: NodeResource) -> N
# TODO
raise RuntimeError()

def _get_node_mapping(self, map_name: str) -> NodeMapping:
mappings: list[NodeMapping] = self._node_template.mappings.mappings
# TODO: another scenarios suggesting property lookups might be preferable.
for mapping in mappings:
if mapping.name == map_name:
return mapping
# TODO
raise RuntimeError()

def _get_node_parameter_if_exists(self, parameter_name: str) -> Optional[NodeParameter]:
parameters: list[NodeParameter] = self._node_template.parameters.parameters
# TODO: another scenarios suggesting property lookups might be preferable.
Expand Down Expand Up @@ -109,6 +119,16 @@ def _resolve_reference(self, logica_id: str) -> DescribeUnit:
resource_unit = DescribeUnit(before_context=limitation_str, after_context=limitation_str)
return resource_unit

def _resolve_mapping(self, map_name: str, top_level_key: str, second_level_key) -> DescribeUnit:
# TODO: add support for nested intrinsic functions, and KNOWN AFTER APPLY logical ids.
node_mapping: NodeMapping = self._get_node_mapping(map_name=map_name)
top_level_value = node_mapping.bindings.bindings.get(top_level_key)
if not isinstance(top_level_value, NodeObject):
raise RuntimeError()
second_level_value = top_level_value.bindings.get(second_level_key)
mapping_value_unit = self.visit(second_level_value)
return mapping_value_unit

def _resolve_reference_binding(
self, before_logical_id: str, after_logical_id: str
) -> DescribeUnit:
Expand Down Expand Up @@ -281,8 +301,31 @@ def visit_node_intrinsic_function_fn_not(
# Implicit change type computation.
return DescribeUnit(before_context=before_context, after_context=after_context)

def visit_node_intrinsic_function_fn_find_in_map(
self, node_intrinsic_function: NodeIntrinsicFunction
) -> DescribeUnit:
# TODO: check for KNOWN AFTER APPLY values for logical ids coming from intrinsic functions as arguments.
# TODO: add type checking/validation for result unit?
arguments_unit = self.visit(node_intrinsic_function.arguments)
before_arguments = arguments_unit.before_context
after_arguments = arguments_unit.after_context
if before_arguments:
before_value_unit = self._resolve_mapping(*before_arguments)
before_context = before_value_unit.before_context
else:
before_context = None
if after_arguments:
after_value_unit = self._resolve_mapping(*after_arguments)
after_context = after_value_unit.after_context
else:
after_context = None
return DescribeUnit(before_context=before_context, after_context=after_context)

def visit_node_mapping(self, node_mapping: NodeMapping) -> DescribeUnit:
bindings_unit = self.visit(node_mapping.bindings)
return bindings_unit

def visit_node_parameter(self, node_parameter: NodeParameter) -> DescribeUnit:
# TODO: add caching for these operation, parameters may be referenced more than once.
# TODO: add support for default value sampling
dynamic_value = node_parameter.dynamic_value
describe_unit = self.visit(dynamic_value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
NodeConditions,
NodeDivergence,
NodeIntrinsicFunction,
NodeMapping,
NodeMappings,
NodeObject,
NodeParameter,
NodeParameters,
Expand Down Expand Up @@ -45,6 +47,12 @@ def visit_children(self, change_set_entity: ChangeSetEntity):
def visit_node_template(self, node_template: NodeTemplate):
self.visit_children(node_template)

def visit_node_mapping(self, node_mapping: NodeMapping):
self.visit_children(node_mapping)

def visit_node_mappings(self, node_mappings: NodeMappings):
self.visit_children(node_mappings)

def visit_node_parameters(self, node_parameters: NodeParameters):
self.visit_children(node_parameters)

Expand Down Expand Up @@ -94,6 +102,11 @@ def visit_node_intrinsic_function_fn_if(self, node_intrinsic_function: NodeIntri
def visit_node_intrinsic_function_fn_not(self, node_intrinsic_function: NodeIntrinsicFunction):
self.visit_children(node_intrinsic_function)

def visit_node_intrinsic_function_fn_find_in_map(
self, node_intrinsic_function: NodeIntrinsicFunction
):
self.visit_children(node_intrinsic_function)

def visit_node_intrinsic_function_ref(self, node_intrinsic_function: NodeIntrinsicFunction):
self.visit_children(node_intrinsic_function)

Expand Down
Loading
Loading