Skip to content

Commit c568c22

Browse files
committed
StateT retains a Pure m instance
1 parent 34bcc6d commit c568c22

File tree

4 files changed

+63
-46
lines changed

4 files changed

+63
-46
lines changed

src/main/java/com/jnape/palatable/lambda/functor/builtin/State.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static com.jnape.palatable.lambda.functions.builtin.fn2.Both.both;
2121
import static com.jnape.palatable.lambda.functions.builtin.fn2.Into.into;
2222
import static com.jnape.palatable.lambda.functions.recursion.Trampoline.trampoline;
23+
import static com.jnape.palatable.lambda.functor.builtin.Identity.pureIdentity;
2324
import static com.jnape.palatable.lambda.monad.transformer.builtin.StateT.stateT;
2425

2526
/**
@@ -32,7 +33,7 @@
3233
* @param <A> the result type
3334
*/
3435
public final class State<S, A> implements
35-
MonadRec<A, State<S,?>>,
36+
MonadRec<A, State<S, ?>>,
3637
MonadReader<S, A, State<S, ?>>,
3738
MonadWriter<S, A, State<S, ?>> {
3839

@@ -252,7 +253,7 @@ public static <S, A> State<S, A> state(A a) {
252253
* @return the new {@link State} instance
253254
*/
254255
public static <S, A> State<S, A> state(Fn1<? super S, ? extends Tuple2<A, S>> stateFn) {
255-
return new State<>(stateT(s -> new Identity<>(stateFn.apply(s))));
256+
return new State<>(stateT(s -> new Identity<>(stateFn.apply(s)), pureIdentity()));
256257
}
257258

258259
/**
@@ -264,7 +265,7 @@ public static <S, A> State<S, A> state(Fn1<? super S, ? extends Tuple2<A, S>> st
264265
public static <S> Pure<State<S, ?>> pureState() {
265266
return new Pure<State<S, ?>>() {
266267
@Override
267-
public <A> State<S, A> checkedApply(A a) throws Throwable {
268+
public <A> State<S, A> checkedApply(A a) {
268269
return state(s -> tuple(a, s));
269270
}
270271
};

src/main/java/com/jnape/palatable/lambda/monad/transformer/builtin/StateT.java

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ public final class StateT<S, M extends MonadRec<?, M>, A> implements
3535
MonadReader<S, A, StateT<S, M, ?>>,
3636
MonadWriter<S, A, StateT<S, M, ?>> {
3737

38+
private final Pure<M> pureM;
3839
private final Fn1<? super S, ? extends MonadRec<Tuple2<A, S>, M>> stateFn;
3940

40-
private StateT(Fn1<? super S, ? extends MonadRec<Tuple2<A, S>, M>> stateFn) {
41+
private StateT(Pure<M> pureM, Fn1<? super S, ? extends MonadRec<Tuple2<A, S>, M>> stateFn) {
42+
this.pureM = pureM;
4143
this.stateFn = stateFn;
4244
}
4345

@@ -78,14 +80,15 @@ public <MS extends Monad<S, M>> MS execT(S s) {
7880
/**
7981
* Map both the result and the final state to a new result and final state inside the {@link Monad}.
8082
*
81-
* @param fn the mapping function
82-
* @param <N> the new {@link Monad monadic embedding} for this {@link StateT}
83-
* @param <B> the new state type
83+
* @param fn the mapping function
84+
* @param pureN the new embedded {@link MonadRec monad's} {@link Pure} instance
85+
* @param <N> the new {@link Monad monadic embedding} for this {@link StateT}
86+
* @param <B> the new state type
8487
* @return the mapped {@link StateT}
8588
*/
8689
public <N extends MonadRec<?, N>, B> StateT<S, N, B> mapStateT(
87-
Fn1<? super MonadRec<Tuple2<A, S>, M>, ? extends MonadRec<Tuple2<B, S>, N>> fn) {
88-
return stateT(s -> fn.apply(runStateT(s)));
90+
Fn1<? super MonadRec<Tuple2<A, S>, M>, ? extends MonadRec<Tuple2<B, S>, N>> fn, Pure<N> pureN) {
91+
return stateT(s -> fn.apply(runStateT(s)), pureN);
8992
}
9093

9194
/**
@@ -96,15 +99,15 @@ public <N extends MonadRec<?, N>, B> StateT<S, N, B> mapStateT(
9699
* @return the mapped {@link StateT}
97100
*/
98101
public StateT<S, M, A> withStateT(Fn1<? super S, ? extends MonadRec<S, M>> fn) {
99-
return modify(fn).flatMap(constantly(this));
102+
return modify(fn, pureM).flatMap(constantly(this));
100103
}
101104

102105
/**
103106
* {@inheritDoc}
104107
*/
105108
@Override
106109
public <B> StateT<S, M, Tuple2<A, B>> listens(Fn1<? super S, ? extends B> fn) {
107-
return mapStateT(mas -> mas.fmap(t -> t.into((a, s) -> tuple(tuple(a, fn.apply(s)), s))));
110+
return mapStateT(mas -> mas.fmap(t -> t.into((a, s) -> tuple(tuple(a, fn.apply(s)), s))), pureM);
108111
}
109112

110113
/**
@@ -120,23 +123,24 @@ public StateT<S, M, A> censor(Fn1<? super S, ? extends S> fn) {
120123
*/
121124
@Override
122125
public StateT<S, M, A> local(Fn1<? super S, ? extends S> fn) {
123-
return stateT(s -> runStateT(fn.apply(s)));
126+
return stateT(s -> runStateT(fn.apply(s)), pureM);
124127
}
125128

126129
/**
127130
* {@inheritDoc}
128131
*/
129132
@Override
130133
public <B> StateT<S, M, B> flatMap(Fn1<? super A, ? extends Monad<B, StateT<S, M, ?>>> f) {
131-
return stateT(s -> runStateT(s).flatMap(into((a, s_) -> f.apply(a).<StateT<S, M, B>>coerce().runStateT(s_))));
134+
return stateT(s -> runStateT(s).flatMap(into((a, s_) -> f.apply(a).<StateT<S, M, B>>coerce().runStateT(s_))),
135+
pureM);
132136
}
133137

134138
/**
135139
* {@inheritDoc}
136140
*/
137141
@Override
138142
public <B> StateT<S, M, B> pure(B b) {
139-
return stateT(s -> runStateT(s).pure(tuple(b, s)));
143+
return stateT(s -> pureM.apply(tuple(b, s)), pureM);
140144
}
141145

142146
/**
@@ -185,7 +189,7 @@ public <B> StateT<S, M, A> discardR(Applicative<B, StateT<S, M, ?>> appB) {
185189
*/
186190
@Override
187191
public <B, N extends MonadRec<?, N>> StateT<S, N, B> lift(MonadRec<B, N> mb) {
188-
return stateT(s -> mb.fmap(b -> tuple(b, s)));
192+
return stateT(s -> mb.fmap(b -> tuple(b, s)), Pure.of(mb));
189193
}
190194

191195
/**
@@ -194,11 +198,12 @@ public <B, N extends MonadRec<?, N>> StateT<S, N, B> lift(MonadRec<B, N> mb) {
194198
@Override
195199
public <B> StateT<S, M, B> trampolineM(
196200
Fn1<? super A, ? extends MonadRec<RecursiveResult<A, B>, StateT<S, M, ?>>> fn) {
197-
return StateT.<S, M, B>stateT((Fn1.<S, MonadRec<Tuple2<A, S>, M>>fn1(this::runStateT))
198-
.fmap(m -> m.trampolineM(into((a, s) -> fn.apply(a)
199-
.<StateT<S, M, RecursiveResult<A, B>>>coerce().runStateT(s)
200-
.fmap(into((aOrB, s_) -> aOrB.biMap(a_ -> tuple(a_, s_),
201-
b -> tuple(b, s_))))))));
201+
return stateT((Fn1.<S, MonadRec<Tuple2<A, S>, M>>fn1(this::runStateT))
202+
.fmap(m -> m.trampolineM(into((a, s) -> fn.apply(a)
203+
.<StateT<S, M, RecursiveResult<A, B>>>coerce().runStateT(s)
204+
.fmap(into((aOrB, s_) -> aOrB.biMap(a_ -> tuple(a_, s_),
205+
b -> tuple(b, s_))))))),
206+
pureM);
202207
}
203208

204209
/**
@@ -212,34 +217,37 @@ public <B> StateT<S, M, B> trampolineM(
212217
*/
213218
@SuppressWarnings("RedundantTypeArguments")
214219
public static <A, M extends MonadRec<?, M>> StateT<A, M, A> get(Pure<M> pureM) {
215-
return gets(pureM::<A, MonadRec<A, M>>apply);
220+
return gets(pureM::<A, MonadRec<A, M>>apply, pureM);
216221
}
217222

218223
/**
219224
* Given a function that produces a value inside a {@link Monad monadic effect} from a state, produce a
220225
* {@link StateT} that simply passes its state to the function and applies it.
221226
*
222-
* @param fn the function
223-
* @param <S> the state type
224-
* @param <M> the{@link Monad} embedding
225-
* @param <A> the value type
227+
* @param fn the function
228+
* @param pureM the embedded {@link MonadRec monad's} {@link Pure} instance
229+
* @param <S> the state type
230+
* @param <M> the{@link Monad} embedding
231+
* @param <A> the value type
226232
* @return the {@link StateT}
227233
*/
228-
public static <S, M extends MonadRec<?, M>, A> StateT<S, M, A> gets(Fn1<? super S, ? extends MonadRec<A, M>> fn) {
229-
return stateT(s -> fn.apply(s).fmap(a -> tuple(a, s)));
234+
public static <S, M extends MonadRec<?, M>, A> StateT<S, M, A> gets(Fn1<? super S, ? extends MonadRec<A, M>> fn,
235+
Pure<M> pureM) {
236+
return stateT(s -> fn.apply(s).fmap(a -> tuple(a, s)), pureM);
230237
}
231238

232239
/**
233240
* Lift a function that makes a stateful modification inside an {@link Monad} into {@link StateT}.
234241
*
235242
* @param updateFn the update function
243+
* @param pureM the embedded {@link MonadRec monad's} {@link Pure} instance
236244
* @param <S> the state type
237245
* @param <M> the {@link Monad} embedding
238246
* @return the {@link StateT}
239247
*/
240248
public static <S, M extends MonadRec<?, M>> StateT<S, M, Unit> modify(
241-
Fn1<? super S, ? extends MonadRec<S, M>> updateFn) {
242-
return stateT(s -> updateFn.apply(s).fmap(tupler(UNIT)));
249+
Fn1<? super S, ? extends MonadRec<S, M>> updateFn, Pure<M> pureM) {
250+
return stateT(s -> updateFn.apply(s).fmap(tupler(UNIT)), pureM);
243251
}
244252

245253
/**
@@ -251,7 +259,7 @@ public static <S, M extends MonadRec<?, M>> StateT<S, M, Unit> modify(
251259
* @return the {@link StateT}
252260
*/
253261
public static <S, M extends MonadRec<?, M>> StateT<S, M, Unit> put(MonadRec<S, M> ms) {
254-
return modify(constantly(ms));
262+
return modify(constantly(ms), Pure.of(ms));
255263
}
256264

257265
/**
@@ -264,21 +272,22 @@ public static <S, M extends MonadRec<?, M>> StateT<S, M, Unit> put(MonadRec<S, M
264272
* @return the {@link StateT}
265273
*/
266274
public static <S, M extends MonadRec<?, M>, A> StateT<S, M, A> stateT(MonadRec<A, M> ma) {
267-
return gets(constantly(ma));
275+
return gets(constantly(ma), Pure.of(ma));
268276
}
269277

270278
/**
271279
* Lift a state-sensitive {@link Monad monadically embedded} computation into {@link StateT}.
272280
*
273281
* @param stateFn the stateful operation
282+
* @param pureM the embedded {@link MonadRec monad's} {@link Pure} instance
274283
* @param <S> the state type
275284
* @param <M> the {@link Monad} embedding
276285
* @param <A> the result type
277286
* @return the {@link StateT}
278287
*/
279288
public static <S, M extends MonadRec<?, M>, A> StateT<S, M, A> stateT(
280-
Fn1<? super S, ? extends MonadRec<Tuple2<A, S>, M>> stateFn) {
281-
return new StateT<>(stateFn);
289+
Fn1<? super S, ? extends MonadRec<Tuple2<A, S>, M>> stateFn, Pure<M> pureM) {
290+
return new StateT<>(pureM, stateFn);
282291
}
283292

284293
/**
@@ -292,7 +301,7 @@ public static <S, M extends MonadRec<?, M>, A> StateT<S, M, A> stateT(
292301
public static <S, M extends MonadRec<?, M>> Pure<StateT<S, M, ?>> pureStateT(Pure<M> pureM) {
293302
return new Pure<StateT<S, M, ?>>() {
294303
@Override
295-
public <A> StateT<S, M, A> checkedApply(A a) throws Throwable {
304+
public <A> StateT<S, M, A> checkedApply(A a) {
296305
return stateT(pureM.<A, MonadRec<A, M>>apply(a));
297306
}
298307
};

src/test/java/com/jnape/palatable/lambda/matchers/StateTMatcherTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import static com.jnape.palatable.lambda.adt.Either.right;
1010
import static com.jnape.palatable.lambda.adt.hlist.HList.tuple;
1111
import static com.jnape.palatable.lambda.io.IO.io;
12+
import static com.jnape.palatable.lambda.io.IO.pureIO;
1213
import static com.jnape.palatable.lambda.monad.transformer.builtin.StateT.gets;
1314
import static com.jnape.palatable.lambda.monad.transformer.builtin.StateT.stateT;
1415
import static org.hamcrest.CoreMatchers.not;
@@ -75,7 +76,7 @@ public void whenRunWithUsingOneTupleMatcherOnObject() {
7576
public void onlyRunsStateOnceWithTupleMatcher() {
7677
AtomicInteger count = new AtomicInteger(0);
7778

78-
assertThat(gets(s -> io(count::incrementAndGet)), whenRunWith(0, yieldsValue(equalTo(tuple(1, 0)))));
79+
assertThat(gets(s -> io(count::incrementAndGet), pureIO()), whenRunWith(0, yieldsValue(equalTo(tuple(1, 0)))));
7980
assertEquals(1, count.get());
8081
}
8182
}

src/test/java/com/jnape/palatable/lambda/monad/transformer/builtin/StateTTest.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.List;
1919

2020
import static com.jnape.palatable.lambda.adt.Maybe.just;
21+
import static com.jnape.palatable.lambda.adt.Maybe.pureMaybe;
2122
import static com.jnape.palatable.lambda.adt.Unit.UNIT;
2223
import static com.jnape.palatable.lambda.adt.hlist.HList.tuple;
2324
import static com.jnape.palatable.lambda.functor.builtin.Identity.pureIdentity;
@@ -40,38 +41,43 @@ public class StateTTest {
4041
MonadReaderLaws.class,
4142
MonadWriterLaws.class})
4243
public Equivalence<StateT<String, Identity<?>, Integer>> testReader() {
43-
return equivalence(StateT.gets(s -> new Identity<>(s.length())), s -> s.runStateT("foo"));
44+
return equivalence(StateT.gets(s -> new Identity<>(s.length()), pureIdentity()), s -> s.runStateT("foo"));
4445
}
4546

4647
@Test
4748
public void evalAndExec() {
4849
StateT<String, Identity<?>, Integer> stateT =
49-
StateT.stateT(str -> new Identity<>(tuple(str.length(), str + "_")));
50+
StateT.stateT(str -> new Identity<>(tuple(str.length(), str + "_")), pureIdentity());
5051

5152
assertThat(stateT, whenExecuted("_", new Identity<>("__")));
5253
assertThat(stateT, whenEvaluated("_", new Identity<>(1)));
5354
}
5455

5556
@Test
5657
public void mapStateT() {
57-
assertThat(StateT.<String, Identity<?>, Integer>stateT(str -> new Identity<>(tuple(str.length(), str + "_")))
58+
assertThat(StateT.<String, Identity<?>, Integer>stateT(str -> new Identity<>(tuple(str.length(), str + "_")),
59+
pureIdentity())
5860
.mapStateT(id -> id.<Identity<Tuple2<Integer, String>>>coerce()
59-
.runIdentity()
60-
.into((x, str) -> just(tuple(x + 1, str.toUpperCase())))),
61+
.runIdentity()
62+
.into((x, str) -> just(tuple(x + 1, str.toUpperCase()))),
63+
pureMaybe()),
6164
whenRun("abc", just(tuple(4, "ABC_"))));
6265
}
6366

6467
@Test
6568
public void zipping() {
6669
assertThat(
67-
StateT.<List<String>, Identity<?>>modify(s -> new Identity<>(set(elementAt(s.size()), just("one"), s)))
68-
.discardL(StateT.modify(s -> new Identity<>(set(elementAt(s.size()), just("two"), s)))),
70+
StateT.<List<String>, Identity<?>>modify(s -> new Identity<>(set(elementAt(s.size()), just("one"), s)),
71+
pureIdentity())
72+
.discardL(StateT.modify(s -> new Identity<>(set(elementAt(s.size()), just("two"), s)),
73+
pureIdentity())),
6974
whenRun(new ArrayList<>(), new Identity<>(tuple(UNIT, asList("one", "two")))));
7075
}
7176

7277
@Test
7378
public void withStateT() {
74-
assertThat(StateT.<String, Identity<?>, Integer>stateT(str -> new Identity<>(tuple(str.length(), str + "_")))
79+
assertThat(StateT.<String, Identity<?>, Integer>stateT(str -> new Identity<>(tuple(str.length(), str + "_")),
80+
pureIdentity())
7581
.withStateT(str -> new Identity<>(str.toUpperCase())),
7682
whenRun("abc", new Identity<>(tuple(3, "ABC_"))));
7783
}
@@ -84,7 +90,7 @@ public void get() {
8490

8591
@Test
8692
public void gets() {
87-
assertThat(StateT.gets(s -> new Identity<>(s.length())),
93+
assertThat(StateT.gets(s -> new Identity<>(s.length()), pureIdentity()),
8894
whenRun("state", new Identity<>(tuple(5, "state"))));
8995
}
9096

@@ -96,15 +102,15 @@ public void put() {
96102

97103
@Test
98104
public void modify() {
99-
assertThat(StateT.modify(x -> new Identity<>(x + 1)),
105+
assertThat(StateT.modify(x -> new Identity<>(x + 1), pureIdentity()),
100106
whenRun(0, new Identity<>(tuple(UNIT, 1))));
101107
}
102108

103109
@Test
104110
public void stateT() {
105111
assertThat(StateT.stateT(new Identity<>(0)),
106112
whenRun("_", new Identity<>(tuple(0, "_"))));
107-
assertThat(StateT.stateT(s -> new Identity<>(tuple(s.length(), s + "1"))),
113+
assertThat(StateT.stateT(s -> new Identity<>(tuple(s.length(), s + "1")), pureIdentity()),
108114
whenRun("_", new Identity<>(tuple(1, "_1"))));
109115
}
110116

0 commit comments

Comments
 (0)