Skip to content

Commit 82c40a4

Browse files
committed
MonadRec instance for LambdaIterable
1 parent e207252 commit 82c40a4

File tree

5 files changed

+199
-6
lines changed

5 files changed

+199
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/).
1616

1717
### Added
1818
- `MonadError`, monads that can be thrown to and caught from, with defaults for `IO`, `Either`, `Maybe`, and `Try`
19+
- `MonadRec`, monads that support a stack-safe `trampolineM` method with defaults for all exported monads
1920
- `Optic#andThen`, `Optic#compose`, and other defaults added
2021
- `Prism#andThen`, `Prism#compose` begets another `Prism`
2122
- `Prism#fromPartial` public interfaces
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package com.jnape.palatable.lambda.internal.iteration;
2+
3+
import com.jnape.palatable.lambda.functions.Fn0;
4+
import com.jnape.palatable.lambda.functions.Fn1;
5+
import com.jnape.palatable.lambda.functions.recursion.RecursiveResult;
6+
7+
import java.util.Iterator;
8+
import java.util.NoSuchElementException;
9+
10+
import static com.jnape.palatable.lambda.functions.builtin.fn1.Constantly.constantly;
11+
import static com.jnape.palatable.lambda.functions.builtin.fn1.Not.not;
12+
import static com.jnape.palatable.lambda.io.IO.io;
13+
14+
public final class TrampoliningIterator<A, B> implements Iterator<B> {
15+
private final Fn1<? super A, ? extends Iterable<RecursiveResult<A, B>>> fn;
16+
private final A a;
17+
18+
private ImmutableQueue<Iterator<RecursiveResult<A, B>>> remaining;
19+
private B b;
20+
21+
public TrampoliningIterator(Fn1<? super A, ? extends Iterable<RecursiveResult<A, B>>> fn, A a) {
22+
this.fn = fn;
23+
this.a = a;
24+
}
25+
26+
@Override
27+
public boolean hasNext() {
28+
queueNextIfPossible();
29+
return b != null;
30+
}
31+
32+
@Override
33+
public B next() {
34+
if (!hasNext())
35+
throw new NoSuchElementException();
36+
B next = b;
37+
b = null;
38+
return next;
39+
}
40+
41+
private void queueNextIfPossible() {
42+
if (remaining == null)
43+
pruneAfter(() -> remaining = ImmutableQueue.<Iterator<RecursiveResult<A, B>>>empty()
44+
.pushFront(fn.apply(a).iterator()));
45+
46+
while (b == null && remaining.head().match(constantly(false), constantly(true))) {
47+
tickNext();
48+
}
49+
}
50+
51+
private void tickNext() {
52+
pruneAfter(() -> remaining.head().orElseThrow(NoSuchElementException::new).next())
53+
.match(a -> io(() -> {
54+
pruneAfter(() -> remaining = remaining.pushFront(fn.apply(a).iterator()));
55+
}), b -> io(() -> {
56+
this.b = b;
57+
})).unsafePerformIO();
58+
}
59+
60+
private <R> R pruneAfter(Fn0<? extends R> fn) {
61+
R r = fn.apply();
62+
while (remaining.head().match(constantly(false), not(Iterator::hasNext))) {
63+
remaining = remaining.tail();
64+
}
65+
return r;
66+
}
67+
}

src/main/java/com/jnape/palatable/lambda/traversable/LambdaIterable.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import com.jnape.palatable.lambda.functions.Fn1;
44
import com.jnape.palatable.lambda.functions.builtin.fn1.Empty;
55
import com.jnape.palatable.lambda.functions.builtin.fn3.FoldRight;
6+
import com.jnape.palatable.lambda.functions.recursion.RecursiveResult;
67
import com.jnape.palatable.lambda.functions.specialized.Pure;
78
import com.jnape.palatable.lambda.functor.Applicative;
89
import com.jnape.palatable.lambda.functor.builtin.Lazy;
10+
import com.jnape.palatable.lambda.internal.iteration.TrampoliningIterator;
911
import com.jnape.palatable.lambda.monad.Monad;
12+
import com.jnape.palatable.lambda.monad.MonadRec;
1013

