diff --git a/build.sbt b/build.sbt index ae0916a9..f91bac7a 100644 --- a/build.sbt +++ b/build.sbt @@ -4,7 +4,7 @@ organization := "org.typesafe.async" name := "scala-async" -version := "1.0.0-SNAPSHOT" +version := "1.0.0-M2" libraryDependencies <++= (scalaVersion) { sv => Seq( diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index afcf6bdf..bf2d7b2c 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -116,44 +116,42 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private object inline { def transformToList(tree: Tree): List[Tree] = trace("inline", tree) { + def branchWithAssign(orig: Tree, varDef: ValDef) = orig match { + case Block(stats, expr) => Block(stats, Assign(Ident(varDef.name), expr)) + case _ => Assign(Ident(varDef.name), orig) + } + + def casesWithAssign(cases: List[CaseDef], varDef: ValDef) = cases map { + case cd @ CaseDef(pat, guard, orig) => + attachCopy(cd)(CaseDef(pat, guard, branchWithAssign(orig, varDef))) + } + val stats :+ expr = anf.transformToList(tree) expr match { + // if type of if-else/try/match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + case If(_, _, _) | Try(_, _, _) | Match(_, _) if expr.tpe =:= definitions.UnitTpe => + stats :+ expr :+ Literal(Constant(())) + case Apply(fun, args) if isAwait(fun) => val valDef = defineVal(name.await, expr, tree.pos) stats :+ valDef :+ Ident(valDef.name) case If(cond, thenp, elsep) => - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } else { - val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) - case _ => Assign(Ident(varDef.name), orig) - } - val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) - stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) - } + val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) + val ifWithAssign = If(cond, branchWithAssign(thenp, varDef), branchWithAssign(elsep, varDef)) + stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) + + case Try(body, catches, finalizer) => + val varDef = defineVar(name.tryRes, expr.tpe, tree.pos) + val tryWithAssign = Try(branchWithAssign(body, varDef), casesWithAssign(catches, varDef), finalizer) + stats :+ varDef :+ tryWithAssign :+ Ident(varDef.name) case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - stats :+ expr :+ Literal(Constant(())) - } - else { - val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, Block(caseStats, caseExpr)) => - attachCopy(cd)(CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))) - case cd@CaseDef(pat, guard, body) => - attachCopy(cd)(CaseDef(pat, guard, Assign(Ident(varDef.name), body))) - } - val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign)) - stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) - } + val varDef = defineVar(name.matchRes, expr.tpe, tree.pos) + val matchWithAssign = attachCopy(tree)(Match(scrut, casesWithAssign(cases, varDef))) + stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) + case _ => stats :+ expr } @@ -220,6 +218,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val stats :+ expr = inline.transformToList(rhs) stats :+ attachCopy(tree)(Assign(lhs, expr)) + case Try(body, catches, finalizer) if containsAwait => + val stats :+ expr = inline.transformToList(body) + val tryType = c.typeCheck(Try(Block(stats, expr), catches, finalizer)).tpe + List(attachCopy(tree)(Try(Block(stats, expr), catches, finalizer)).setType(tryType)) + case If(cond, thenp, elsep) if containsAwait => val condStats :+ condExpr = inline.transformToList(cond) val thenBlock = inline.transformToBlock(thenp) diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 12fe428a..272cc481 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -6,7 +6,6 @@ package scala.async import scala.language.experimental.macros import scala.reflect.macros.Context -import scala.util.continuations.{cpsParam, reset} object Async extends AsyncBase { @@ -61,9 +60,7 @@ abstract class AsyncBase { @deprecated("`await` must be enclosed in an `async` block", "0.1") def await[T](awaitable: futureSystem.Fut[T]): T = ??? - def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = ??? - - def fallbackEnabled = false + protected[async] def fallbackEnabled = false def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ @@ -72,7 +69,8 @@ abstract class AsyncBase { val utils = TransformUtils[c.type](c) import utils.{name, defn} - if (!analyzer.reportUnsupportedAwaits(body.tree) || !fallbackEnabled) { + analyzer.reportUnsupportedAwaits(body.tree) + // Transform to A-normal form: // - no await calls in qualifiers or arguments, // - if/match only used in statement position. @@ -122,6 +120,13 @@ abstract class AsyncBase { val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) + + // the stack of currently active exception handlers + val handlers = ValDef(Modifiers(Flag.MUTABLE), name.handlers, TypeTree(typeOf[List[PartialFunction[Throwable, Unit]]]), (reify { List() }).tree) + + // the exception that is currently in-flight or `null` otherwise + val exception = ValDef(Modifiers(Flag.MUTABLE), name.exception, TypeTree(typeOf[Throwable]), Literal(Constant(null))) + val applyDefDef: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) val applyBody = asyncBlock.onCompleteHandler @@ -134,7 +139,7 @@ abstract class AsyncBase { val applyBody = asyncBlock.onCompleteHandler DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) } - List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) + List(utils.emptyConstructor, stateVar, result, execContext, handlers, exception) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) } val template = { Template(List(stateMachineType), emptyValDef, body) @@ -162,35 +167,6 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") code - } else { - // replace `await` invocations with `awaitFallback` invocations - val awaitReplacer = new Transformer { - override def transform(tree: Tree): Tree = tree match { - case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == defn.Async_await => - val typeApp = treeCopy.TypeApply(fun, Ident(defn.Async_awaitFallback), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) - treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(name.result)) - case _ => - super.transform(tree) - } - } - val newBody = awaitReplacer.transform(body.tree) - - val resetBody = reify { - reset { c.Expr(c.resetAllAttrs(newBody.duplicate)).splice } - } - - val futureSystemOps = futureSystem.mkOps(c) - val code = { - val tree = Block(List( - ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree), - futureSystemOps.spawn(resetBody.tree) - ), futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](Ident(name.result))).tree) - c.Expr[futureSystem.Fut[T]](tree) - } - - AsyncUtils.vprintln(s"async CPS fallback transform expands to:\n ${code.tree}") - code - } } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 9184960b..7c667c32 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -76,16 +76,16 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: Asy } override def traverse(tree: Tree) { - def containsAwait = tree exists isAwait + def containsAwait(t: Tree) = t exists isAwait tree match { - case Try(_, _, _) if containsAwait => - reportUnsupportedAwait(tree, "try/catch") + case Try(_, catches, _) if catches exists containsAwait => + reportUnsupportedAwait(tree, "catch") super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => + case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => c.abort(tree.pos, "lazy vals are illegal within an async block") - case _ => + case _ => super.traverse(tree) } } diff --git a/src/main/scala/scala/async/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/AsyncWithCPSFallback.scala deleted file mode 100644 index 39e43a3c..00000000 --- a/src/main/scala/scala/async/AsyncWithCPSFallback.scala +++ /dev/null @@ -1,31 +0,0 @@ -package scala.async - -import scala.language.experimental.macros - -import scala.reflect.macros.Context -import scala.util.continuations._ - -object AsyncWithCPSFallback extends AsyncBase { - - import scala.concurrent.{Future, ExecutionContext} - import ExecutionContext.Implicits.global - - lazy val futureSystem = ScalaConcurrentFutureSystem - type FS = ScalaConcurrentFutureSystem.type - - /* Fall-back for `await` when it is called at an unsupported position. - */ - override def awaitFallback[T, U](awaitable: futureSystem.Fut[T], p: futureSystem.Prom[U]): T @cpsParam[U, Unit] = - shift { - (k: (T => U)) => - awaitable onComplete { - case tr => p.success(k(tr.get)) - } - } - - override def fallbackEnabled = true - - def async[T](body: T) = macro asyncImpl[T] - - override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) -} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 180e7b91..22337d4a 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -38,11 +38,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, excState: Option[Int]) extends AsyncState { def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) + mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply, excState) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -51,9 +51,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int, excState: Option[Int]) extends AsyncState { override def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats) + mkHandlerCase(state, stats, excState) override val toString: String = s"AsyncStateWithoutAwait #$state" @@ -63,13 +63,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * handler will unconditionally transition to `nestState`.`` */ final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, - awaitable: Awaitable) + awaitable: Awaitable, excState: Option[Int]) extends AsyncState { override def mkHandlerCaseForState: CaseDef = { val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree - mkHandlerCase(state, stats :+ callOnComplete) + mkHandlerCase(state, stats :+ callOnComplete, excState) } override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = { @@ -97,7 +97,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: Block(List(tryGetTree, mkStateTree(nextState), mkResumeApply): _*) ) - Some(mkHandlerCase(state, List(ifIsFailureTree))) + Some(mkHandlerCase(state, List(ifIsFailureTree), excState)) } override val toString: String = @@ -106,8 +106,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /* * Builder for a single state of an async method. + * + * The `excState` parameter is implicit, so that it is passed implicitly + * when `AsyncBlockBuilder` creates new `AsyncStateBuilder`s. */ - final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { + final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name])(implicit excState: Option[Int]) { /* Statements preceding an await call. */ private val stats = ListBuffer[c.Tree]() /** The state of the target of a LabelDef application (while loop jump) */ @@ -134,12 +137,17 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: nextState: Int): AsyncState = { val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable, excState) + } + + def resultWithoutAwait(): AsyncState = { + this += mkResumeApply + new AsyncStateWithoutAwait(stats.toList, state, excState) } def resultSimple(nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, excState) } def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { @@ -148,7 +156,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val cond = renameReset(condTree) def mkBranch(state: Int) = Block(mkStateTree(state), mkResumeApply) this += If(cond, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, excState) } /** @@ -173,12 +181,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } // 2. insert changed match tree at the end of the current state this += Match(renameReset(scrutTree), newCases) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, excState) } def resultWithLabel(startLabelState: Int): AsyncState = { this += Block(mkStateTree(startLabelState), mkResumeApply) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, excState) } override def toString: String = { @@ -190,14 +198,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * - * @param stats a list of expressions - * @param expr the last expression of the block - * @param startState the start state - * @param endState the state to continue with - * @param toRename a `Map` for renaming the given key symbols to the mangled value names + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + * @param toRename a `Map` for renaming the given key symbols to the mangled value names + * @param excState the state to continue with in case of an exception + * @param parentExcState the state to continue with in case of an exception not handled by the current exception handler */ - final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { + final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, private val toRename: Map[Symbol, c.Name], + parentExcState: Option[Int] = None)(implicit excState: Option[Int]) { val asyncStates = ListBuffer[AsyncState]() var stateBuilder = new AsyncStateBuilder(startState, toRename) @@ -209,9 +219,10 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: case _ => false }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException - def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int, + excState: Option[Int] = None, parentExcState: Option[Int] = None) = { val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename, parentExcState)(excState) } import stateAssigner.nextState @@ -250,6 +261,60 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: currState = afterIfState stateBuilder = new AsyncStateBuilder(currState, toRename) + case Try(block, catches, finalizer) if stat exists isAwait => + val tryStartState = nextState() + val afterTryState = nextState() + val ehState = nextState() + val finalizerState = if (!finalizer.isEmpty) Some(nextState()) else None + + // complete current state so that it continues with tryStartState + asyncStates += stateBuilder.resultWithLabel(tryStartState) + + if (!finalizer.isEmpty) { + val builder = nestedBlockBuilder(finalizer, finalizerState.get, afterTryState) + asyncStates ++= builder.asyncStates + } + + // create handler state + def handlersDot(m: String) = Select(Ident(name.handlers), m) + val exceptionExpr = c.Expr[Throwable](Ident(name.exception)) + // handler state does not have active exception handler --> None + val handlerStateBuilder = new AsyncStateBuilder(ehState, toRename)(None) + + val parentExpr: c.Expr[Unit] = + if (parentExcState.isEmpty) reify { throw exceptionExpr.splice } + else c.Expr[Unit](mkStateTree(parentExcState.get)) + + val handlerExpr = reify { + val h = c.Expr[PartialFunction[Throwable, Unit]](handlersDot("head")).splice + c.Expr[Unit](Assign(Ident(name.handlers), handlersDot("tail"))).splice + + if (h isDefinedAt exceptionExpr.splice) { + h(exceptionExpr.splice) + c.Expr[Unit](mkStateTree(if (!finalizer.isEmpty) finalizerState.get else afterTryState)).splice + } else { + parentExpr.splice + } + } + + handlerStateBuilder += handlerExpr.tree + asyncStates += handlerStateBuilder.resultWithoutAwait() + + val ehName = newTermName("handlerPF$" + ehState) + val partFunAssign = ValDef(Modifiers(), ehName, TypeTree(typeOf[PartialFunction[Throwable, Unit]]), Match(EmptyTree, catches)) + val newHandler = c.Expr[PartialFunction[Throwable, Unit]](Ident(ehName)) + val handlersIdent = c.Expr[List[PartialFunction[Throwable, Unit]]](Ident(name.handlers)) + val pushedHandlers = reify { handlersIdent.splice.+:(newHandler.splice) } + val pushAssign = Assign(Ident(name.handlers), pushedHandlers.tree) + + val (tryStats, tryExpr) = statsAndExpr(block) + val builder = nestedBlockBuilder(Block(partFunAssign :: pushAssign :: tryStats, tryExpr), + tryStartState, if (!finalizer.isEmpty) finalizerState.get else afterTryState, Some(ehState), excState) + asyncStates ++= builder.asyncStates + + currState = afterTryState + stateBuilder = new AsyncStateBuilder(currState, toRename) + case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) @@ -302,7 +367,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val startState = stateAssigner.nextState() val endState = Int.MaxValue - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename)(None) new AsyncBlock { def asyncStates = blockBuilder.asyncStates.toList @@ -313,7 +378,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val lastStateBody = c.Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) - mkHandlerCase(lastState.state, rhs.tree) + mkHandlerCase(lastState.state, rhs.tree, None) } asyncStates.toList match { case s :: Nil => @@ -386,9 +451,35 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: private def mkStateTree(nextState: Int): c.Tree = Assign(Ident(name.state), c.literal(nextState).tree) - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs: _*)) + private def mkHandlerCase(num: Int, rhs: List[c.Tree], excState: Option[Int]): CaseDef = + mkHandlerCase(num, Block(rhs: _*), excState) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + /* Generates `case` clause with wrapping try-catch: + * + * case `num` => + * try { + * rhs + * } catch { + * case NonFatal(t) => + * exception$async = t + * state$async = excState.get + * resume$async() + * } + */ + private def mkHandlerCase(num: Int, rhs: c.Tree, excState: Option[Int]): CaseDef = { + val rhsWithTry = + if (excState.isEmpty) rhs + else Try(rhs, + List( + CaseDef( + Apply(Ident(defn.NonFatalClass), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + EmptyTree, + Block(List( + Assign(Ident(name.exception), Ident(newTermName("t"))), + mkStateTree(excState.get), + mkResumeApply + ), c.literalUnit.tree))), EmptyTree + ) + CaseDef(c.literal(num).tree, EmptyTree, rhsWithTry) + } } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index f0b46531..a050bec0 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -54,6 +54,8 @@ trait FutureSystem { def spawn(tree: context.Tree): context.Tree = future(context.Expr[Unit](tree))(execContext).tree + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] } def mkOps(c: Context): Ops { val context: c.type } @@ -101,6 +103,10 @@ object ScalaConcurrentFutureSystem extends FutureSystem { prom.splice.complete(value.splice) context.literalUnit.splice } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = reify { + future.splice.asInstanceOf[Fut[A]] + } } } @@ -145,5 +151,7 @@ object IdentityFutureSystem extends FutureSystem { prom.splice.a = value.splice.get context.literalUnit.splice } + + def castTo[A: WeakTypeTag](future: Expr[Fut[Any]]): Expr[Fut[A]] = ??? } } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index db82ed61..00fa4309 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -29,8 +29,11 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val tr = newTermName("tr") val matchRes = "matchres" val ifRes = "ifres" + val tryRes = "tryRes" val await = "await" val bindSuffix = "$bind" + val handlers = suffixedName("handlers") + val exception = suffixedName("exception") def arg(i: Int) = "arg" + i @@ -174,8 +177,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { tpe.member(newTermName(name)).ensuring(_ != NoSymbol) } - val Async_await = asyncMember("await") - val Async_awaitFallback = asyncMember("awaitFallback") + val Async_await = asyncMember("await") } /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala new file mode 100644 index 00000000..a669cfa2 --- /dev/null +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.util.continuations._ + +trait AsyncBaseWithCPSFallback extends AsyncBase { + + /* Fall-back for `await` using CPS plugin. + * + * Note: This method is public, but is intended only for internal use. + */ + def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[futureSystem.Fut[Any]] + + override protected[async] def fallbackEnabled = true + + /* Implements `async { ... }` using the CPS plugin. + */ + protected def cpsBasedAsyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + import c.universe._ + + def lookupMember(name: String) = { + val asyncTrait = c.mirror.staticClass("scala.async.continuations.AsyncBaseWithCPSFallback") + val tpe = asyncTrait.asType.toType + tpe.member(newTermName(name)).ensuring(_ != NoSymbol) + } + + AsyncUtils.vprintln("AsyncBaseWithCPSFallback.cpsBasedAsyncImpl") + + val utils = TransformUtils[c.type](c) + val futureSystemOps = futureSystem.mkOps(c) + val awaitSym = utils.defn.Async_await + val awaitFallbackSym = lookupMember("awaitFallback") + + // replace `await` invocations with `awaitFallback` invocations + val awaitReplacer = new Transformer { + override def transform(tree: Tree): Tree = tree match { + case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitSym => + val typeApp = treeCopy.TypeApply(fun, Ident(awaitFallbackSym), List(TypeTree(futArgTpt.tpe))) + treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate))) + case _ => + super.transform(tree) + } + } + val bodyWithAwaitFallback = awaitReplacer.transform(body.tree) + + /* generate an expression that looks like this: + reset { + val f = future { ... } + ... + val x = awaitFallback(f) + ... + future { expr } + }.asInstanceOf[Future[T]] + */ + + val bodyWithFuture = { + val tree = bodyWithAwaitFallback match { + case Block(stmts, expr) => Block(stmts, futureSystemOps.spawn(expr)) + case expr => futureSystemOps.spawn(expr) + } + c.Expr[futureSystem.Fut[Any]](c.resetAllAttrs(tree.duplicate)) + } + + val bodyWithReset: c.Expr[futureSystem.Fut[Any]] = reify { + reset { bodyWithFuture.splice } + } + val bodyWithCast = futureSystemOps.castTo[T](bodyWithReset) + + AsyncUtils.vprintln(s"CPS-based async transform expands to:\n${bodyWithCast.tree}") + bodyWithCast + } + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { + AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") + + val analyzer = AsyncAnalysis[c.type](c, this) + + if (!analyzer.reportUnsupportedAwaits(body.tree)) + super.asyncImpl[T](c)(body) // no unsupported awaits + else + cpsBasedAsyncImpl[T](c)(body) // fallback to CPS + } +} diff --git a/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala new file mode 100644 index 00000000..fe6e1a60 --- /dev/null +++ b/src/main/scala/scala/async/continuations/AsyncWithCPSFallback.scala @@ -0,0 +1,20 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.concurrent.Future + +trait AsyncWithCPSFallback extends AsyncBaseWithCPSFallback with ScalaConcurrentCPSFallback + +object AsyncWithCPSFallback extends AsyncWithCPSFallback { + + def async[T](body: T) = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) +} diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsync.scala b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala new file mode 100644 index 00000000..922d1ac6 --- /dev/null +++ b/src/main/scala/scala/async/continuations/CPSBasedAsync.scala @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.concurrent.Future + +trait CPSBasedAsync extends CPSBasedAsyncBase with ScalaConcurrentCPSFallback + +object CPSBasedAsync extends CPSBasedAsync { + + def async[T](body: T) = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) + +} diff --git a/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala new file mode 100644 index 00000000..4e8ec80b --- /dev/null +++ b/src/main/scala/scala/async/continuations/CPSBasedAsyncBase.scala @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.language.experimental.macros + +import scala.reflect.macros.Context +import scala.util.continuations._ + +/* Specializes `AsyncBaseWithCPSFallback` to always fall back to CPS, yielding a purely CPS-based + * implementation of async/await. + */ +trait CPSBasedAsyncBase extends AsyncBaseWithCPSFallback { + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = + super.cpsBasedAsyncImpl[T](c)(body) + +} diff --git a/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala new file mode 100644 index 00000000..018ad05d --- /dev/null +++ b/src/main/scala/scala/async/continuations/ScalaConcurrentCPSFallback.scala @@ -0,0 +1,31 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package continuations + +import scala.util.continuations._ +import scala.concurrent.{Future, Promise, ExecutionContext} + +trait ScalaConcurrentCPSFallback { + self: AsyncBaseWithCPSFallback => + + import ExecutionContext.Implicits.global + + lazy val futureSystem = ScalaConcurrentFutureSystem + type FS = ScalaConcurrentFutureSystem.type + + /* Fall-back for `await` when it is called at an unsupported position. + */ + override def awaitFallback[T](awaitable: futureSystem.Fut[T]): T @cps[Future[Any]] = + shift { + (k: (T => Future[Any])) => + val fr = Promise[Any]() + awaitable onComplete { + case tr => fr completeWith k(tr.get) + } + fr.future + } + +} diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 93cfdf57..18765568 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -40,8 +40,7 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) - varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2")) + varDefs.map(_.decoded.trim).toSet mustBe (Set("state$async", "await$1", "await$2", "handlers$async", "exception$async")) val defDefs = tree1.collect { case t: Template => diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index c3537ec0..f297bede 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -102,19 +102,9 @@ class NakedAwait { } } - @Test - def tryBody() { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.AsyncId._ - | async { try { await(false) } catch { case _ => } } - """.stripMargin - } - } - @Test def catchBody() { - expectError("await must not be used under a try/catch.") { + expectError("await must not be used under a catch.") { """ | import _root_.scala.async.AsyncId._ | async { try { () } catch { case _ => await(false) } } @@ -122,16 +112,6 @@ class NakedAwait { } } - @Test - def finallyBody() { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.AsyncId._ - | async { try { () } finally { await(false) } } - """.stripMargin - } - } - @Test def nestedMethod() { expectError("await must not be used under a nested method.") { diff --git a/src/test/scala/scala/async/run/cps/CPSSpec.scala b/src/test/scala/scala/async/run/cps/CPSSpec.scala new file mode 100644 index 00000000..b56c6ad9 --- /dev/null +++ b/src/test/scala/scala/async/run/cps/CPSSpec.scala @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package run +package cps + +import scala.concurrent.{Future, Promise, ExecutionContext, future, Await} +import scala.concurrent.duration._ +import scala.async.continuations.CPSBasedAsync._ +import scala.util.continuations._ + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + +@RunWith(classOf[JUnit4]) +class CPSSpec { + + import ExecutionContext.Implicits.global + + def m1(y: Int): Future[Int] = async { + val f = future { y + 2 } + val f2 = future { y + 3 } + val x1 = await(f) + val x2 = await(f2) + x1 + x2 + } + + def m2(y: Int): Future[Int] = async { + val f = future { y + 2 } + val res = await(f) + if (y > 0) res + 2 + else res - 2 + } + + @Test + def testCPSFallback() { + val fut1 = m1(10) + val res1 = Await.result(fut1, 2.seconds) + assert(res1 == 25, s"expected 25, got $res1") + + val fut2 = m2(10) + val res2 = Await.result(fut2, 2.seconds) + assert(res2 == 14, s"expected 14, got $res2") + } + +} diff --git a/src/test/scala/scala/async/run/trycatch/TrySpec.scala b/src/test/scala/scala/async/run/trycatch/TrySpec.scala new file mode 100644 index 00000000..4f6e93c6 --- /dev/null +++ b/src/test/scala/scala/async/run/trycatch/TrySpec.scala @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package run +package trycatch + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + +@RunWith(classOf[JUnit4]) +class TrySpec { + + @Test + def tryCatch1() { + import AsyncId._ + + val result = async { + var xxx: Int = 0 + try { + val y = await(xxx) + xxx = xxx + 1 + y + } catch { + case e: Exception => + assert(false) + } + xxx + } + assert(result == 1) + } + + @Test + def tryCatch2() { + import AsyncId._ + + val result = async { + var xxx: Int = 0 + try { + val y = await(xxx) + throw new Exception("test msg") + assert(false) + xxx = xxx + 1 + y + } catch { + case e: Exception => + assert(e.getMessage == "test msg") + xxx = 7 + } + xxx + } + assert(result == 7) + } + + @Test + def nestedTry1() { + import AsyncId._ + + val result = async { + var xxx = 0 + try { + try { + val y = await(xxx) + throw new IllegalArgumentException("msg") + assert(false) + y + 2 + } catch { + case iae: IllegalArgumentException => + xxx = 6 + } + } catch { + case nsee: NoSuchElementException => + xxx = 7 + } + xxx + } + assert(result == 6) + } + + @Test + def nestedTry2() { + import AsyncId._ + + val result = async { + var xxx = 0 + try { + try { + val y = await(xxx) + throw new NoSuchElementException("msg") + assert(false) + y + 2 + } catch { + case iae: IllegalArgumentException => + xxx = 6 + } + } catch { + case nsee: NoSuchElementException => + xxx = 7 + } + xxx + } + assert(result == 7) + } + + @Test + def tryAsExpr() { + import AsyncId._ + + val result = async { + val xxx: Int = 0 + try { + val y = await(xxx) + y + 2 + } catch { + case e: Exception => + assert(false) + xxx + 4 + } + } + assert(result == 2) + } + + @Test + def tryFinally1() { + import AsyncId._ + + var xxx: Int = 0 + val result = async { + try { + val y = await(xxx) + y + 2 + } catch { + case e: Exception => + assert(false) + xxx + 4 + } finally { + xxx = 5 + } + } + assert(result == 2) + assert(xxx == 5) + } + + @Test + def tryFinally2() { + import AsyncId._ + + var xxx: Int = 0 + val result = async { + try { + val y = await(xxx) + throw new Exception("msg") + assert(false) + y + 2 + } catch { + case e: Exception => + xxx + 4 + } finally { + xxx = 6 + } + } + assert(result == 4) + assert(xxx == 6) + } + + @Test + def tryFinallyAwait1() { + import AsyncId._ + + var xxx: Int = 0 + var uuu: Int = 10 + val result = async { + try { + val y = await(xxx) + y + 2 + } catch { + case e: Exception => + assert(false) + xxx + 4 + } finally { + val v = await(uuu) + xxx = v + } + } + assert(result == 2) + assert(xxx == 10) + } + + @Test + def tryFinallyAwait2() { + import AsyncId._ + + var xxx: Int = 0 + var uuu: Int = 10 + val result = async { + try { + val y = await(xxx) + throw new Exception("msg") + assert(false) + y + 2 + } catch { + case e: Exception => + xxx + 4 + } finally { + val v = await(uuu) + xxx = v + } + } + assert(result == 4) + assert(xxx == 10) + } + +}