forked from rbeeli/websocketclient-cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcorochain.hpp
112 lines (86 loc) · 2.73 KB
/
corochain.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#pragma once
#include <coroutine>
#include <optional>
#include <variant>
#include <memory>
#include <exception>
namespace NNet {
template<typename T> struct TFinalSuspendContinuation;
template<typename T> struct TValueTask;
template<typename T>
struct TValuePromiseBase {
std::suspend_never initial_suspend() { return {}; }
TFinalSuspendContinuation<T> final_suspend() noexcept;
std::coroutine_handle<> Caller = std::noop_coroutine();
};
template<typename T>
struct TValuePromise: public TValuePromiseBase<T> {
TValueTask<T> get_return_object();
void return_value(const T& t) {
ErrorOr = t;
}
void return_value(T&& t) {
ErrorOr = std::move(t);
}
void unhandled_exception() {
ErrorOr = std::current_exception();
}
std::optional<std::variant<T, std::exception_ptr>> ErrorOr;
};
template<typename T>
struct TValueTaskBase : std::coroutine_handle<TValuePromise<T>> {
~TValueTaskBase() { this->destroy(); }
bool await_ready() {
return !!this->promise().ErrorOr;
}
void await_suspend(std::coroutine_handle<> caller) {
this->promise().Caller = caller;
}
using promise_type = TValuePromise<T>;
};
template<typename T>
struct TValueTask : public TValueTaskBase<T> {
T await_resume() {
auto& errorOr = *this->promise().ErrorOr;
if (auto* res = std::get_if<T>(&errorOr)) {
return std::move(*res);
} else {
std::rethrow_exception(std::get<std::exception_ptr>(errorOr));
}
}
};
template<> struct TValueTask<void>;
template<>
struct TValuePromise<void>: public TValuePromiseBase<void> {
TValueTask<void> get_return_object();
void return_void() {
ErrorOr = nullptr;
}
void unhandled_exception() {
ErrorOr = std::current_exception();
}
std::optional<std::exception_ptr> ErrorOr;
};
template<>
struct TValueTask<void> : public TValueTaskBase<void> {
void await_resume() {
auto& errorOr = *this->promise().ErrorOr;
if (errorOr) {
std::rethrow_exception(errorOr);
}
}
};
template<typename T>
struct TFinalSuspendContinuation {
bool await_ready() noexcept { return false; }
std::coroutine_handle<> await_suspend(std::coroutine_handle<TValuePromise<T>> h) noexcept {
return h.promise().Caller;
}
void await_resume() noexcept { }
};
inline TValueTask<void> TValuePromise<void>::get_return_object() { return { TValueTask<void>::from_promise(*this) }; }
template<typename T>
TValueTask<T> TValuePromise<T>::get_return_object() { return { TValueTask<T>::from_promise(*this) }; }
template<typename T>
TFinalSuspendContinuation<T> TValuePromiseBase<T>::final_suspend() noexcept { return {}; }
} // namespace NNet