Skip to content

Commit d83cae8

Browse files
William FisetWilliam Fiset
William Fiset
authored and
William Fiset
committed
Segment tree work
1 parent ae13588 commit d83cae8

File tree

2 files changed

+300
-66
lines changed

2 files changed

+300
-66
lines changed

src/main/java/com/williamfiset/algorithms/datastructures/segmenttree/RangeQueryPointUpdateSegmentTree.java

Lines changed: 131 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,68 +13,102 @@
1313
*/
1414
package com.williamfiset.algorithms.datastructures.segmenttree;
1515

16+
import java.util.Arrays;
17+
import java.util.function.BinaryOperator;
18+
1619
public class RangeQueryPointUpdateSegmentTree {
17-
// TODO(william): make the members of this class private
1820

19-
// Tree segment values.
20-
Integer[] t;
21+
// The type of segment combination function to use
22+
public static enum Operation {
23+
SUM,
24+
MIN,
25+
MAX
26+
}
2127

2228
// The number of values in the original input values array.
23-
int n;
29+
private int n;
30+
31+
private long[] t;
2432

25-
// The size of the segment tree `t`
26-
// NOTE: the size is not necessarily = number of segments.
27-
int N;
33+
private Operation op;
2834

29-
public RangeQueryPointUpdateSegmentTree(int[] values) {
35+
// The chosen range combination function
36+
private BinaryOperator<Long> fn;
37+
38+
private BinaryOperator<Long> sumFn = (a, b) -> a + b;
39+
private BinaryOperator<Long> minFn = (a, b) -> Math.min(a, b);
40+
private BinaryOperator<Long> maxFn = (a, b) -> Math.max(a, b);
41+
42+
public RangeQueryPointUpdateSegmentTree(long[] values, Operation op) {
3043
if (values == null) {
3144
throw new NullPointerException("Segment tree values cannot be null.");
3245
}
46+
if (op == null) {
47+
throw new NullPointerException("Please specify a valid segment combination operation.");
48+
}
3349
n = values.length;
50+
this.op = op;
51+
52+
// The size of the segment tree `t`
53+
//
3454
// TODO(william): Investigate to reduce this space. There are only 2n-1 segments, so we should
3555
// be able to reduce the space, but may need to reorganize the tree/queries. One idea is to use
3656
// the Eulerian tour structure of the tree to densely pack the segments.
37-
N = 4 * n;
38-
t = new Integer[N];
57+
int N = 4 * n;
58+
59+
t = new long[N];
60+
61+
if (op == Operation.SUM) {
62+
fn = sumFn;
63+
} else if (op == Operation.MIN) {
64+
Arrays.fill(t, Long.MAX_VALUE);
65+
fn = minFn;
66+
} else if (op == Operation.MAX) {
67+
Arrays.fill(t, Long.MIN_VALUE);
68+
fn = maxFn;
69+
}
3970

40-
buildTree(0, 0, n - 1, values);
41-
// System.out.println(java.util.Arrays.toString(values));
42-
// System.out.println(java.util.Arrays.toString(t));
71+
buildSegmentTree(0, 0, n - 1, values);
4372
}
4473

4574
/**
46-
* Builds the segment tree starting with leaf nodes and combining values on callback. This
47-
* construction method takes O(n) time since there are only 2n - 1 segments in the segment tree.
75+
* Builds a segment tree by starting with the leaf nodes and combining segment values on callback.
4876
*
4977
* @param i the index of the segment in the segment tree
50-
* @param l the left index of the range on the values array
51-
* @param r the right index of the range on the values array
78+
* @param l the left index (inclusive) of the range in the values array
79+
* @param r the right index (inclusive) of the range in the values array
5280
* @param values the initial values array
53-
* <p>The range [l, r] over the values array is inclusive.
5481
*/
55-
private void buildTree(int i, int tl, int tr, int[] values) {
82+
private void buildSegmentTree(int i, int tl, int tr, long[] values) {
5683
if (tl == tr) {
5784
t[i] = values[tl];
5885
return;
5986
}
6087
int mid = (tl + tr) / 2;
88+
buildSegmentTree(2 * i + 1, tl, mid, values);
89+
buildSegmentTree(2 * i + 2, mid + 1, tr, values);
6190

62-
buildTree(2 * i + 1, tl, mid, values);
63-
buildTree(2 * i + 2, mid + 1, tr, values);
91+
t[i] = fn.apply(t[2 * i + 1], t[2 * i + 2]);
92+
}
6493

65-
// TODO(william): Make generic to support min, max and other queries. One idea is to keep
66-
// segment multiple trees for each query type?
67-
t[i] = t[2 * i + 1] + t[2 * i + 2];
94+
/**
95+
* Returns the query of the range [l, r] on the original `values` array (+ any updates made to it)
96+
*
97+
* @param l the left endpoint of the range query (inclusive)
98+
* @param r the right endpoint of the range query (inclusive)
99+
*/
100+
public long rangeQuery(int l, int r) {
101+
return rangeQuery(0, 0, n - 1, l, r);
68102
}
69103

70104
/**
71-
* Returns the sum of the range [l, r] in the original `values` array.
105+
* Returns the query of the range [l, r] on the original `values` array (+ any updates made to it)
72106
*
73-
* @param l the left endpoint of the sum range query (inclusive)
74-
* @param r the right endpoint of the sum range query (inclusive)
107+
* @param l the left endpoint of the range query (inclusive)
108+
* @param r the right endpoint of the range query (inclusive)
75109
*/
76-
public long sumQuery(int l, int r) {
77-
return sumQuery(0, 0, n - 1, l, r);
110+
public long rangeQuery2(int l, int r) {
111+
return rangeQuery2(0, 0, n - 1, l, r);
78112
}
79113

80114
/**
@@ -84,23 +118,34 @@ public long sumQuery(int l, int r) {
84118
* @param l the target left endpoint for the range query
85119
* @param r the target right endpoint for the range query
86120
*/
87-
private long sumQuery(int i, int tl, int tr, int l, int r) {
121+
private long rangeQuery(int i, int tl, int tr, int l, int r) {
88122
if (l > r) {
89-
return 0;
123+
// Different segment tree types have different base cases:
124+
if (op == Operation.SUM) {
125+
return 0;
126+
} else if (op == Operation.MIN) {
127+
return Long.MAX_VALUE;
128+
} else if (op == Operation.MAX) {
129+
return Long.MIN_VALUE;
130+
}
90131
}
91132
if (tl == l && tr == r) {
92133
return t[i];
93134
}
94135
int tm = (tl + tr) / 2;
95136
// Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps
96137
// [l, r], simply recurse on both and return a sum of 0 if the interval is invalid.
97-
return sumQuery(2 * i + 1, tl, tm, l, Math.min(tm, r))
98-
+ sumQuery(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r);
138+
return fn.apply(
139+
rangeQuery(2 * i + 1, tl, tm, l, Math.min(tm, r)),
140+
rangeQuery(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r));
99141
}
100142

101143
// Alternative implementation of summing that intelligently only digs into
102-
// the branches which overlap with the query [l, r]
103-
private long sumQuery2(int i, int tl, int tr, int l, int r) {
144+
// the branches which overlap with the query [l, r].
145+
//
146+
// This version of the range query impl also has the advantage that it doesn't
147+
// need to know the explicit base case value for each query type.
148+
private long rangeQuery2(int i, int tl, int tr, int l, int r) {
104149
if (tl == l && tr == r) {
105150
return t[i];
106151
}
@@ -109,18 +154,19 @@ private long sumQuery2(int i, int tl, int tr, int l, int r) {
109154
boolean overlapsLeftSegment = (l <= tm);
110155
boolean overlapsRightSegment = (r > tm);
111156
if (overlapsLeftSegment && overlapsRightSegment) {
112-
return sumQuery2(2 * i + 1, tl, tm, l, Math.min(tm, r))
113-
+ sumQuery2(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r);
157+
return fn.apply(
158+
rangeQuery2(2 * i + 1, tl, tm, l, Math.min(tm, r)),
159+
rangeQuery2(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r));
114160
} else if (overlapsLeftSegment) {
115-
return sumQuery2(2 * i + 1, tl, tm, l, Math.min(tm, r));
161+
return rangeQuery2(2 * i + 1, tl, tm, l, Math.min(tm, r));
116162
} else {
117-
return sumQuery2(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r);
163+
return rangeQuery2(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r);
118164
}
119165
}
120166

121167
// Updates the segment tree to reflect that index `i` in the original `values` array was updated
122168
// to `newValue`.
123-
public void update(int i, int newValue) {
169+
public void update(int i, long newValue) {
124170
update(0, i, 0, n - 1, newValue);
125171
}
126172

@@ -137,7 +183,7 @@ public void update(int i, int newValue) {
137183
* @param tr the right segment endpoint
138184
* @param newValue the new value to update
139185
*/
140-
private void update(int at, int pos, int tl, int tr, int newValue) {
186+
private void update(int at, int pos, int tl, int tr, long newValue) {
141187
if (tl == tr) { // `tl == pos && tr == pos` might be clearer
142188
t[at] = newValue;
143189
return;
@@ -151,18 +197,52 @@ private void update(int at, int pos, int tl, int tr, int newValue) {
151197
update(2 * at + 2, pos, tm + 1, tr, newValue);
152198
}
153199
// Re-compute the segment value of the current segment on the callback
154-
t[at] = t[2 * at + 1] + t[2 * at + 2];
200+
t[at] = fn.apply(t[2 * at + 1], t[2 * at + 2]);
155201
}
156202

203+
////////////////////////////////////////////////////
204+
// Example usage: //
205+
////////////////////////////////////////////////////
206+
157207
public static void main(String[] args) {
158-
int[] values = new int[6];
159-
java.util.Arrays.fill(values, 1);
160-
RangeQueryPointUpdateSegmentTree st = new RangeQueryPointUpdateSegmentTree(values);
161-
System.out.println(st.sumQuery(1, 4)); // 4
162-
163-
st.update(1, 2);
164-
System.out.println(st.sumQuery(1, 1)); // 2
165-
System.out.println(st.sumQuery(0, 1)); // 3
166-
System.out.println(st.sumQuery(0, 2)); // 4
208+
rangeSumQueryExample();
209+
rangeMinQueryExample();
210+
rangeMaxQueryExample();
211+
}
212+
213+
private static void rangeSumQueryExample() {
214+
// 0 1 2 3
215+
long[] values = {1, 2, 3, 2};
216+
RangeQueryPointUpdateSegmentTree st =
217+
new RangeQueryPointUpdateSegmentTree(values, Operation.SUM);
218+
219+
int l = 0, r = 3;
220+
System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery(l, r));
221+
// Prints:
222+
// The sum between indeces [0, 3] is: 8
223+
}
224+
225+
private static void rangeMinQueryExample() {
226+
// 0 1 2 3
227+
long[] values = {1, 2, 3, 2};
228+
RangeQueryPointUpdateSegmentTree st =
229+
new RangeQueryPointUpdateSegmentTree(values, Operation.MIN);
230+
231+
int l = 0, r = 3;
232+
System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery(l, r));
233+
// Prints:
234+
// The sum between indeces [0, 3] is: 1
235+
}
236+
237+
private static void rangeMaxQueryExample() {
238+
// 0 1 2 3
239+
long[] values = {1, 2, 3, 2};
240+
RangeQueryPointUpdateSegmentTree st =
241+
new RangeQueryPointUpdateSegmentTree(values, Operation.MAX);
242+
243+
int l = 0, r = 3;
244+
System.out.printf("The sum between indeces [%d, %d] is: %d\n", l, r, st.rangeQuery(l, r));
245+
// Prints:
246+
// The sum between indeces [0, 3] is: 3
167247
}
168248
}

0 commit comments

Comments
 (0)