-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathbeliefpropagationcache.jl
108 lines (90 loc) · 3.59 KB
/
beliefpropagationcache.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
using Graphs: IsDirected
using SplitApplyCombine: group
using LinearAlgebra: diag, dot
using ITensors: dir
using NamedGraphs.PartitionedGraphs:
PartitionedGraphs,
PartitionedGraph,
PartitionVertex,
boundary_partitionedges,
partitionvertices,
partitionedges,
unpartitioned_graph
using SimpleTraits: SimpleTraits, Not, @traitfn
using NDTensors: NDTensors
function default_cache_construction_kwargs(alg::Algorithm"bp", ψ::AbstractITensorNetwork)
return (; partitioned_vertices=default_partitioned_vertices(ψ))
end
function default_cache_construction_kwargs(alg::Algorithm"bp", pg::PartitionedGraph)
return (;)
end
struct BeliefPropagationCache{PTN,MTS} <: AbstractBeliefPropagationCache
partitioned_tensornetwork::PTN
messages::MTS
end
#Constructors...
function BeliefPropagationCache(ptn::PartitionedGraph; messages=default_messages(ptn))
return BeliefPropagationCache(ptn, messages)
end
function BeliefPropagationCache(tn::AbstractITensorNetwork, partitioned_vertices; kwargs...)
ptn = PartitionedGraph(tn, partitioned_vertices)
return BeliefPropagationCache(ptn; kwargs...)
end
function BeliefPropagationCache(
tn::AbstractITensorNetwork;
partitioned_vertices=default_partitioned_vertices(tn),
kwargs...,
)
return BeliefPropagationCache(tn, partitioned_vertices; kwargs...)
end
function cache(alg::Algorithm"bp", tn; kwargs...)
return BeliefPropagationCache(tn; kwargs...)
end
default_cache_update_kwargs(alg::Algorithm"bp") = (; maxiter=25, tol=1e-8)
function partitioned_tensornetwork(bp_cache::BeliefPropagationCache)
return bp_cache.partitioned_tensornetwork
end
messages(bp_cache::BeliefPropagationCache) = bp_cache.messages
function default_message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
return default_message(scalartype(bp_cache), linkinds(bp_cache, edge))
end
function Base.copy(bp_cache::BeliefPropagationCache)
return BeliefPropagationCache(
copy(partitioned_tensornetwork(bp_cache)), copy(messages(bp_cache))
)
end
default_message_update_alg(bp_cache::BeliefPropagationCache) = "bp"
function default_bp_maxiter(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
return default_bp_maxiter(partitioned_graph(bp_cache))
end
function default_edge_sequence(alg::Algorithm"bp", bp_cache::BeliefPropagationCache)
return default_edge_sequence(partitioned_tensornetwork(bp_cache))
end
function default_message_update_kwargs(
alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache
)
return (;)
end
partitions(bpc::BeliefPropagationCache) = partitionvertices(partitioned_tensornetwork(bpc))
partitionpairs(bpc::BeliefPropagationCache) = partitionedges(partitioned_tensornetwork(bpc))
function set_messages(cache::BeliefPropagationCache, messages)
return BeliefPropagationCache(partitioned_tensornetwork(cache), messages)
end
function environment(bpc::BeliefPropagationCache, verts::Vector; kwargs...)
partition_verts = partitionvertices(bpc, verts)
messages = incoming_messages(bpc, partition_verts; kwargs...)
central_tensors = factors(bpc, setdiff(vertices(bpc, partition_verts), verts))
return vcat(messages, central_tensors)
end
function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
incoming_mts = incoming_messages(bp_cache, [pv])
local_state = factors(bp_cache, pv)
ts = vcat(incoming_mts, local_state)
sequence = contraction_sequence(ts; alg="optimal")
return contract(ts; sequence)[]
end
function region_scalar(bp_cache::BeliefPropagationCache, pe::PartitionEdge)
ts = vcat(message(bp_cache, pe), message(bp_cache, reverse(pe)))
sequence = contraction_sequence(ts; alg="optimal")
return contract(ts; sequence)[]
end