Skip to content

Commit adabf63

Browse files
authored
[Stack Switching] Add basic support for resume/suspend in the interpreter (#7771)
It turns out we can implement suspend/resume on the current interpreter with only a small amount of changes, basically using the same idea as Asyncify: pause and resume structured control flow by rewinding the stack in a "resume" mode that skips normal execution. That is, this does not use a program counter (no goto), nor continuation-passing style, and instead adds interpreter support to unwind and rewind the stack. This is not the fastest way to do things, but it is the simplest by far. The basic idea is the same as in this 2019 blogpost: https://kripken.github.io/blog/wasm/2019/07/16/asyncify.html This is quite efficient in the things we care about: suspend/ resume is slow, but normal execution hardly regresses, which is important to keep the Precompute pass from getting slower. While the more intrusive experiment #7762 made that pass 2x slower, this only adds 10-15% overhead. The main reason is that this pass keeps us down to a single indirect call per instruction, while e.g. separating decoding and execution - to maintain a normal value stack - ends up doing two indirect calls. The "trick" in this PR is that, yes, we do need a value stack (that is the only practical way to stash the values on the stack when suspending), but we can populate that stack only when inside a coroutine (when we might suspend). So we still use the normal way of getting child instruction values - just calling `visit(curr->child)` from the parent's `visitFoo()` method - and do not lose that speed, but still have a stack of values when we need it. (10-15% is still significant, but it is just on a single pass, so it seems acceptable, and there might be ways to optimize this further.) Notes: * Flow now has a "suspendTag" property, which denotes the tag we are suspending, when we suspend. * As part of this change, `callFunction` in the interpreter returns a Flow, so that we can propagate suspensions out of functions. * For testing, this adds `assert_suspension` support in `wasm-shell`. * For testing, add part of `cont.wast` from the test suite, and minor fixes for it in `wasm.cpp`. * This only adds basic support: `resume_throw` and various other parts of the stack switching proposal are TODOs, but I think this PR does the hard parts.
1 parent fc6a797 commit adabf63

File tree

14 files changed

+1387
-51
lines changed

14 files changed

+1387
-51
lines changed

scripts/test/fuzzing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@
111111
'precompute-stack-switching.wast',
112112
'unsubtyping-stack-switching.wast',
113113
'vacuum-stack-switching.wast',
114+
'cont.wast',
115+
'cont_simple.wast',
114116
# TODO: fuzzer support for custom descriptors
115117
'remove-unused-module-elements-refs-descriptors.wast',
116118
'custom-descriptors.wast',

src/ir/iteration.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace wasm {
3232
// In general, it is preferable not to use this class and to directly access the
3333
// children (using e.g. iff->ifTrue etc.), as that is faster. However, in cases
3434
// where speed does not matter, this can be convenient. TODO: reimplement these
35-
// to avoid materializing all the chilren at once.
35+
// to avoid materializing all the children at once.
3636
//
3737
// ChildIterator - Iterates over all children
3838
//

src/literal.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ namespace wasm {
3333
class Literals;
3434
struct GCData;
3535
struct ExnData;
36+
struct ContData;
3637

3738
class Literal {
3839
// store only integers, whose bits are deterministic. floats
@@ -63,6 +64,8 @@ class Literal {
6364
std::shared_ptr<GCData> gcData;
6465
// A reference to Exn data.
6566
std::shared_ptr<ExnData> exnData;
67+
// A reference to a Continuation.
68+
std::shared_ptr<ContData> contData;
6669
};
6770

6871
public:
@@ -93,6 +96,7 @@ class Literal {
9396
}
9497
explicit Literal(std::shared_ptr<GCData> gcData, HeapType type);
9598
explicit Literal(std::shared_ptr<ExnData> exnData);
99+
explicit Literal(std::shared_ptr<ContData> contData);
96100
explicit Literal(std::string_view string);
97101
Literal(const Literal& other);
98102
Literal& operator=(const Literal& other);
@@ -105,6 +109,7 @@ class Literal {
105109
// a null or i31). This includes structs, arrays, and also strings.
106110
bool isData() const { return type.isData(); }
107111
bool isExn() const { return type.isExn(); }
112+
bool isContinuation() const { return type.isContinuation(); }
108113
bool isString() const { return type.isString(); }
109114

110115
bool isNull() const { return type.isNull(); }
@@ -312,6 +317,7 @@ class Literal {
312317
}
313318
std::shared_ptr<GCData> getGCData() const;
314319
std::shared_ptr<ExnData> getExnData() const;
320+
std::shared_ptr<ContData> getContData() const;
315321

316322
// careful!
317323
int32_t* geti32Ptr() {

src/parser/wast-parser.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,25 @@ MaybeResult<Assertion> assertTrap(Lexer& in) {
381381
return Assertion{AssertModule{ModuleAssertionType::Trap, *mod}};
382382
}
383383

384+
// (assert_suspension action msg)
385+
MaybeResult<Assertion> assertSuspension(Lexer& in) {
386+
if (!in.takeSExprStart("assert_suspension"sv)) {
387+
return {};
388+
}
389+
if (auto a = maybeAction(in)) {
390+
CHECK_ERR(a);
391+
auto msg = in.takeString();
392+
if (!msg) {
393+
return in.err("expected error message");
394+
}
395+
if (!in.takeRParen()) {
396+
return in.err("expected end of assertion");
397+
}
398+
return Assertion{AssertAction{ActionAssertionType::Suspension, *a}};
399+
}
400+
return in.err("invalid assert_suspension");
401+
}
402+
384403
MaybeResult<Assertion> assertion(Lexer& in) {
385404
if (auto a = assertReturn(in)) {
386405
CHECK_ERR(a);
@@ -402,6 +421,10 @@ MaybeResult<Assertion> assertion(Lexer& in) {
402421
CHECK_ERR(a);
403422
return *a;
404423
}
424+
if (auto a = assertSuspension(in)) {
425+
CHECK_ERR(a);
426+
return *a;
427+
}
405428
return {};
406429
}
407430

src/parser/wat-parser.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct AssertReturn {
7878
ExpectedResults expected;
7979
};
8080

81-
enum class ActionAssertionType { Trap, Exhaustion, Exception };
81+
enum class ActionAssertionType { Trap, Exhaustion, Exception, Suspension };
8282

8383
struct AssertAction {
8484
ActionAssertionType type;

src/shell-interface.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ struct ShellExternalInterface : ModuleRunner::ExternalInterface {
144144
std::cout << "exit()\n";
145145
throw ExitException();
146146
} else if (auto* inst = getImportInstance(import)) {
147-
return inst->callExport(import->base, arguments);
147+
auto flow = inst->callExport(import->base, arguments);
148+
assert(!flow.suspendTag); // TODO: support stack switching on calls
149+
return flow.values;
148150
}
149151
Fatal() << "callImport: unknown import: " << import->module.str << "."
150152
<< import->name.str;
@@ -191,7 +193,9 @@ struct ShellExternalInterface : ModuleRunner::ExternalInterface {
191193
if (func->imported()) {
192194
return callImport(func, arguments);
193195
} else {
194-
return instance.callFunction(func->name, arguments);
196+
auto flow = instance.callFunction(func->name, arguments);
197+
assert(!flow.suspendTag); // TODO: support stack switching on calls
198+
return flow.values;
195199
}
196200
}
197201

src/tools/execution-results.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ struct LoggingExternalInterface : public ShellExternalInterface {
254254
}
255255

256256
// Call the function.
257-
return instance->callFunction(func->name, arguments);
257+
auto flow = instance->callFunction(func->name, arguments);
258+
assert(!flow.suspendTag);
259+
return flow.values;
258260
}
259261

260262
void setModuleRunner(ModuleRunner* instance_) { instance = instance_; }
@@ -471,7 +473,12 @@ struct ExecutionResults {
471473
}
472474
arguments.push_back(Literal::makeZero(param));
473475
}
474-
return instance.callFunction(func->name, arguments);
476+
auto flow = instance.callFunction(func->name, arguments);
477+
if (flow.suspendTag) { // TODO: support stack switching here
478+
std::cout << "[exception thrown: unhandled suspend]" << std::endl;
479+
return Exception{};
480+
}
481+
return flow.values;
475482
} catch (const TrapException&) {
476483
return Trap{};
477484
} catch (const WasmException& e) {

src/tools/wasm-ctor-eval.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,11 @@ struct CtorEvalExternalInterface : EvallingModuleRunner::ExternalInterface {
350350
targetFunc.toString());
351351
}
352352
if (!func->imported()) {
353-
return instance.callFunction(targetFunc, arguments);
353+
auto flow = instance.callFunction(targetFunc, arguments);
354+
if (flow.suspendTag) {
355+
throw FailToEvalException("unhandled suspend");
356+
}
357+
return flow.values;
354358
} else {
355359
throw FailToEvalException(
356360
std::string("callTable on imported function: ") +

src/tools/wasm-shell.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,12 @@ struct Shell {
186186
struct TrapResult {};
187187
struct HostLimitResult {};
188188
struct ExceptionResult {};
189-
using ActionResult =
190-
std::variant<Literals, TrapResult, HostLimitResult, ExceptionResult>;
189+
struct SuspensionResult {};
190+
using ActionResult = std::variant<Literals,
191+
TrapResult,
192+
HostLimitResult,
193+
ExceptionResult,
194+
SuspensionResult>;
191195

192196
std::string resultToString(ActionResult& result) {
193197
if (std::get_if<TrapResult>(&result)) {
@@ -196,6 +200,8 @@ struct Shell {
196200
return "exceeded host limit";
197201
} else if (std::get_if<ExceptionResult>(&result)) {
198202
return "exception";
203+
} else if (std::get_if<SuspensionResult>(&result)) {
204+
return "suspension";
199205
} else if (auto* vals = std::get_if<Literals>(&result)) {
200206
std::stringstream ss;
201207
ss << *vals;
@@ -213,8 +219,9 @@ struct Shell {
213219
return TrapResult{};
214220
}
215221
auto& instance = it->second;
222+
Flow flow;
216223
try {
217-
return instance->callExport(invoke->name, invoke->args);
224+
flow = instance->callExport(invoke->name, invoke->args);
218225
} catch (TrapException&) {
219226
return TrapResult{};
220227
} catch (HostLimitException&) {
@@ -224,6 +231,10 @@ struct Shell {
224231
} catch (...) {
225232
WASM_UNREACHABLE("unexpected error");
226233
}
234+
if (flow.suspendTag) {
235+
return SuspensionResult{};
236+
}
237+
return flow.values;
227238
} else if (auto* get = std::get_if<GetAction>(&act)) {
228239
auto it = instances.find(get->base ? *get->base : lastModule);
229240
if (it == instances.end()) {
@@ -390,6 +401,12 @@ struct Shell {
390401
}
391402
err << "expected exception";
392403
break;
404+
case ActionAssertionType::Suspension:
405+
if (std::get_if<SuspensionResult>(&result)) {
406+
return Ok{};
407+
}
408+
err << "expected suspension";
409+
break;
393410
}
394411
err << ", got " << resultToString(result);
395412
return Err{err.str()};

0 commit comments

Comments
 (0)