Skip to content

Commit 6abbd2a

Browse files
committed
Adding IterateT#foldCut for folding with early termination
1 parent a2505ba commit 6abbd2a

File tree

2 files changed

+36
-3
lines changed
  • src
    • main/java/com/jnape/palatable/lambda/monad/transformer/builtin
    • test/java/com/jnape/palatable/lambda/monad/transformer/builtin

2 files changed

+36
-3
lines changed

src/main/java/com/jnape/palatable/lambda/monad/transformer/builtin/IterateT.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import static com.jnape.palatable.lambda.adt.choice.Choice2.a;
2929
import static com.jnape.palatable.lambda.adt.choice.Choice2.b;
3030
import static com.jnape.palatable.lambda.adt.hlist.HList.tuple;
31-
import static com.jnape.palatable.lambda.functions.builtin.fn1.Constantly.constantly;
3231
import static com.jnape.palatable.lambda.functions.builtin.fn2.Into.into;
3332
import static com.jnape.palatable.lambda.functions.builtin.fn2.Tupler2.tupler;
3433
import static com.jnape.palatable.lambda.functions.builtin.fn3.FoldLeft.foldLeft;
@@ -138,11 +137,30 @@ public IterateT<M, A> concat(IterateT<M, A> other) {
138137
*/
139138
public <B, MB extends MonadRec<B, M>> MB fold(Fn2<? super B, ? super A, ? extends MonadRec<B, M>> fn,
140139
MonadRec<B, M> acc) {
140+
return foldCut((b, a) -> fn.apply(b, a).fmap(RecursiveResult::recurse), acc);
141+
}
142+
143+
/**
144+
* Monolithically fold the spine of this {@link IterateT} (with the possibility of early termination) by
145+
* {@link MonadRec#trampolineM(Fn1) trampolining} the underlying effects (for iterative folding, use
146+
* {@link IterateT#trampolineM(Fn1) trampolineM} directly).
147+
*
148+
* @param fn the folding function
149+
* @param acc the starting accumulation effect
150+
* @param <B> the accumulation type
151+
* @param <MB> the witnessed target result type
152+
* @return the folded effect result
153+
*/
154+
public <B, MB extends MonadRec<B, M>> MB foldCut(
155+
Fn2<? super B, ? super A, ? extends MonadRec<RecursiveResult<B, B>, M>> fn,
156+
MonadRec<B, M> acc) {
141157
return acc.fmap(tupler(this))
142158
.trampolineM(into((as, b) -> maybeT(as.runIterateT())
143-
.flatMap(into((a, aas) -> maybeT(fn.apply(b, a).fmap(tupler(aas)).fmap(Maybe::just))))
159+
.flatMap(into((a, aas) -> maybeT(fn.apply(b, a).fmap(Maybe::just)).fmap(tupler(aas))))
144160
.runMaybeT()
145-
.fmap(maybeRecur -> maybeRecur.match(constantly(terminate(b)), RecursiveResult::recurse))))
161+
.fmap(maybeR -> maybeR.match(
162+
__ -> terminate(b),
163+
into((rest, rr) -> rr.biMapL(tupler(rest)))))))
146164
.coerce();
147165
}
148166

src/test/java/com/jnape/palatable/lambda/monad/transformer/builtin/IterateTTest.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import static com.jnape.palatable.lambda.adt.hlist.HList.tuple;
3131
import static com.jnape.palatable.lambda.functions.builtin.fn2.LTE.lte;
3232
import static com.jnape.palatable.lambda.functions.builtin.fn3.Times.times;
33+
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.recurse;
34+
import static com.jnape.palatable.lambda.functions.recursion.RecursiveResult.terminate;
3335
import static com.jnape.palatable.lambda.functor.builtin.Identity.pureIdentity;
3436
import static com.jnape.palatable.lambda.functor.builtin.Lazy.lazy;
3537
import static com.jnape.palatable.lambda.functor.builtin.Writer.listen;
@@ -41,6 +43,7 @@
4143
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.singleton;
4244
import static com.jnape.palatable.lambda.monad.transformer.builtin.IterateT.unfold;
4345
import static com.jnape.palatable.lambda.monoid.builtin.AddAll.addAll;
46+
import static com.jnape.palatable.lambda.monoid.builtin.Join.join;
4447
import static com.jnape.palatable.traitor.framework.Subjects.subjects;
4548
import static java.util.Arrays.asList;
4649
import static java.util.Collections.emptyList;
@@ -152,6 +155,18 @@ public void fold() {
152155
.runWriter(addAll(ArrayList::new)));
153156
}
154157

158+
@Test
159+
public void foldCut() {
160+
assertEquals(tuple(3, "012"),
161+
IterateT.of(writer(tuple(1, "1")),
162+
writer(tuple(2, "2")),
163+
writer(tuple(3, "3")))
164+
.<Integer, Writer<String, Integer>>foldCut(
165+
(x, y) -> listen(y == 2 ? terminate(x + y) : recurse(x + y)),
166+
writer(tuple(0, "0")))
167+
.runWriter(join()));
168+
}
169+
155170
@Test
156171
public void zipUsesCartesianProduct() {
157172
assertThat(IterateT.of(new Identity<>(1), new Identity<>(2), new Identity<>(3))

0 commit comments

Comments
 (0)