Skip to content

Commit cda383d

Browse files
committed
IterateT#flatMap is now stack-safe regardless of how many consecutive empty IterateTs are returned
1 parent 2cdc1cf commit cda383d

File tree

3 files changed

+77
-14
lines changed

3 files changed

+77
-14
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/).
1919
### Fixed
2020
- `IterateT#trampolineM` now yields and stages all recursive result values, rather
2121
than prematurely terminating on the first termination result
22+
- `IterateT#flatMap` is now stack-safe regardless of how many consecutive empty `IterateT`s
23+
are returned and regardless of whether the monad is strict or lazy or internally trampolined
2224

2325
## [5.2.0] - 2020-02-12
2426

src/main/java/com/jnape/palatable/lambda/monad/transformer/builtin/IterateT.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static com.jnape.palatable.lambda.adt.choice.Choice2.a;
3030
import static com.jnape.palatable.lambda.adt.choice.Choice2.b;
3131
import static com.jnape.palatable.lambda.adt.hlist.HList.tuple;
32+
import static com.jnape.palatable.lambda.functions.Fn0.fn0;
3233
import static com.jnape.palatable.lambda.functions.Fn1.withSelf;
3334
import static com.jnape.palatable.lambda.functions.builtin.fn2.$.$;
3435
import static com.jnape.palatable.lambda.functions.builtin.fn2.Into.into;
@@ -65,7 +66,7 @@
6566
* @param <A> the element type
6667
*/
6768
public class IterateT<M extends MonadRec<?, M>, A> implements
68-
MonadT<M, A, IterateT<M, ?>, IterateT<?, ?>> {
69+
MonadT<M, A, IterateT<M, ?>, IterateT<?, ?>> {
6970

7071
private final Pure<M> pureM;
7172
private final ImmutableQueue<Choice2<Fn0<MonadRec<Maybe<Tuple2<A, IterateT<M, A>>>, M>>, MonadRec<A, M>>> spine;
@@ -208,11 +209,15 @@ public <B> IterateT<M, B> trampolineM(
208209
@Override
209210
public <B> IterateT<M, B> flatMap(Fn1<? super A, ? extends Monad<B, IterateT<M, ?>>> f) {
210211
return suspended(() -> maybeT(runIterateT())
211-
.flatMap(into((a, as) -> maybeT(f.apply(a)
212-
.<IterateT<M, B>>coerce()
213-
.concat(as.flatMap(f))
214-
.runIterateT())))
215-
.runMaybeT(), pureM);
212+
.trampolineM(into((a, as) -> maybeT(
213+
f.apply(a).<IterateT<M, B>>coerce().runIterateT()
214+
.flatMap(maybePair -> maybePair.match(
215+
fn0(() -> as.runIterateT()
216+
.fmap(maybeResult -> maybeResult.fmap(RecursiveResult::recurse))),
217+
t -> pureM.apply(just(terminate(t.fmap(mb -> mb.concat(as.flatMap(f))))))
218+
)))))
219+
.runMaybeT(),
220+
pureM);
216221
}
217222

218223
/**

src/test/java/com/jnape/palatable/lambda/monad/transformer/builtin/IterateTTest.java

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,47 @@
88
import com.jnape.palatable.lambda.functor.builtin.Lazy;
99
import com.jnape.palatable.lambda.functor.builtin.Writer;
1010
import com.jnape.palatable.lambda.io.IO;
11+
import com.jnape.palatable.lambda.monoid.Monoid;
1112
import com.jnape.palatable.traitor.annotations.TestTraits;
1213
import com.jnape.palatable.traitor.framework.Subjects;
1314
import com.jnape.palatable.traitor.runners.Traits;
1415
import org.junit.Test;
1516
import org.junit.runner.RunWith;
16-
import testsupport.traits.*;
17+
import testsupport.traits.ApplicativeLaws;
18+
import testsupport.traits.Equivalence;
19+
import testsupport.traits.FunctorLaws;
20+
import testsupport.traits.MonadLaws;
21+
import testsupport.traits.MonadRecLaws;
1722

1823
import java.util.ArrayList;
1924
import java.util.Collection;
2025
import java.util.List;
2126
import java.util.concurrent.CountDownLatch;
27+
import java.util.concurrent.atomic.AtomicInteger;
2228

2329
import static com.jnape.palatable.lambda.adt.Maybe.just;
2430
import static com.jnape.palatable.lambda.adt.Maybe.nothing;
2531
import static com.jnape.palatable.lambda.adt.Unit.UNIT;
2632
import static com.jnape.palatable.lambda.adt.hlist.HList.tuple;
33+
import static com.jnape.palatable.lambda.functions.builtin.fn1.Constantly.constantly;
2734
import static com.jnape.palatable.lambda.functions.builtin.fn2.LTE.lte;
2835
import static com.jnape.palatable.lambda.functions.builtin.fn3.Times.times;
2936
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.recurse;
3037
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.terminate;
3138
import static com.jnape.palatable.lambda.functor.builtin.Identity.pureIdentity;
3239
import static com.jnape.palatable.lambda.functor.builtin.Lazy.lazy;
33-
import static com.jnape.palatable.lambda.functor.builtin.Writer.*;
40+
import static com.jnape.palatable.lambda.functor.builtin.Writer.listen;
41+
import static com.jnape.palatable.lambda.functor.builtin.Writer.pureWriter;
42+
import static com.jnape.palatable.lambda.functor.builtin.Writer.tell;
43+
import static com.jnape.palatable.lambda.functor.builtin.Writer.writer;
3444
import static com.jnape.palatable.lambda.io.IO.io;
35-
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.*;
45+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.empty;
46+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.iterateT;
47+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.liftIterateT;
48+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.of;
49+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.pureIterateT;
50+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.singleton;
51+
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.unfold;
3652
import static com.jnape.palatable.lambda.monoid.builtin.AddAll.addAll;
3753
import static com.jnape.palatable.lambda.monoid.builtin.Join.join;
3854
import static com.jnape.palatable.traitor.framework.Subjects.subjects;
@@ -44,7 +60,9 @@
4460
import static org.junit.Assert.assertThat;
4561
import static testsupport.Constants.STACK_EXPLODING_NUMBER;
4662
import static testsupport.matchers.IOMatcher.yieldsValue;
47-
import static testsupport.matchers.IterateTMatcher.*;
63+
import static testsupport.matchers.IterateTMatcher.isEmpty;
64+
import static testsupport.matchers.IterateTMatcher.iterates;
65+
import static testsupport.matchers.IterateTMatcher.iteratesAll;
4866
import static testsupport.traits.Equivalence.equivalence;
4967

5068
@RunWith(Traits.class)
@@ -236,16 +254,16 @@ public void concatIsStackSafe() {
236254
public void staticPure() {
237255
assertEquals(new Identity<>(singletonList(1)),
238256
pureIterateT(pureIdentity())
239-
.<Integer, IterateT<Identity<?>, Integer>>apply(1)
240-
.<List<Integer>, Identity<List<Integer>>>toCollection(ArrayList::new));
257+
.<Integer, IterateT<Identity<?>, Integer>>apply(1)
258+
.<List<Integer>, Identity<List<Integer>>>toCollection(ArrayList::new));
241259
}
242260

243261
@Test
244262
public void staticLift() {
245263
assertEquals(new Identity<>(singletonList(1)),
246264
liftIterateT()
247-
.<Integer, Identity<?>, IterateT<Identity<?>, Integer>>apply(new Identity<>(1))
248-
.<List<Integer>, Identity<List<Integer>>>toCollection(ArrayList::new));
265+
.<Integer, Identity<?>, IterateT<Identity<?>, Integer>>apply(new Identity<>(1))
266+
.<List<Integer>, Identity<List<Integer>>>toCollection(ArrayList::new));
249267
}
250268

251269
@Test
@@ -257,4 +275,42 @@ public void trampolineMRecursesBreadth() {
257275
: singleton(new Identity<>(terminate(x))));
258276
assertThat(trampolined, iterates(1, 2, 13, 14, 25, 26, 37, 38, 39, 40, 28, 16, 4));
259277
}
278+
279+
@Test
280+
public void flatMapToEmptyStackSafety() {
281+
assertEquals(new Identity<>(UNIT),
282+
unfold(x -> new Identity<>(x <= STACK_EXPLODING_NUMBER ? just(tuple(x, x + 1)) : nothing()),
283+
new Identity<>(1))
284+
.flatMap(constantly(iterateT(new Identity<>(nothing()))))
285+
.forEach(constantly(new Identity<>(UNIT))));
286+
287+
assertEquals((Integer) 1_250_025_000,
288+
unfold(x -> listen(x <= STACK_EXPLODING_NUMBER ? just(tuple(x, x + 1)) : nothing()),
289+
Writer.<Integer, Integer>listen(1))
290+
.flatMap(x -> iterateT(writer(tuple(nothing(), x))))
291+
.<Writer<Integer, Unit>>forEach(constantly(listen(UNIT)))
292+
.runWriter(Monoid.monoid(Integer::sum, 0))
293+
._2());
294+
}
295+
296+
@Test
297+
public void flatMapCostsNoMoreEffortThanRequiredToYieldFirstValue() {
298+
AtomicInteger flatMapCost = new AtomicInteger(0);
299+
AtomicInteger unfoldCost = new AtomicInteger(0);
300+
assertEquals(just(1),
301+
unfold(x -> {
302+
unfoldCost.incrementAndGet();
303+
return new Identity<>(x <= 10 ? just(tuple(x, x + 1)) : nothing());
304+
},
305+
new Identity<>(1))
306+
.flatMap(x -> {
307+
flatMapCost.incrementAndGet();
308+
return singleton(new Identity<>(x));
309+
})
310+
.<Identity<Maybe<Tuple2<Integer, IterateT<Identity<?>, Integer>>>>>runIterateT()
311+
.runIdentity()
312+
.fmap(Tuple2::_1));
313+
assertEquals(1, flatMapCost.get());
314+
assertEquals(1, unfoldCost.get());
315+
}
260316
}

0 commit comments

Comments
 (0)