13
13
*/
14
14
package com .williamfiset .algorithms .datastructures .segmenttree ;
15
15
16
+ import java .util .Arrays ;
17
+ import java .util .function .BinaryOperator ;
18
+
16
19
public class RangeQueryPointUpdateSegmentTree {
17
- // TODO(william): make the members of this class private
18
20
19
- // Tree segment values.
20
- Integer [] t ;
21
+ // The type of segment combination function to use
22
+ public static enum Operation {
23
+ SUM ,
24
+ MIN ,
25
+ MAX
26
+ }
21
27
22
28
// The number of values in the original input values array.
23
- int n ;
29
+ private int n ;
30
+
31
+ private long [] t ;
24
32
25
- // The size of the segment tree `t`
26
- // NOTE: the size is not necessarily = number of segments.
27
- int N ;
33
+ private Operation op ;
28
34
29
- public RangeQueryPointUpdateSegmentTree (int [] values ) {
35
+ // The chosen range combination function
36
+ private BinaryOperator <Long > fn ;
37
+
38
+ private BinaryOperator <Long > sumFn = (a , b ) -> a + b ;
39
+ private BinaryOperator <Long > minFn = (a , b ) -> Math .min (a , b );
40
+ private BinaryOperator <Long > maxFn = (a , b ) -> Math .max (a , b );
41
+
42
+ public RangeQueryPointUpdateSegmentTree (long [] values , Operation op ) {
30
43
if (values == null ) {
31
44
throw new NullPointerException ("Segment tree values cannot be null." );
32
45
}
46
+ if (op == null ) {
47
+ throw new NullPointerException ("Please specify a valid segment combination operation." );
48
+ }
33
49
n = values .length ;
50
+ this .op = op ;
51
+
52
+ // The size of the segment tree `t`
53
+ //
34
54
// TODO(william): Investigate to reduce this space. There are only 2n-1 segments, so we should
35
55
// be able to reduce the space, but may need to reorganize the tree/queries. One idea is to use
36
56
// the Eulerian tour structure of the tree to densely pack the segments.
37
- N = 4 * n ;
38
- t = new Integer [N ];
57
+ int N = 4 * n ;
58
+
59
+ t = new long [N ];
60
+
61
+ if (op == Operation .SUM ) {
62
+ fn = sumFn ;
63
+ } else if (op == Operation .MIN ) {
64
+ Arrays .fill (t , Long .MAX_VALUE );
65
+ fn = minFn ;
66
+ } else if (op == Operation .MAX ) {
67
+ Arrays .fill (t , Long .MIN_VALUE );
68
+ fn = maxFn ;
69
+ }
39
70
40
- buildTree (0 , 0 , n - 1 , values );
41
- // System.out.println(java.util.Arrays.toString(values));
42
- // System.out.println(java.util.Arrays.toString(t));
71
+ buildSegmentTree (0 , 0 , n - 1 , values );
43
72
}
44
73
45
74
/**
46
- * Builds the segment tree starting with leaf nodes and combining values on callback. This
47
- * construction method takes O(n) time since there are only 2n - 1 segments in the segment tree.
75
+ * Builds a segment tree by starting with the leaf nodes and combining segment values on callback.
48
76
*
49
77
* @param i the index of the segment in the segment tree
50
- * @param l the left index of the range on the values array
51
- * @param r the right index of the range on the values array
78
+ * @param l the left index (inclusive) of the range in the values array
79
+ * @param r the right index (inclusive) of the range in the values array
52
80
* @param values the initial values array
53
- * <p>The range [l, r] over the values array is inclusive.
54
81
*/
55
- private void buildTree (int i , int tl , int tr , int [] values ) {
82
+ private void buildSegmentTree (int i , int tl , int tr , long [] values ) {
56
83
if (tl == tr ) {
57
84
t [i ] = values [tl ];
58
85
return ;
59
86
}
60
87
int mid = (tl + tr ) / 2 ;
88
+ buildSegmentTree (2 * i + 1 , tl , mid , values );
89
+ buildSegmentTree (2 * i + 2 , mid + 1 , tr , values );
61
90
62
- buildTree ( 2 * i + 1 , tl , mid , values );
63
- buildTree ( 2 * i + 2 , mid + 1 , tr , values );
91
+ t [ i ] = fn . apply ( t [ 2 * i + 1 ], t [ 2 * i + 2 ] );
92
+ }
64
93
65
- // TODO(william): Make generic to support min, max and other queries. One idea is to keep
66
- // segment multiple trees for each query type?
67
- t [i ] = t [2 * i + 1 ] + t [2 * i + 2 ];
94
+ /**
95
+ * Returns the query of the range [l, r] on the original `values` array (+ any updates made to it)
96
+ *
97
+ * @param l the left endpoint of the range query (inclusive)
98
+ * @param r the right endpoint of the range query (inclusive)
99
+ */
100
+ public long rangeQuery (int l , int r ) {
101
+ return rangeQuery (0 , 0 , n - 1 , l , r );
68
102
}
69
103
70
104
/**
71
- * Returns the sum of the range [l, r] in the original `values` array.
105
+ * Returns the query of the range [l, r] on the original `values` array (+ any updates made to it)
72
106
*
73
- * @param l the left endpoint of the sum range query (inclusive)
74
- * @param r the right endpoint of the sum range query (inclusive)
107
+ * @param l the left endpoint of the range query (inclusive)
108
+ * @param r the right endpoint of the range query (inclusive)
75
109
*/
76
- public long sumQuery (int l , int r ) {
77
- return sumQuery (0 , 0 , n - 1 , l , r );
110
+ public long rangeQuery2 (int l , int r ) {
111
+ return rangeQuery2 (0 , 0 , n - 1 , l , r );
78
112
}
79
113
80
114
/**
@@ -84,23 +118,34 @@ public long sumQuery(int l, int r) {
84
118
* @param l the target left endpoint for the range query
85
119
* @param r the target right endpoint for the range query
86
120
*/
87
- private long sumQuery (int i , int tl , int tr , int l , int r ) {
121
+ private long rangeQuery (int i , int tl , int tr , int l , int r ) {
88
122
if (l > r ) {
89
- return 0 ;
123
+ // Different segment tree types have different base cases:
124
+ if (op == Operation .SUM ) {
125
+ return 0 ;
126
+ } else if (op == Operation .MIN ) {
127
+ return Long .MAX_VALUE ;
128
+ } else if (op == Operation .MAX ) {
129
+ return Long .MIN_VALUE ;
130
+ }
90
131
}
91
132
if (tl == l && tr == r ) {
92
133
return t [i ];
93
134
}
94
135
int tm = (tl + tr ) / 2 ;
95
136
// Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps
96
137
// [l, r], simply recurse on both and return a sum of 0 if the interval is invalid.
97
- return sumQuery (2 * i + 1 , tl , tm , l , Math .min (tm , r ))
98
- + sumQuery (2 * i + 2 , tm + 1 , tr , Math .max (l , tm + 1 ), r );
138
+ return fn .apply (
139
+ rangeQuery (2 * i + 1 , tl , tm , l , Math .min (tm , r )),
140
+ rangeQuery (2 * i + 2 , tm + 1 , tr , Math .max (l , tm + 1 ), r ));
99
141
}
100
142
101
143
// Alternative implementation of summing that intelligently only digs into
102
- // the branches which overlap with the query [l, r]
103
- private long sumQuery2 (int i , int tl , int tr , int l , int r ) {
144
+ // the branches which overlap with the query [l, r].
145
+ //
146
+ // This version of the range query impl also has the advantage that it doesn't
147
+ // need to know the explicit base case value for each query type.
148
+ private long rangeQuery2 (int i , int tl , int tr , int l , int r ) {
104
149
if (tl == l && tr == r ) {
105
150
return t [i ];
106
151
}
@@ -109,18 +154,19 @@ private long sumQuery2(int i, int tl, int tr, int l, int r) {
109
154
boolean overlapsLeftSegment = (l <= tm );
110
155
boolean overlapsRightSegment = (r > tm );
111
156
if (overlapsLeftSegment && overlapsRightSegment ) {
112
- return sumQuery2 (2 * i + 1 , tl , tm , l , Math .min (tm , r ))
113
- + sumQuery2 (2 * i + 2 , tm + 1 , tr , Math .max (l , tm + 1 ), r );
157
+ return fn .apply (
158
+ rangeQuery2 (2 * i + 1 , tl , tm , l , Math .min (tm , r )),
159
+ rangeQuery2 (2 * i + 2 , tm + 1 , tr , Math .max (l , tm + 1 ), r ));
114
160
} else if (overlapsLeftSegment ) {
115
- return sumQuery2 (2 * i + 1 , tl , tm , l , Math .min (tm , r ));
161
+ return rangeQuery2 (2 * i + 1 , tl , tm , l , Math .min (tm , r ));
116
162
} else {
117
- return sumQuery2 (2 * i + 2 , tm + 1 , tr , Math .max (l , tm + 1 ), r );
163
+ return rangeQuery2 (2 * i + 2 , tm + 1 , tr , Math .max (l , tm + 1 ), r );
118
164
}
119
165
}
120
166
121
167
// Updates the segment tree to reflect that index `i` in the original `values` array was updated
122
168
// to `newValue`.
123
- public void update (int i , int newValue ) {
169
+ public void update (int i , long newValue ) {
124
170
update (0 , i , 0 , n - 1 , newValue );
125
171
}
126
172
@@ -137,7 +183,7 @@ public void update(int i, int newValue) {
137
183
* @param tr the right segment endpoint
138
184
* @param newValue the new value to update
139
185
*/
140
- private void update (int at , int pos , int tl , int tr , int newValue ) {
186
+ private void update (int at , int pos , int tl , int tr , long newValue ) {
141
187
if (tl == tr ) { // `tl == pos && tr == pos` might be clearer
142
188
t [at ] = newValue ;
143
189
return ;
@@ -151,18 +197,52 @@ private void update(int at, int pos, int tl, int tr, int newValue) {
151
197
update (2 * at + 2 , pos , tm + 1 , tr , newValue );
152
198
}
153
199
// Re-compute the segment value of the current segment on the callback
154
- t [at ] = t [2 * at + 1 ] + t [2 * at + 2 ];
200
+ t [at ] = fn . apply ( t [2 * at + 1 ], t [2 * at + 2 ]) ;
155
201
}
156
202
203
+ ////////////////////////////////////////////////////
204
+ // Example usage: //
205
+ ////////////////////////////////////////////////////
206
+
157
207
public static void main (String [] args ) {
158
- int [] values = new int [6 ];
159
- java .util .Arrays .fill (values , 1 );
160
- RangeQueryPointUpdateSegmentTree st = new RangeQueryPointUpdateSegmentTree (values );
161
- System .out .println (st .sumQuery (1 , 4 )); // 4
162
-
163
- st .update (1 , 2 );
164
- System .out .println (st .sumQuery (1 , 1 )); // 2
165
- System .out .println (st .sumQuery (0 , 1 )); // 3
166
- System .out .println (st .sumQuery (0 , 2 )); // 4
208
+ rangeSumQueryExample ();
209
+ rangeMinQueryExample ();
210
+ rangeMaxQueryExample ();
211
+ }
212
+
213
+ private static void rangeSumQueryExample () {
214
+ // 0 1 2 3
215
+ long [] values = {1 , 2 , 3 , 2 };
216
+ RangeQueryPointUpdateSegmentTree st =
217
+ new RangeQueryPointUpdateSegmentTree (values , Operation .SUM );
218
+
219
+ int l = 0 , r = 3 ;
220
+ System .out .printf ("The sum between indeces [%d, %d] is: %d\n " , l , r , st .rangeQuery (l , r ));
221
+ // Prints:
222
+ // The sum between indeces [0, 3] is: 8
223
+ }
224
+
225
+ private static void rangeMinQueryExample () {
226
+ // 0 1 2 3
227
+ long [] values = {1 , 2 , 3 , 2 };
228
+ RangeQueryPointUpdateSegmentTree st =
229
+ new RangeQueryPointUpdateSegmentTree (values , Operation .MIN );
230
+
231
+ int l = 0 , r = 3 ;
232
+ System .out .printf ("The sum between indeces [%d, %d] is: %d\n " , l , r , st .rangeQuery (l , r ));
233
+ // Prints:
234
+ // The sum between indeces [0, 3] is: 1
235
+ }
236
+
237
+ private static void rangeMaxQueryExample () {
238
+ // 0 1 2 3
239
+ long [] values = {1 , 2 , 3 , 2 };
240
+ RangeQueryPointUpdateSegmentTree st =
241
+ new RangeQueryPointUpdateSegmentTree (values , Operation .MAX );
242
+
243
+ int l = 0 , r = 3 ;
244
+ System .out .printf ("The sum between indeces [%d, %d] is: %d\n " , l , r , st .rangeQuery (l , r ));
245
+ // Prints:
246
+ // The sum between indeces [0, 3] is: 3
167
247
}
168
248
}
0 commit comments