1114
import java.util.Iterator;
1215
import java.util.Objects;
@@ -24,7 +27,10 @@
2427
* @param <A> the {@link Iterable} element type
2528
* @see LambdaMap
2629
*/
27-
public final class LambdaIterable<A> implements Monad<A, LambdaIterable<?>>, Traversable<A, LambdaIterable<?>> {
30+
public final class LambdaIterable<A> implements
31+
MonadRec<A, LambdaIterable<?>>,
32+
Traversable<A, LambdaIterable<?>> {
33+
2834
private final Iterable<A> as;
2935

3036
@SuppressWarnings("unchecked")
@@ -69,7 +75,7 @@ public <B> LambdaIterable<B> pure(B b) {
6975
*/
7076
@Override
7177
public <B> LambdaIterable<B> zip(Applicative<Fn1<? super A, ? extends B>, LambdaIterable<?>> appFn) {
72-
return Monad.super.zip(appFn).coerce();
78+
return MonadRec.super.zip(appFn).coerce();
7379
}
7480

7581
/**
@@ -80,23 +86,23 @@ public <B> Lazy<LambdaIterable<B>> lazyZip(
8086
Lazy<? extends Applicative<Fn1<? super A, ? extends B>, LambdaIterable<?>>> lazyAppFn) {
8187
return Empty.empty(as)
8288
? lazy(LambdaIterable.empty())
83-
: Monad.super.lazyZip(lazyAppFn).fmap(Monad<B, LambdaIterable<?>>::coerce);
89+
: MonadRec.super.lazyZip(lazyAppFn).fmap(Monad<B, LambdaIterable<?>>::coerce);
8490
}
8591

8692
/**
8793
* {@inheritDoc}
8894
*/
8995
@Override
9096
public <B> LambdaIterable<B> discardL(Applicative<B, LambdaIterable<?>> appB) {
91-
return Monad.super.discardL(appB).coerce();
97+
return MonadRec.super.discardL(appB).coerce();
9298
}
9399

94100
/**
95101
* {@inheritDoc}
96102
*/
97103
@Override
98104
public <B> LambdaIterable<A> discardR(Applicative<B, LambdaIterable<?>> appB) {
99-
return Monad.super.discardR(appB).coerce();
105+
return MonadRec.super.discardR(appB).coerce();
100106
}
101107

102108
/**
@@ -107,6 +113,19 @@ public <B> LambdaIterable<B> flatMap(Fn1<? super A, ? extends Monad<B, LambdaIte
107113
return wrap(flatten(map(a -> f.apply(a).<LambdaIterable<B>>coerce().unwrap(), as)));
108114
}
109115

116+
/**
117+
* {@inheritDoc}
118+
*/
119+
@Override
120+
public <B> LambdaIterable<B> trampolineM(
121+
Fn1<? super A, ? extends MonadRec<RecursiveResult<A, B>, LambdaIterable<?>>> fn) {
122+
return flatMap(a -> wrap(() -> new TrampoliningIterator<>(
123+
x -> fn.apply(x)
124+
.<LambdaIterable<RecursiveResult<A, B>>>coerce()
125+
.unwrap(),
126+
a)));
127+
}
128+
110129
/**
111130
* {@inheritDoc}
112131
*/
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package com.jnape.palatable.lambda.internal.iteration;
2+
3+
import org.junit.Test;
4+
5+
import static com.jnape.palatable.lambda.functions.builtin.fn1.Constantly.constantly;
6+
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.recurse;
7+
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.terminate;
8+
import static java.util.Arrays.asList;
9+
import static java.util.Collections.emptyList;
10+
import static java.util.Collections.singleton;
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertFalse;
13+
import static org.junit.Assert.assertTrue;
14+
15+
public class TrampoliningIteratorTest {
16+
17+
@Test
18+
public void hasNextIfAnyTerminateInstructions() {
19+
TrampoliningIterator<Integer, Object> it = new TrampoliningIterator<>(x -> singleton(terminate(x + 1)), 0);
20+
assertTrue(it.hasNext());
21+
assertEquals(1, it.next());
22+
assertFalse(it.hasNext());
23+
}
24+
25+
@Test
26+
public void hasNextIfTerminateInterleavedBeforeRecurse() {
27+
TrampoliningIterator<Integer, Object> it = new TrampoliningIterator<>(
28+
x -> x < 3
29+
? asList(terminate(x), recurse(x + 1))
30+
: emptyList(),
31+
0);
32+
assertTrue(it.hasNext());
33+
assertEquals(0, it.next());
34+
assertTrue(it.hasNext());
35+
assertEquals(1, it.next());
36+
assertTrue(it.hasNext());
37+
assertEquals(2, it.next());
38+
assertFalse(it.hasNext());
39+
}
40+
41+
@Test
42+
public void hasNextIfTerminateInterleavedAfterRecurse() {
43+
TrampoliningIterator<Integer, Object> it = new TrampoliningIterator<>(
44+
x -> x < 3
45+
? asList(recurse(x + 1), terminate(x))
46+
: emptyList(),
47+
0);
48+
assertTrue(it.hasNext());
49+
assertEquals(2, it.next());
50+
assertTrue(it.hasNext());
51+
assertEquals(1, it.next());
52+
assertTrue(it.hasNext());
53+
assertEquals(0, it.next());
54+
assertFalse(it.hasNext());
55+
}
56+
57+
@Test
58+
public void doesNotHaveNextIfEmptyInitialResult() {
59+
TrampoliningIterator<Integer, Object> it = new TrampoliningIterator<>(constantly(emptyList()), 0);
60+
assertFalse(it.hasNext());
61+
}
62+
63+
@Test
64+
public void doesNotHaveNextIfNoTerminateInstruction() {
65+
TrampoliningIterator<Integer, Object> it = new TrampoliningIterator<>(
66+
x -> x < 3
67+
? singleton(recurse(x + 1))
68+
: emptyList(),
69+
0);
70+
assertFalse(it.hasNext());
71+
}
72+
}

src/test/java/com/jnape/palatable/lambda/traversable/LambdaIterableTest.java

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import testsupport.traits.ApplicativeLaws;
1111
import testsupport.traits.FunctorLaws;
1212
import testsupport.traits.MonadLaws;
13+
import testsupport.traits.MonadRecLaws;
1314
import testsupport.traits.TraversableLaws;
1415

1516
import static com.jnape.palatable.lambda.adt.Maybe.just;
@@ -19,13 +20,16 @@
1920
import static com.jnape.palatable.lambda.functions.builtin.fn1.Size.size;
2021
import static com.jnape.palatable.lambda.functions.builtin.fn2.Cons.cons;
2122
import static com.jnape.palatable.lambda.functions.builtin.fn2.Replicate.replicate;
23+
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.recurse;
24+
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.terminate;
2225
import static com.jnape.palatable.lambda.functor.builtin.Lazy.lazy;
2326
import static com.jnape.palatable.lambda.traversable.LambdaIterable.empty;
2427
import static com.jnape.palatable.lambda.traversable.LambdaIterable.pureLambdaIterable;
2528
import static com.jnape.palatable.lambda.traversable.LambdaIterable.wrap;
2629
import static com.jnape.palatable.traitor.framework.Subjects.subjects;
2730
import static java.util.Arrays.asList;
2831
import static java.util.Collections.singleton;
32+
import static java.util.Collections.singletonList;
2933
import static org.junit.Assert.assertEquals;
3034
import static org.junit.Assert.assertThat;
3135
import static testsupport.Constants.STACK_EXPLODING_NUMBER;
@@ -34,11 +38,41 @@
3438
@RunWith(Traits.class)
3539
public class LambdaIterableTest {
3640

37-
@TestTraits({FunctorLaws.class, ApplicativeLaws.class, TraversableLaws.class, MonadLaws.class})
41+
@TestTraits({FunctorLaws.class, ApplicativeLaws.class, TraversableLaws.class, MonadLaws.class, MonadRecLaws.class})
3842
public Subjects<LambdaIterable<Object>> testSubject() {
3943
return subjects(LambdaIterable.empty(), wrap(singleton(1)), wrap(replicate(100, 1)));
4044
}
4145

46+
@Test
47+
public void trampoliningWithDeferredResult() {
48+
assertThat(LambdaIterable.wrap(singletonList(0))
49+
.trampolineM(x -> wrap(x < STACK_EXPLODING_NUMBER
50+
? singleton(recurse(x + 1))
51+
: singleton(terminate(x))))
52+
.unwrap(),
53+
iterates(STACK_EXPLODING_NUMBER));
54+
}
55+
56+
@Test
57+
public void trampoliningOncePerElement() {
58+
assertThat(LambdaIterable.wrap(asList(1, 2, 3))
59+
.trampolineM(x -> wrap(x < STACK_EXPLODING_NUMBER
60+
? singleton(recurse(x + 1))
61+
: singleton(terminate(x))))
62+
.unwrap(),
63+
iterates(STACK_EXPLODING_NUMBER, STACK_EXPLODING_NUMBER, STACK_EXPLODING_NUMBER));
64+
}
65+
66+
@Test
67+
public void trampoliningWithIncrementalResults() {
68+
assertThat(LambdaIterable.wrap(singletonList(0))
69+
.trampolineM(x -> wrap(x < 10
70+
? asList(terminate(x), recurse(x + 1))
71+
: singleton(terminate(x))))
72+
.unwrap(),
73+
iterates(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
74+
}
75+
4276
@Test
4377
public void zipAppliesCartesianProductOfFunctionsAndValues() {
4478
LambdaIterable<Integer> xs = wrap(asList(1, 2, 3));

0 commit comments

Comments
 (0)