From 88f6b2cdb086058b62d231a17af7e6f22c5a3199 Mon Sep 17 00:00:00 2001 From: Joseph Tindall <51231103+JoeyT1994@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:32:12 -0400 Subject: [PATCH] Update gauge_walk functionality and fix bug in `VidalITensorNetwork` (#222) --- Project.toml | 2 +- src/ITensorNetworks.jl | 4 +- src/abstractitensornetwork.jl | 38 +++++++++++++++---- src/apply.jl | 2 +- src/caches/abstractbeliefpropagationcache.jl | 3 -- src/formnetworks/abstractformnetwork.jl | 4 ++ .../abstracttreetensornetwork.jl | 8 +--- 7 files changed, 40 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index 5509be2a..3347daa6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.12.0" +version = "0.12.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index 57d13e88..bcdae764 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -25,11 +25,11 @@ include("specialitensornetworks.jl") include("boundarymps.jl") include("partitioneditensornetwork.jl") include("edge_sequences.jl") +include("caches/abstractbeliefpropagationcache.jl") +include("caches/beliefpropagationcache.jl") include("formnetworks/abstractformnetwork.jl") include("formnetworks/bilinearformnetwork.jl") include("formnetworks/quadraticformnetwork.jl") -include("caches/abstractbeliefpropagationcache.jl") -include("caches/beliefpropagationcache.jl") include("contraction_tree_to_graph.jl") include("gauging.jl") include("utils.jl") diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index 73c915a1..6c80e36e 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -623,18 +623,42 @@ function gauge_walk( return gauge_walk(alg, tn, edgetype(tn).(edges); kwargs...) end +function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region) + return tree_gauge(alg, ψ, [region]) +end + +#Get the path that moves the gauge from a to b on a tree +#TODO: Move to NamedGraphs +function edge_sequence_between_regions(g::AbstractGraph, region_a::Vector, region_b::Vector) + issetequal(region_a, region_b) && return edgetype(g)[] + st = steiner_tree(g, union(region_a, region_b)) + path = post_order_dfs_edges(st, first(region_b)) + path = filter(e -> !((src(e) ∈ region_b) && (dst(e) ∈ region_b)), path) + return path +end + +# Gauge a ITensorNetwork from cur_region towards new_region, treating +# the network as a tree spanned by a spanning tree. +function tree_gauge( + alg::Algorithm, + ψ::AbstractITensorNetwork, + cur_region::Vector, + new_region::Vector; + kwargs..., +) + es = edge_sequence_between_regions(ψ, cur_region, new_region) + ψ = gauge_walk(alg, ψ, es; kwargs...) + return ψ +end + # Gauge a ITensorNetwork towards a region, treating # the network as a tree spanned by a spanning tree. function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region::Vector) - region_center = - length(region) != 1 ? first(center(steiner_tree(ψ, region))) : only(region) - path = post_order_dfs_edges(bfs_tree(ψ, region_center), region_center) - path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) - return gauge_walk(alg, ψ, path) + return tree_gauge(alg, ψ, collect(vertices(ψ)), region) end -function tree_gauge(alg::Algorithm, ψ::AbstractITensorNetwork, region) - return tree_gauge(alg, ψ, [region]) +function tree_orthogonalize(ψ::AbstractITensorNetwork, cur_region, new_region; kwargs...) + return tree_gauge(Algorithm("orthogonalize"), ψ, cur_region, new_region; kwargs...) end function tree_orthogonalize(ψ::AbstractITensorNetwork, region; kwargs...) diff --git a/src/apply.jl b/src/apply.jl index 6a55f45f..3d2cfc0e 100644 --- a/src/apply.jl +++ b/src/apply.jl @@ -378,7 +378,7 @@ function ITensors.apply(o, ψ::VidalITensorNetwork; normalize=false, apply_kwarg else updated_ψ = apply(o, updated_ψ; normalize) - return VidalITensorNetwork(ψ, updated_bond_tensors) + return VidalITensorNetwork(updated_ψ, updated_bond_tensors) end end diff --git a/src/caches/abstractbeliefpropagationcache.jl b/src/caches/abstractbeliefpropagationcache.jl index 01c90c04..b5163d3b 100644 --- a/src/caches/abstractbeliefpropagationcache.jl +++ b/src/caches/abstractbeliefpropagationcache.jl @@ -40,9 +40,6 @@ default_messages(ptn::PartitionedGraph) = Dictionary() return default_bp_maxiter(undirected_graph(underlying_graph(g))) end default_partitioned_vertices(ψ::AbstractITensorNetwork) = group(v -> v, vertices(ψ)) -function default_partitioned_vertices(f::AbstractFormNetwork) - return group(v -> original_state_vertex(f, v), vertices(f)) -end partitioned_tensornetwork(bpc::AbstractBeliefPropagationCache) = not_implemented() messages(bpc::AbstractBeliefPropagationCache) = not_implemented() diff --git a/src/formnetworks/abstractformnetwork.jl b/src/formnetworks/abstractformnetwork.jl index 66776ec8..07ddffdf 100644 --- a/src/formnetworks/abstractformnetwork.jl +++ b/src/formnetworks/abstractformnetwork.jl @@ -80,3 +80,7 @@ operator_vertex(f::AbstractFormNetwork, v) = operator_vertex_map(f)(v) bra_vertex(f::AbstractFormNetwork, v) = bra_vertex_map(f)(v) ket_vertex(f::AbstractFormNetwork, v) = ket_vertex_map(f)(v) original_state_vertex(f::AbstractFormNetwork, v) = inv_vertex_map(f)(v) + +function default_partitioned_vertices(f::AbstractFormNetwork) + return group(v -> original_state_vertex(f, v), vertices(f)) +end diff --git a/src/treetensornetworks/abstracttreetensornetwork.jl b/src/treetensornetworks/abstracttreetensornetwork.jl index f6c8f49f..f2f82ef1 100644 --- a/src/treetensornetworks/abstracttreetensornetwork.jl +++ b/src/treetensornetworks/abstracttreetensornetwork.jl @@ -36,13 +36,7 @@ function set_ortho_region(tn::AbstractTTN, new_region) end function gauge(alg::Algorithm, ttn::AbstractTTN, region::Vector; kwargs...) - issetequal(region, ortho_region(ttn)) && return ttn - st = steiner_tree(ttn, union(region, ortho_region(ttn))) - path = post_order_dfs_edges(st, first(region)) - path = filter(e -> !((src(e) ∈ region) && (dst(e) ∈ region)), path) - if !isempty(path) - ttn = typeof(ttn)(gauge_walk(alg, ITensorNetwork(ttn), path; kwargs...)) - end + ttn = tree_gauge(alg, ttn, collect(ortho_region(ttn)), region; kwargs...) return set_ortho_region(ttn, region) end