@@ -75,6 +75,103 @@ where
75
75
self . _fold ( start, end)
76
76
}
77
77
78
+ /// `f(fold(l..r)) = true` となる最大の `r` を返します。
79
+ ///
80
+ /// # Panics
81
+ ///
82
+ /// if `f(e) = false`
83
+ pub fn max_right < P > ( & self , l : usize , f : P ) -> usize
84
+ where
85
+ P : Fn ( & T ) -> bool ,
86
+ {
87
+ assert ! ( l <= self . original_n) ;
88
+ assert ! ( f( & self . e) , "f(e) must be true" ) ;
89
+
90
+ if l == self . original_n {
91
+ return self . original_n ;
92
+ }
93
+
94
+ let mut l = l + self . n ;
95
+ let mut sum = self . e . clone ( ) ;
96
+
97
+ loop {
98
+ // l を含む区間の右端まで進む
99
+ while l % 2 == 0 {
100
+ l >>= 1 ;
101
+ }
102
+
103
+ let new_sum = ( self . multiply ) ( & sum, & self . dat [ l] ) ;
104
+ if !f ( & new_sum) {
105
+ while l < self . n {
106
+ l <<= 1 ;
107
+ let new_sum = ( self . multiply ) ( & sum, & self . dat [ l] ) ;
108
+ if f ( & new_sum) {
109
+ sum = new_sum;
110
+ l += 1 ;
111
+ }
112
+ }
113
+ return l - self . n ;
114
+ }
115
+
116
+ sum = new_sum;
117
+ l += 1 ;
118
+
119
+ if ( l & ( l. wrapping_neg ( ) ) ) == l {
120
+ break ;
121
+ }
122
+ }
123
+
124
+ self . original_n
125
+ }
126
+
127
+ /// `f(fold(l..r)) = true` となる最小の `l` を返します。
128
+ ///
129
+ /// # Panics
130
+ ///
131
+ /// if `f(e) = false`
132
+ pub fn min_left < P > ( & self , r : usize , f : P ) -> usize
133
+ where
134
+ P : Fn ( & T ) -> bool ,
135
+ {
136
+ assert ! ( r <= self . original_n) ;
137
+ assert ! ( f( & self . e) , "f(e) must be true" ) ;
138
+
139
+ if r == 0 {
140
+ return 0 ;
141
+ }
142
+
143
+ let mut r = r + self . n ;
144
+ let mut sum = self . e . clone ( ) ;
145
+
146
+ loop {
147
+ r -= 1 ;
148
+ while r > 1 && r % 2 == 1 {
149
+ r >>= 1 ;
150
+ }
151
+
152
+ let new_sum = ( self . multiply ) ( & self . dat [ r] , & sum) ;
153
+ if !f ( & new_sum) {
154
+ while r < self . n {
155
+ r = r * 2 + 1 ;
156
+ let new_sum = ( self . multiply ) ( & self . dat [ r] , & sum) ;
157
+ if f ( & new_sum) {
158
+ sum = new_sum;
159
+ r -= 1 ;
160
+ }
161
+ }
162
+ return r + 1 - self . n ;
163
+ }
164
+
165
+ sum = new_sum;
166
+
167
+ if ( r & ( r. wrapping_neg ( ) ) ) == r {
168
+ break ;
169
+ }
170
+ }
171
+
172
+ 0
173
+ }
174
+
78
175
fn _fold ( & self , mut l : usize , mut r : usize ) -> T {
79
176
let mut acc_l = self . e . clone ( ) ;
80
177
let mut acc_r = self . e . clone ( ) ;
@@ -154,4 +251,49 @@ mod tests {
154
251
seg. set ( 0 , 42 ) ;
155
252
assert_eq ! ( seg[ 0 ] , 42 ) ;
156
253
}
254
+
255
+ #[ test]
256
+ fn test_max_right ( ) {
257
+ let n = 9 ;
258
+ let mut seg = SegmentTree :: new ( n, 0 , |a, b| a + b) ;
259
+ let values = vec ! [ 3 , 1 , 4 , 1 , 5 , 9 , 2 , 6 , 5 ] ;
260
+ for ( i, & v) in values. iter ( ) . enumerate ( ) {
261
+ seg. set ( i, v) ;
262
+ }
263
+
264
+ // 区間和
265
+ assert_eq ! ( seg. max_right( 0 , |& sum| sum < 9 ) , 3 ) ; // 3 + 1 + 4 = 8
266
+ assert_eq ! ( seg. max_right( 0 , |& sum| sum <= 9 ) , 4 ) ; // 3 + 1 + 4 + 1 = 9
267
+
268
+ assert_eq ! ( seg. max_right( 1 , |& sum| sum < 11 ) , 4 ) ; // 1 + 4 + 1 = 6
269
+ assert_eq ! ( seg. max_right( 1 , |& sum| sum <= 11 ) , 5 ) ; // 1 + 4 + 1 + 5 = 11
270
+
271
+ assert_eq ! ( seg. max_right( 2 , |& sum| sum < 4 ) , 2 ) ;
272
+ assert_eq ! ( seg. max_right( 2 , |& sum| sum <= 4 ) , 3 ) ;
273
+ assert_eq ! ( seg. max_right( 2 , |& sum| sum <= 100 ) , n) ;
274
+
275
+ assert_eq ! ( seg. max_right( n, |& sum| sum <= 0 ) , n) ;
276
+ assert_eq ! ( seg. max_right( n, |& sum| sum <= 100 ) , n) ;
277
+ }
278
+
279
+ #[ test]
280
+ fn test_min_left ( ) {
281
+ let n = 9 ;
282
+ let mut seg = SegmentTree :: new ( n, 0 , |a, b| a + b) ;
283
+ let values = vec ! [ 3 , 1 , 4 , 1 , 5 , 9 , 2 , 6 , 5 ] ;
284
+ for ( i, & v) in values. iter ( ) . enumerate ( ) {
285
+ seg. set ( i, v) ;
286
+ }
287
+
288
+ // 区間和
289
+ assert_eq ! ( seg. min_left( n, |& sum| sum <= 22 ) , 5 ) ; // 9 + 2 + 6 + 5 = 22
290
+ assert_eq ! ( seg. min_left( n, |& sum| sum < 22 ) , 6 ) ; // 2 + 6 + 5 = 13
291
+
292
+ assert_eq ! ( seg. min_left( n - 1 , |& sum| sum <= 27 ) , 2 ) ; // 4 + 1 + 5 + 9 + 2 + 6 = 27
293
+ assert_eq ! ( seg. min_left( n - 1 , |& sum| sum < 27 ) , 3 ) ; // 1 + 5 + 9 + 2 + 6 = 23
294
+ assert_eq ! ( seg. min_left( n - 1 , |& sum| sum < 100 ) , 0 ) ;
295
+
296
+ assert_eq ! ( seg. min_left( 0 , |& sum| sum <= 0 ) , 0 ) ;
297
+ assert_eq ! ( seg. min_left( 0 , |& sum| sum <= 100 ) , 0 ) ;
298
+ }
157
299
}
0 commit comments