diff --git a/.circleci/config.yml b/.circleci/config.yml index da8efe4..2930aff 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -36,8 +36,6 @@ jobs: - run: name: "Test" command: cd build && ctest -C Release --output-on-failure && cd .. - environment: - SOCKET_LOG: debug - run: name: "Install" command: cmake --install build --config Release @@ -75,5 +73,5 @@ workflows: - build: matrix: parameters: - os: [ "jammy", "focal" ] + os: [ "jammy", "focal", "debian-10", "debian-11" ] shared: [ "ON", "OFF" ] diff --git a/.clangd b/.clangd new file mode 100644 index 0000000..c8f81f7 --- /dev/null +++ b/.clangd @@ -0,0 +1,19 @@ +Diagnostics: + ClangTidy: + Add: [ + 'bugprone*', + 'modernize*', + 'cppcoreguidelines*', + 'performance*', + 'readability*' + ] + Remove: [ + 'modernize-use-trailing-return-type', + 'cppcoreguidelines-no-malloc', + 'cppcoreguidelines-owning-memory', + "cppcoreguidelines-macro-usage", + 'cppcoreguidelines-pro-type-reinterpret-cast', + 'cppcoreguidelines-pro-type-union-access', + 'cppcoreguidelines-special-member-functions', + 'bugprone-easily-swappable-parameters' + ] diff --git a/.github/workflows/PR.yml b/.github/workflows/PR.yml index 10f9f21..a974478 100644 --- a/.github/workflows/PR.yml +++ b/.github/workflows/PR.yml @@ -24,13 +24,15 @@ jobs: submodules: true - name: Install LLVM and Clang - run: brew install llvm@16 + run: | + brew update || true + brew install llvm@17 || true - name: Install Rust toolchain uses: actions-rs/toolchain@v1 with: - toolchain: nightly - default: true + toolchain: nightly-2023-07-07 + profile: minimal - name: Configure CMake # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. @@ -52,8 +54,6 @@ jobs: # Execute tests defined by the CMake configuration. # See https://cmake.org/cmake/help/latest/manual/ctest.1.html for more detail run: ctest -C ${{env.BUILD_TYPE}} --output-on-failure - env: - SOCKET_LOG: debug - name: Install run: sudo cmake --install build --config Release diff --git a/.github/workflows/Push-Dev-Container.yml b/.github/workflows/Push-Dev-Container.yml index eb3412b..bcd457d 100644 --- a/.github/workflows/Push-Dev-Container.yml +++ b/.github/workflows/Push-Dev-Container.yml @@ -1,6 +1,7 @@ name: Push Dev Container on: + workflow_dispatch: push: paths: - dockerfile/** @@ -11,32 +12,12 @@ on: branches: [ "main" ] jobs: - build-dev-container-focal: + build-dev-containers: runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - submodules: true - - - name: Log in to Docker Hub - uses: docker/login-action@v2 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_PASSWORD }} - - - name: Build and push Docker image - uses: docker/build-push-action@v4 - with: - context: . - file: ./dockerfile/dev-containers/focal/Dockerfile - push: true - tags: congyuwang/socket-manager-dev:focal - cache-from: type=registry,ref=congyuwang/socket-manager-dev:focal - cache-to: type=inline - - build-dev-container-jammy: - runs-on: ubuntu-latest + strategy: + matrix: + os: ["focal", "jammy", "debian-10", "debian-11"] steps: - uses: actions/checkout@v3 @@ -53,8 +34,8 @@ jobs: uses: docker/build-push-action@v4 with: context: . - file: ./dockerfile/dev-containers/jammy/Dockerfile + file: ./dockerfile/dev-containers/${{ matrix.os }}/Dockerfile push: true - tags: congyuwang/socket-manager-dev:jammy - cache-from: type=registry,ref=congyuwang/socket-manager-dev:jammy + tags: congyuwang/socket-manager-dev:${{ matrix.os }} + cache-from: type=registry,ref=congyuwang/socket-manager-dev:${{ matrix.os }} cache-to: type=inline diff --git a/.gitignore b/.gitignore index 3ea6d48..3125e9d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ /target /Cargo.lock **/.idea +**/.cache build build-debug cmake-build-debug diff --git a/.gitmodules b/.gitmodules index 98dca17..153978d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "concurrentqueue"] path = tests/concurrentqueue url = https://github.com/cameron314/concurrentqueue.git +[submodule "tests/spdlog"] + path = tests/spdlog-repo + url = https://github.com/gabime/spdlog diff --git a/CMakeLists.txt b/CMakeLists.txt index 7900e6d..46a13c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/toolchain.cmake) # define project -project(socket_manager LANGUAGES C CXX VERSION 0.3.1) +project(socket_manager LANGUAGES C CXX VERSION 0.3.3) # set default build type as shared option(BUILD_SHARED_LIBS "Build using shared libraries" ON) @@ -18,11 +18,6 @@ set(Rust_TOOLCHAIN "nightly" CACHE STRING "requires nightly") # enable cross language lto corrosion_add_target_rustflags(tokio-socket-manager -C linker-plugin-lto) -# set target triple as rust target -set(BUILD_TRIPLE ${Rust_CARGO_TARGET_ARCH}-${Rust_CARGO_TARGET_VENDOR}-${Rust_CARGO_TARGET_OS}-${Rust_CARGO_TARGET_ENV}) -set(CMAKE_C_COMPILER_TARGET ${BUILD_TRIPLE}) -set(CMAKE_CXX_COMPILER_TARGET ${BUILD_TRIPLE}) - set(include_dest "include/${PROJECT_NAME}-${PROJECT_VERSION}") set(main_lib_dest "lib/${PROJECT_NAME}-${PROJECT_VERSION}") diff --git a/Cargo.toml b/Cargo.toml index b2a74a2..fd01288 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-socket-manager" -version = "0.3.1" +version = "0.3.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -8,11 +8,29 @@ edition = "2021" crate-type = ["staticlib"] [dependencies] -async-ringbuf = "0.1.3" -dashmap = "5.4.0" -futures = { version = "0.3.28", default-features = false, features = ["async-await"] } +dashmap = { version = "5.4.0", features = ["inline"] } libc = "0.2.146" socket2 = "0.5.3" -tokio = { version = "1.28.2", features = ["full"] } -tracing = "0.1.37" -tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } + +[dependencies.async-ringbuf] +git = "https://github.com/Congyuwang/ringbuf.git" + +[dependencies.tokio] +version = "1.29.1" +default-features = false +features = ["rt", "rt-multi-thread", "net", "time", "sync", "io-util", "macros", "parking_lot"] + +[dependencies.futures] +version = "0.3.28" +default-features = false +features = ["async-await"] + +[dependencies.tracing] +version = "0.1.37" +default-features = false +features = ["std"] + +[dependencies.tracing-subscriber] +version = "0.3.17" +default-features = false +features = ["std", "fmt", "env-filter", "registry"] diff --git a/README.md b/README.md index 36deac1..ec67fee 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # A C++ Library Developed In Rust Tokio To Manage Multiple TCP Connections -Easily manage multiple socket connections asynchronously in C++. +Easily manage multiple socket connections asynchronously in C++. ## Installation @@ -10,11 +10,11 @@ Easily manage multiple socket connections asynchronously in C++. curl https://sh.rustup.rs -sSf | sh -s -- -y --default-toolchain nightly ``` -- Step 2: Install LLVM 16 +- Step 2: Install LLVM 17 macOS: ```shell -brew install llvm@16 +brew install llvm@17 # get brew environment eval "$(brew shellenv)" ``` @@ -23,8 +23,8 @@ linux ```shell wget https://apt.llvm.org/llvm.sh chmod +x llvm.sh -sudo ./llvm.sh 16 all -sudo ./update-alternatives-clang.sh 16 9999 +sudo ./llvm.sh 17 all +sudo ./update-alternatives-clang.sh 17 9999 ``` - Step 3: Pull the source code diff --git a/dockerfile/dev-containers/debian-10/Dockerfile b/dockerfile/dev-containers/debian-10/Dockerfile new file mode 100644 index 0000000..a015d40 --- /dev/null +++ b/dockerfile/dev-containers/debian-10/Dockerfile @@ -0,0 +1,57 @@ +FROM --platform=linux/amd64 debian:10 + +# use ustc mirror in China, comment out if you are not in China +#RUN sed -i 's@//.*archive.ubuntu.com@//mirrors.tuna.tsinghua.edu.cn@g' /etc/apt/sources.list + +# remove security source to avoid connection error +#RUN sed -i '/^deb http:\/\/security/d' /etc/apt/sources.list + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y -q\ + build-essential \ + software-properties-common \ + libssl-dev \ + git \ + vim \ + wget \ + curl \ + screen + +# install cmake +WORKDIR /root +RUN mkdir temp +WORKDIR /root/temp +RUN curl -OL https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4.tar.gz +RUN tar -xzvf cmake-3.27.4.tar.gz + +WORKDIR /root/temp/cmake-3.27.4 +RUN ./bootstrap -- -DCMAKE_BUILD_TYPE:STRING=Release +RUN make -j4 +RUN make install + +WORKDIR /root +RUN rm -rf temp + +# install llvm@17 +COPY ./update-alternatives-clang.sh /root +RUN wget https://apt.llvm.org/llvm.sh +RUN chmod +x llvm.sh +RUN ./llvm.sh 17 all +# use tsinghua mirror in China, comment out if you are not in China +#RUN ./llvm.sh 17 all -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt +RUN ./update-alternatives-clang.sh 17 9999 +RUN rm llvm.sh +RUN rm update-alternatives-clang.sh + +# install nightly rust +# use ustc mirror in China, comment out if you are not in China +#ARG RUSTUP_UPDATE_ROOT="https://mirrors.ustc.edu.cn/rust-static/rustup" +#ARG RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup" +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly +ENV PATH="/root/.cargo/bin:${PATH}" + +RUN chsh -s /bin/bash + +ENTRYPOINT ["/bin/bash"] diff --git a/dockerfile/dev-containers/debian-11/Dockerfile b/dockerfile/dev-containers/debian-11/Dockerfile new file mode 100644 index 0000000..8de9455 --- /dev/null +++ b/dockerfile/dev-containers/debian-11/Dockerfile @@ -0,0 +1,57 @@ +FROM --platform=linux/amd64 debian:11 + +# use ustc mirror in China, comment out if you are not in China +#RUN sed -i 's@//.*archive.ubuntu.com@//mirrors.tuna.tsinghua.edu.cn@g' /etc/apt/sources.list + +# remove security source to avoid connection error +#RUN sed -i '/^deb http:\/\/security/d' /etc/apt/sources.list + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y -q\ + build-essential \ + software-properties-common \ + libssl-dev \ + git \ + vim \ + wget \ + curl \ + screen + +# install cmake +WORKDIR /root +RUN mkdir temp +WORKDIR /root/temp +RUN curl -OL https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4.tar.gz +RUN tar -xzvf cmake-3.27.4.tar.gz + +WORKDIR /root/temp/cmake-3.27.4 +RUN ./bootstrap -- -DCMAKE_BUILD_TYPE:STRING=Release +RUN make -j4 +RUN make install + +WORKDIR /root +RUN rm -rf temp + +# install llvm@17 +COPY ./update-alternatives-clang.sh /root +RUN wget https://apt.llvm.org/llvm.sh +RUN chmod +x llvm.sh +RUN ./llvm.sh 17 all +# use tsinghua mirror in China, comment out if you are not in China +#RUN ./llvm.sh 17 all -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt +RUN ./update-alternatives-clang.sh 17 9999 +RUN rm llvm.sh +RUN rm update-alternatives-clang.sh + +# install nightly rust +# use ustc mirror in China, comment out if you are not in China +#ARG RUSTUP_UPDATE_ROOT="https://mirrors.ustc.edu.cn/rust-static/rustup" +#ARG RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup" +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly +ENV PATH="/root/.cargo/bin:${PATH}" + +RUN chsh -s /bin/bash + +ENTRYPOINT ["/bin/bash"] diff --git a/dockerfile/dev-containers/focal/Dockerfile b/dockerfile/dev-containers/focal/Dockerfile index 6ca17fd..8d4df41 100644 --- a/dockerfile/dev-containers/focal/Dockerfile +++ b/dockerfile/dev-containers/focal/Dockerfile @@ -12,23 +12,37 @@ RUN apt-get update && \ apt-get install -y -q\ build-essential \ software-properties-common \ + libssl-dev \ lsb-core \ - cmake \ git \ vim \ wget \ curl \ screen -# install llvm@16 +# install cmake WORKDIR /root +RUN mkdir temp +WORKDIR /root/temp +RUN curl -OL https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4.tar.gz +RUN tar -xzvf cmake-3.27.4.tar.gz + +WORKDIR /root/temp/cmake-3.27.4 +RUN ./bootstrap -- -DCMAKE_BUILD_TYPE:STRING=Release +RUN make -j4 +RUN make install + +WORKDIR /root +RUN rm -rf temp + +# install llvm@17 COPY ./update-alternatives-clang.sh /root RUN wget https://apt.llvm.org/llvm.sh RUN chmod +x llvm.sh -RUN ./llvm.sh 16 all +RUN ./llvm.sh 17 all # use tsinghua mirror in China, comment out if you are not in China -#RUN ./llvm.sh 16 all -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt -RUN ./update-alternatives-clang.sh 16 9999 +#RUN ./llvm.sh 17 all -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt +RUN ./update-alternatives-clang.sh 17 9999 RUN rm llvm.sh RUN rm update-alternatives-clang.sh diff --git a/dockerfile/dev-containers/jammy/Dockerfile b/dockerfile/dev-containers/jammy/Dockerfile index 7a28d90..c7b3840 100644 --- a/dockerfile/dev-containers/jammy/Dockerfile +++ b/dockerfile/dev-containers/jammy/Dockerfile @@ -12,23 +12,37 @@ RUN apt-get update && \ apt-get install -y -q\ build-essential \ software-properties-common \ + libssl-dev \ lsb-core \ - cmake \ git \ vim \ wget \ curl \ screen -# install llvm@16 +# install cmake WORKDIR /root +RUN mkdir temp +WORKDIR /root/temp +RUN curl -OL https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4.tar.gz +RUN tar -xzvf cmake-3.27.4.tar.gz + +WORKDIR /root/temp/cmake-3.27.4 +RUN ./bootstrap -- -DCMAKE_BUILD_TYPE:STRING=Release +RUN make -j4 +RUN make install + +WORKDIR /root +RUN rm -rf temp + +# install llvm@17 COPY ./update-alternatives-clang.sh /root RUN wget https://apt.llvm.org/llvm.sh RUN chmod +x llvm.sh -RUN ./llvm.sh 16 all +RUN ./llvm.sh 17 all # use tsinghua mirror in China, comment out if you are not in China -#RUN ./llvm.sh 16 all -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt -RUN ./update-alternatives-clang.sh 16 9999 +#RUN ./llvm.sh 17 all -m https://mirrors.tuna.tsinghua.edu.cn/llvm-apt +RUN ./update-alternatives-clang.sh 17 9999 RUN rm llvm.sh RUN rm update-alternatives-clang.sh diff --git a/examples/echo_server/.clangd b/examples/echo_server/.clangd new file mode 100644 index 0000000..666b006 --- /dev/null +++ b/examples/echo_server/.clangd @@ -0,0 +1,9 @@ +Diagnostics: + ClangTidy: + Add: [ + 'bugprone*', + 'modernize*', + 'cppcoreguidelines*', + 'performance*', + 'readability*' + ] diff --git a/examples/echo_server/justfile b/examples/echo_server/justfile index 7c5944c..6fb18b0 100644 --- a/examples/echo_server/justfile +++ b/examples/echo_server/justfile @@ -14,4 +14,4 @@ build: codesign -s - -v -f --entitlements debug.plist ./build/echo_server run: - SOCKET_LOG=info ./build/echo_server + ./build/echo_server diff --git a/examples/echo_server/src/echo_server.cpp b/examples/echo_server/src/echo_server.cpp index 9dc0875..6ba1a24 100644 --- a/examples/echo_server/src/echo_server.cpp +++ b/examples/echo_server/src/echo_server.cpp @@ -1,9 +1,9 @@ -#include #include +#include #include -#include +#include #include -#include +#include /** * UniqueWaker is a wrapper of `socket_manager::Waker` @@ -13,14 +13,10 @@ class WrapWaker : public socket_manager::Notifier { public: explicit WrapWaker(socket_manager::Waker &&wake) : waker(std::move(wake)) {} - void set_waker(socket_manager::Waker &&wake) { - waker = std::move(wake); - } + void set_waker(socket_manager::Waker &&wake) { waker = std::move(wake); } private: - void wake() override { - waker.wake(); - } + void wake() override { waker.wake(); } socket_manager::Waker waker; }; @@ -33,10 +29,9 @@ class WrapWaker : public socket_manager::Notifier { */ class EchoReceiver : public socket_manager::MsgReceiverAsync { public: - explicit EchoReceiver( - std::shared_ptr &&sender, - const std::shared_ptr &waker - ) : waker(waker), sender(std::move(sender)) {}; + explicit EchoReceiver(std::shared_ptr &&sender, + const std::shared_ptr &waker) + : waker(waker), sender(std::move(sender)){}; /** * Release resources to break potential ref cycles. @@ -47,7 +42,8 @@ class EchoReceiver : public socket_manager::MsgReceiverAsync { } private: - long on_message_async(std::string_view data, socket_manager::Waker &&wake) override { + auto on_message_async(std::string_view data, socket_manager::Waker &&wake) + -> long override { waker->set_waker(std::move(wake)); return sender->send_async(data); }; @@ -73,7 +69,8 @@ class EchoCallback : public socket_manager::ConnCallback { conn->start(std::move(recv), std::move(waker)); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override { { std::lock_guard lock(mutex); auto find = receivers.find(local_addr + peer_addr); @@ -82,17 +79,21 @@ class EchoCallback : public socket_manager::ConnCallback { find->second->close(); receivers.erase(find); } else { - throw std::runtime_error("connection not found: " + local_addr + " -> " + peer_addr); + throw std::runtime_error("connection not found: " + local_addr + + " -> " + peer_addr); } } - std::cout << "connection closed: " << local_addr << " -> " << peer_addr << std::endl; + std::cout << "connection closed: " << local_addr << " -> " << peer_addr + << std::endl; } - void on_listen_error(const std::string &addr, const std::string &err) override { + void on_listen_error(const std::string &addr, + const std::string &err) override { throw std::runtime_error("listen error: addr=" + addr + ", " + err); } - void on_connect_error(const std::string &addr, const std::string &err) override { + void on_connect_error(const std::string &addr, + const std::string &err) override { throw std::runtime_error("connect error: addr=" + addr + ", " + err); } @@ -100,7 +101,7 @@ class EchoCallback : public socket_manager::ConnCallback { std::unordered_map> receivers; }; -int main() { +auto main() -> int { // start the server auto callback = std::make_shared(); auto manager = socket_manager::SocketManager(callback); diff --git a/include/socket_manager/common/notifier.h b/include/socket_manager/common/notifier.h index 29627d8..c91cff5 100644 --- a/include/socket_manager/common/notifier.h +++ b/include/socket_manager/common/notifier.h @@ -4,34 +4,33 @@ #include "socket_manager_c_api.h" namespace socket_manager { - /** - * @brief The Notifier class is used to receive notification - * from rust runtime to request c/c++ task for further - * execution. - * - * @note Our current implementation does not require ref count - * since we bound the lifetime of the notifier to the - * lifetime of the connection, and guarantee that the passed - * notifier is valid by keeping a reference of the notifier - * in the related connection. - */ - class Notifier { - - public: - virtual ~Notifier() = default; - - private: - - virtual void wake() = 0; - - friend void::socket_manager_extern_notifier_wake(struct SOCKET_MANAGER_C_API_Notifier this_); - - }; - - class NoopNotifier : public Notifier { - public: - void wake() override {} - }; -} - -#endif //SOCKET_MANAGER_NOTIFIER_H +/** + * @brief The Notifier class is used to receive notification + * from rust runtime to request c/c++ task for further + * execution. + * + * @note Our current implementation does not require ref count + * since we bound the lifetime of the notifier to the + * lifetime of the connection, and guarantee that the passed + * notifier is valid by keeping a reference of the notifier + * in the related connection. + */ +class Notifier { + +public: + virtual ~Notifier() = default; + +private: + virtual void wake() = 0; + + friend void ::socket_manager_extern_notifier_wake( + struct SOCKET_MANAGER_C_API_Notifier this_); +}; + +class NoopNotifier : public Notifier { +public: + void wake() override {} +}; +} // namespace socket_manager + +#endif // SOCKET_MANAGER_NOTIFIER_H diff --git a/include/socket_manager/common/waker.h b/include/socket_manager/common/waker.h index 341b64b..f4654a8 100644 --- a/include/socket_manager/common/waker.h +++ b/include/socket_manager/common/waker.h @@ -5,57 +5,56 @@ namespace socket_manager { +/** + * Return `PENDING` to interrupt runtime task. + */ +const long PENDING = -1; + +/** + * @brief Waker is used to wake up a pending runtime task. + * + * The implementation of `MsgReceiverAsync::on_message_async` + * can return `PENDING = -1` to interrupt a message receiving + * task in the runtime (e.g. when the caller buffer is full). + * + * And use `waker.wake()` to resume the message receiving task + * when the caller buffer is ready. + * + *

Resource Leak Note

+ * The `Waker` must be properly destroyed to avoid resource leak. + */ +class Waker { +public: /** - * Return `PENDING` to interrupt runtime task. + * Call wake() to wake up the receiver process. */ - const long PENDING = -1; + void wake(); + + ~Waker(); /** - * @brief Waker is used to wake up a pending runtime task. - * - * The implementation of `MsgReceiverAsync::on_message_async` - * can return `PENDING = -1` to interrupt a message receiving - * task in the runtime (e.g. when the caller buffer is full). - * - * And use `waker.wake()` to resume the message receiving task - * when the caller buffer is ready. - * - *

Resource Leak Note

- * The `Waker` must be properly destroyed to avoid resource leak. + * Create an empty noop waker. */ - class Waker { - public: - /** - * Call wake() to wake up the receiver process. - */ - void wake(); - - ~Waker(); - - /** - * Create an empty noop waker. - */ - explicit Waker(); + explicit Waker(); - Waker(const Waker &) = delete; + Waker(const Waker &) = delete; - Waker &operator=(const Waker &) = delete; + Waker &operator=(const Waker &) = delete; - Waker(Waker &&) noexcept; + Waker(Waker &&) noexcept; - Waker &operator=(Waker &&) noexcept; + Waker &operator=(Waker &&) noexcept; - private: - explicit Waker(SOCKET_MANAGER_C_API_CWaker waker); +private: + explicit Waker(SOCKET_MANAGER_C_API_CWaker waker); - friend long::socket_manager_extern_on_msg( - struct SOCKET_MANAGER_C_API_OnMsgObj this_, - SOCKET_MANAGER_C_API_ConnMsg msg, - SOCKET_MANAGER_C_API_CWaker waker, - char **err); + friend long ::socket_manager_extern_on_msg( + struct SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, SOCKET_MANAGER_C_API_CWaker waker, + char **err); - SOCKET_MANAGER_C_API_CWaker waker; - }; + SOCKET_MANAGER_C_API_CWaker waker; +}; } // namespace socket_manager -#endif //SOCKET_MANAGER_WAKER_H +#endif // SOCKET_MANAGER_WAKER_H diff --git a/include/socket_manager/conn_callback.h b/include/socket_manager/conn_callback.h index f1a2598..0d2e67a 100644 --- a/include/socket_manager/conn_callback.h +++ b/include/socket_manager/conn_callback.h @@ -1,111 +1,107 @@ #ifndef SOCKET_MANAGER_CONN_CALLBACK_H #define SOCKET_MANAGER_CONN_CALLBACK_H +#include "connection.h" +#include "socket_manager_c_api.h" +#include +#include #include #include +#include #include #include -#include -#include -#include -#include "connection.h" -#include "socket_manager_c_api.h" namespace socket_manager { +/** + * The callback object for handling connection events. + *

+ * Throwing error in the callback will cause the runtime + * to abort. + * + *

Thread Safety

+ * All callback methods must be thread safe and non-blocking. + */ +class ConnCallback { +public: + virtual ~ConnCallback() = default; + +private: + friend void ::socket_manager_extern_on_conn( + struct SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates states, char **err); + /** - * The callback object for handling connection events. + * Called when a new connection is established. + * + *

Error handling

+ * Throwing error in `on_connect` callback will close the connection + * and a `on_connection_close` callback will be evoked. *

- * Throwing error in the callback will cause the runtime - * to abort. + * It should be non-blocking. + *

+ * Drop the returned `MsgSender` to close the connection. * - *

Thread Safety

- * All callback methods must be thread safe and non-blocking. + * @param local_addr the local address of the connection. + * @param peer_addr the peer address of the connection. + * @param conn a `Connection` object for starting the connection. + * @param sender a `Sender` object for sending data. */ - class ConnCallback { - public: - - virtual ~ConnCallback() = default; - - private: - - friend void::socket_manager_extern_on_conn( - struct SOCKET_MANAGER_C_API_OnConnObj this_, - SOCKET_MANAGER_C_API_ConnStates conn, - char **err); + virtual void on_connect(const std::string &local_addr, + const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr sender) = 0; - /** - * Called when a new connection is established. - * - *

Error handling

- * Throwing error in `on_connect` callback will close the connection - * and a `on_connection_close` callback will be evoked. - *

- * It should be non-blocking. - *

- * Drop the returned `MsgSender` to close the connection. - * - * @param local_addr the local address of the connection. - * @param peer_addr the peer address of the connection. - * @param conn a `Connection` object for starting the connection. - * @param sender a `Sender` object for sending data. - */ - virtual void on_connect(const std::string &local_addr, - const std::string &peer_addr, - std::shared_ptr conn, - std::shared_ptr sender) = 0; - - /** - * Called when a connection is closed. - * - *

Error handling

- * Throwing error in `on_connection_close` callback is logged as error, - * but ignored. - *

- * It should be non-blocking. - * - * @param local_addr the local address of the connection. - * @param peer_addr the peer address of the connection. - */ - virtual void on_connection_close(const std::string &local_addr, - const std::string &peer_addr) = 0; - - /** - * Called when an error occurs when listening on the given address. - * - *

Error handling

- * Throwing error in `on_listen_error` callback is logged as error, - * but ignored. - *

- * Should be non-blocking. - * - * @param addr the address that failed to listen on. - * @param err the error message. - */ - virtual void on_listen_error(const std::string &addr, - const std::string &err) = 0; + /** + * Called when a connection is closed. + * + *

Error handling

+ * Throwing error in `on_connection_close` callback is logged as error, + * but ignored. + *

+ * It should be non-blocking. + * + * @param local_addr the local address of the connection. + * @param peer_addr the peer address of the connection. + */ + virtual void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) = 0; - /** - * Called when an error occurs when connecting to the given address. - * - *

Error handling

- * Throwing error in `on_connect_error` callback is logged as error, - * but ignored. - *

- * Should be non-blocking. - * - * @param addr the address that failed to connect to. - * @param err the error message. - */ - virtual void on_connect_error(const std::string &addr, - const std::string &err) = 0; + /** + * Called when an error occurs when listening on the given address. + * + *

Error handling

+ * Throwing error in `on_listen_error` callback is logged as error, + * but ignored. + *

+ * Should be non-blocking. + * + * @param addr the address that failed to listen on. + * @param err the error message. + */ + virtual void on_listen_error(const std::string &addr, + const std::string &err) = 0; - // keep the connection object alive before connection closed - // to ensure that message listener is alive during connection. - std::mutex lock; - std::unordered_map> conns; + /** + * Called when an error occurs when connecting to the given address. + * + *

Error handling

+ * Throwing error in `on_connect_error` callback is logged as error, + * but ignored. + *

+ * Should be non-blocking. + * + * @param addr the address that failed to connect to. + * @param err the error message. + */ + virtual void on_connect_error(const std::string &addr, + const std::string &err) = 0; - }; -} + // keep the connection object alive before connection closed + // to ensure that message listener is alive during connection. + std::mutex lock; + std::unordered_map> conns; +}; +} // namespace socket_manager -#endif //SOCKET_MANAGER_CONN_CALLBACK_H +#endif // SOCKET_MANAGER_CONN_CALLBACK_H diff --git a/include/socket_manager/connection.h b/include/socket_manager/connection.h index 04229fc..28da97e 100644 --- a/include/socket_manager/connection.h +++ b/include/socket_manager/connection.h @@ -1,110 +1,105 @@ #ifndef SOCKET_MANAGER_CONNECTION_H #define SOCKET_MANAGER_CONNECTION_H +#include "msg_receiver.h" +#include "socket_manager_c_api.h" #include +#include #include #include -#include -#include "msg_receiver.h" -#include "msg_sender.h" -#include "socket_manager_c_api.h" namespace socket_manager { - const unsigned long long DEFAULT_WRITE_FLUSH_MILLI_SEC = 5; // 5 millisecond - const unsigned long long DEFAULT_READ_MSG_FLUSH_MILLI_SEC = 5; // 5 millisecond - const size_t DEFAULT_MSG_BUF_SIZE = 64 * 1024; // 64KB +const unsigned long long DEFAULT_WRITE_FLUSH_MILLI_SEC = 1; // 1 millisecond +const unsigned long long DEFAULT_READ_MSG_FLUSH_MILLI_SEC = 1; // 1 millisecond +const size_t DEFAULT_MSG_BUF_SIZE = + static_cast(64) * 1024; // 64KB - class MsgSender; +class MsgSender; - class Notifier; +class Notifier; - class NoopNotifier; +/** + * Use Connection to send and receive messages from + * established connections. + */ +class Connection { +public: /** - * Use Connection to send and receive messages from - * established connections. + * Start a connection. + * + *

Start / Close

+ * Exactly one of `start` or `close` should be called! + * Calling more than once will throw runtime exception. + * Not calling any of them might result in resource leak. + * + *

Close started connection

+ * Drop the returned MsgSender object to close the connection + * after starting it. + * + *

Thread Safety

+ * Thread safe, but should be called exactly once, + * otherwise throws error. + * + * @param msg_receiver the message receiver callback to receive + * messages from the peer. Non-null. + * @param send_notifier the notifier for getting notified when the + * send buffer is ready. Pass nullptr to use a noop notifier. + * This parameter is needed only for async sending. + * @param msg_buffer_size The size of the message buffer in bytes. + * Set to 0 to use no buffer (i.e., call `on_msg` immediately on receiving + * any data, expecting the user to implement buffer if needed). + * The minimum is 8KB, and the maximum is 8MB. Default to 64KB. + * @param write_flush_interval The interval in `milliseconds` + * of write buffer auto flushing. Set to 0 to disable auto flush. + * Default to 1 millisecond. + * @param read_msg_flush_interval The interval in `milliseconds` of read + * message buffer auto flushing. The value is ignored when `msg_buffer_size` + * is 0. Set to 0 to disable auto flush (which is not recommended since there + * is no manual flush, and small messages might get stuck in buffer). Default + * to 1 millisecond. */ - class Connection { - - public: - - /** - * Start a connection. - * - *

Start / Close

- * Exactly one of `start` or `close` should be called! - * Calling more than once will throw runtime exception. - * Not calling any of them might result in resource leak. - * - *

Close started connection

- * Drop the returned MsgSender object to close the connection - * after starting it. - * - *

Thread Safety

- * Thread safe, but should be called exactly once, - * otherwise throws error. - * - * @param msg_receiver the message receiver callback to receive - * messages from the peer. Non-null. - * @param send_notifier the notifier for getting notified when the - * send buffer is ready. Pass nullptr to use a noop notifier. - * This parameter is needed only for async sending. - * @param msg_buffer_size The size of the message buffer in bytes. - * Set to 0 to use no buffer (i.e., call `on_msg` immediately on receiving - * any data, expecting the user to implement buffer if needed). - * The minimum is 8KB, and the maximum is 8MB. Default to 64KB. - * @param write_flush_interval The interval in `milliseconds` - * of write buffer auto flushing. Set to 0 to disable auto flush. - * Default to 1 millisecond. - * @param read_msg_flush_interval The interval in `milliseconds` of read message buffer - * auto flushing. The value is ignored when `msg_buffer_size` is 0. - * Set to 0 to disable auto flush (which is not recommended since there is no - * manual flush, and small messages might get stuck in buffer). - * Default to 1 millisecond. - */ - void start( - std::shared_ptr msg_receiver, - std::shared_ptr send_notifier = nullptr, - size_t msg_buffer_size = DEFAULT_MSG_BUF_SIZE, - unsigned long long read_msg_flush_interval = DEFAULT_READ_MSG_FLUSH_MILLI_SEC, - unsigned long long write_flush_interval = DEFAULT_WRITE_FLUSH_MILLI_SEC); - - /** - * Close the connection without using it. - *

- * `on_connection_close` callback will be called. - * - *

Start / Close

- * Exactly one of `start` or `close` should be called! - * Calling more than once will throw runtime exception. - * Not calling any of them might result in resource leak. - */ - void close(); - - private: - - friend class MsgSender; - - friend void::socket_manager_extern_on_conn( - struct SOCKET_MANAGER_C_API_OnConnObj this_, - SOCKET_MANAGER_C_API_ConnStates conn, - char **err); - - // keep the msg_receiver alive - std::shared_ptr receiver; - - // keep the notifier alive - std::shared_ptr notifier; - - explicit Connection(SOCKET_MANAGER_C_API_Connection *inner); - - std::unique_ptr< - SOCKET_MANAGER_C_API_Connection, - std::function> inner; - - }; + void start( + std::shared_ptr msg_receiver, + std::shared_ptr send_notifier = nullptr, + size_t msg_buffer_size = DEFAULT_MSG_BUF_SIZE, + unsigned long long read_msg_flush_interval = + DEFAULT_READ_MSG_FLUSH_MILLI_SEC, + unsigned long long write_flush_interval = DEFAULT_WRITE_FLUSH_MILLI_SEC); + + /** + * Close the connection without using it. + *

+ * `on_connection_close` callback will be called. + * + *

Start / Close

+ * Exactly one of `start` or `close` should be called! + * Calling more than once will throw runtime exception. + * Not calling any of them might result in resource leak. + */ + void close(); + +private: + friend class MsgSender; + + friend void ::socket_manager_extern_on_conn( + struct SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates states, char **err); + + // keep the msg_receiver alive + std::shared_ptr receiver; + + // keep the notifier alive + std::shared_ptr notifier; + + explicit Connection(SOCKET_MANAGER_C_API_Connection *inner); + + std::unique_ptr> + inner; +}; } // namespace socket_manager -#endif //SOCKET_MANAGER_CONNECTION_H +#endif // SOCKET_MANAGER_CONNECTION_H diff --git a/include/socket_manager/msg_receiver.h b/include/socket_manager/msg_receiver.h index d5c02cc..b23cc0f 100644 --- a/include/socket_manager/msg_receiver.h +++ b/include/socket_manager/msg_receiver.h @@ -1,20 +1,45 @@ #ifndef SOCKET_MANAGER_MSG_RECEIVER_H #define SOCKET_MANAGER_MSG_RECEIVER_H -#include -#include -#include -#include -#include #include "socket_manager_c_api.h" #include "waker.h" +#include +#include +#include +#include +#include namespace socket_manager { +/** + * Implement this class to receive messages from Connection. + *

+ * Must read the following details to implement correctly! + * + *

Asynchronous Message Receiving

+ * The caller should return the exact number of bytes written + * to the runtime if some bytes are written. The runtime + * will increment the read offset accordingly. + *

+ * If the caller is unable to receive any bytes, + * it should return `PENDING = -1` to the runtime + * to interrupt message receiving task. The read offset + * will not be incremented. + *

+ * When the caller is able to receive bytes again, + * it should call `waker.wake()` to wake up the runtime. + * + *

Non-blocking

+ * This callback must be non-blocking. + */ +class MsgReceiverAsync { + +public: + virtual ~MsgReceiverAsync() = default; + +private: /** - * Implement this class to receive messages from Connection. - *

- * Must read the following details to implement correctly! + * Called when a message is received. * *

Asynchronous Message Receiving

* The caller should return the exact number of bytes written @@ -29,88 +54,57 @@ namespace socket_manager { * When the caller is able to receive bytes again, * it should call `waker.wake()` to wake up the runtime. * + *

MEMORY SAFETY

+ * The `data` is only valid during the call of this function. + * If you want to keep the data, you should copy it. + * *

Non-blocking

* This callback must be non-blocking. + * + *

Error Handling

+ * Throwing runtime_error in `on_message` callback will cause + * the connection to close. + * + * @param data the message received. */ - class MsgReceiverAsync { - - public: - - virtual ~MsgReceiverAsync() = default; - - private: - - /** - * Called when a message is received. - * - *

Asynchronous Message Receiving

- * The caller should return the exact number of bytes written - * to the runtime if some bytes are written. The runtime - * will increment the read offset accordingly. - *

- * If the caller is unable to receive any bytes, - * it should return `PENDING = -1` to the runtime - * to interrupt message receiving task. The read offset - * will not be incremented. - *

- * When the caller is able to receive bytes again, - * it should call `waker.wake()` to wake up the runtime. - * - *

MEMORY SAFETY

- * The `data` is only valid during the call of this function. - * If you want to keep the data, you should copy it. - * - *

Non-blocking

- * This callback must be non-blocking. - * - *

Error Handling

- * Throwing runtime_error in `on_message` callback will cause - * the connection to close. - * - * @param data the message received. - */ - virtual long on_message_async(std::string_view data, Waker &&waker) = 0; - - friend long::socket_manager_extern_on_msg( - struct SOCKET_MANAGER_C_API_OnMsgObj this_, - SOCKET_MANAGER_C_API_ConnMsg msg, - SOCKET_MANAGER_C_API_CWaker waker, - char **err); - - }; - + virtual long on_message_async(std::string_view data, Waker &&waker) = 0; + + friend long ::socket_manager_extern_on_msg( + struct SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, SOCKET_MANAGER_C_API_CWaker waker, + char **err); +}; + +/** + * If the caller has unlimited buffer implementation, + * it can use this simplified class to receive messages. + *

+ * The caller should implement `on_message` method to + * store the received message in buffer or queue and immediately + * return `on_message` method, and should not block the runtime. + * + *

Non-blocking

+ * This callback must be non-blocking. + */ +class MsgReceiver : public MsgReceiverAsync { +public: + ~MsgReceiver() override = default; + +private: /** - * If the caller has unlimited buffer implementation, - * it can use this simplified class to receive messages. + * Compared to `on_message_async`, this method assumes that + * all data is received by the caller, and the caller does + * not need to report number of bytes written to the runtime. + * Nor can the caller interrupt the runtime. *

- * The caller should implement `on_message` method to - * store the received message in buffer or queue and immediately - * return `on_message` method, and should not block the runtime. - * - *

Non-blocking

- * This callback must be non-blocking. + * Notice that this callback still needs to be non-blocking. + * @param data the message received. */ - class MsgReceiver : public MsgReceiverAsync { - public: - - ~MsgReceiver() override = default; - - private: - - /** - * Compared to `on_message_async`, this method assumes that - * all data is received by the caller, and the caller does - * not need to report number of bytes written to the runtime. - * Nor can the caller interrupt the runtime. - *

- * Notice that this callback still needs to be non-blocking. - * @param data the message received. - */ - virtual void on_message(std::string_view data) = 0; + virtual void on_message(std::string_view data) = 0; - long on_message_async(std::string_view data, Waker &&waker) final; - }; + long on_message_async(std::string_view data, Waker &&waker) final; +}; } // namespace socket_manager -#endif //SOCKET_MANAGER_MSG_RECEIVER_H +#endif // SOCKET_MANAGER_MSG_RECEIVER_H diff --git a/include/socket_manager/msg_sender.h b/include/socket_manager/msg_sender.h index 7702b95..0f719a8 100644 --- a/include/socket_manager/msg_sender.h +++ b/include/socket_manager/msg_sender.h @@ -1,113 +1,109 @@ #ifndef SOCKET_MANAGER_MSG_SENDER_H #define SOCKET_MANAGER_MSG_SENDER_H -#include "socket_manager_c_api.h" #include "connection.h" #include "notifier.h" +#include "socket_manager_c_api.h" #include -#include #include -#include +#include namespace socket_manager { - class Connection; +/** + * Use MsgSender to send messages to the peer. + *

+ * Drop the MsgSender object to close the connection. + */ +class MsgSender { +public: /** - * Use MsgSender to send messages to the peer. + * Asynchronous message sending. *

- * Drop the MsgSender object to close the connection. + * To use the method, the user should pass a `notifier` object + * when calling `Connection::start()` in order to receive notification + * when the send buffer is ready. + * + *

Async control flow (IMPORTANT)

+ * This function is non-blocking, it returns `PENDING = -1` + * if the send buffer is full. So the caller should wait + * by passing a `Notifier` which will be called when the + * buffer is ready. + *

+ * When the buffer is ready, the function returns number of bytes sent. + *

+ * The caller is responsible for updating the buffer offset!! + * + * @param data the message to send + * @param notifier `notifier.wake()` is evoked when send_async + * could accept more data. + * @return return the number of bytes successfully sent, + * or return `PENDING = -1` if the send buffer is full. + * @throws std::runtime_error when the connection is closed. */ - class MsgSender { - - public: - /** - * Asynchronous message sending. - *

- * To use the method, the user should pass a `notifier` object - * when calling `Connection::start()` in order to receive notification - * when the send buffer is ready. - * - *

Async control flow (IMPORTANT)

- * This function is non-blocking, it returns `PENDING = -1` - * if the send buffer is full. So the caller should wait - * by passing a `Notifier` which will be called when the - * buffer is ready. - *

- * When the buffer is ready, the function returns number of bytes sent. - *

- * The caller is responsible for updating the buffer offset!! - * - * @param data the message to send - * @param notifier `notifier.wake()` is evoked when send_async - * could accept more data. - * @return return the number of bytes successfully sent, - * or return `PENDING = -1` if the send buffer is full. - * @throws std::runtime_error when the connection is closed. - */ - long send_async(std::string_view data); - - /** - * Non-blocking message sending (NO BACKPRESSURE). - *

- * This method is non-blocking. It returns immediately and - * caches all the data in the internal buffer. So it comes - * without back pressure. - * - * @param data the message to send - * @param notifier `notifier.wake()` is evoked when send_async - * could accept more data. - * @throws std::runtime_error when the connection is closed. - */ - void send_nonblock(std::string_view data); + long send_async(std::string_view data); - /** - * Blocking message sending (DO NOT USE IN ASYNC CALLBACK). - * - *

Blocking!!

- * This method might block, so it should never be used within any of - * the async callbacks. Otherwise the runtime can panic! - * - *

Thread Safety

- * This method is thread safe. - * - *

Errors

- * This method throws std::runtime_error when - * the connection is closed. - * - * @param data the message to send - * @throws std::runtime_error when the connection is closed. - */ - void send_block(std::string_view data); - - /** - * Manually flush the internal buffer. - * - *

Thread Safety

- * This method is thread safe. - * - */ - void flush(); + /** + * Non-blocking message sending (NO BACKPRESSURE). + *

+ * This method is non-blocking. It returns immediately and + * caches all the data in the internal buffer. So it comes + * without back pressure. + * + * @param data the message to send + * @param notifier `notifier.wake()` is evoked when send_async + * could accept more data. + * @throws std::runtime_error when the connection is closed. + */ + void send_nonblock(std::string_view data); - private: + /** + * Blocking message sending (DO NOT USE IN ASYNC CALLBACK). + * + *

Blocking!!

+ * This method might block, so it should never be used within any of + * the async callbacks. Otherwise the runtime can panic! + * + *

Thread Safety

+ * This method is thread safe. + * + *

Errors

+ * This method throws std::runtime_error when + * the connection is closed. + * + * @param data the message to send + * @throws std::runtime_error when the connection is closed. + */ + void send_block(std::string_view data); - friend class Connection; + /** + * Manually flush the internal buffer. + * + *

Thread Safety

+ * This method is thread safe. + * + */ + void flush(); - friend void::socket_manager_extern_on_conn( - struct SOCKET_MANAGER_C_API_OnConnObj this_, - SOCKET_MANAGER_C_API_ConnStates conn, - char **err); +private: + friend class Connection; - explicit MsgSender(SOCKET_MANAGER_C_API_MsgSender *inner, const std::shared_ptr &); + friend void ::socket_manager_extern_on_conn( + struct SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates states, char **err); - // keep a reference of connection for storing waker object - // in connection, to prevent dangling pointer of waker. - std::shared_ptr conn; + explicit MsgSender(SOCKET_MANAGER_C_API_MsgSender *inner, + const std::shared_ptr &); - std::unique_ptr> inner; + // keep a reference of connection for storing waker object + // in connection, to prevent dangling pointer of waker. + std::shared_ptr conn; - }; + std::unique_ptr> + inner; +}; } // namespace socket_manager diff --git a/include/socket_manager/socket_manager.h b/include/socket_manager/socket_manager.h index bf2a887..4ec7c11 100644 --- a/include/socket_manager/socket_manager.h +++ b/include/socket_manager/socket_manager.h @@ -2,154 +2,189 @@ #define SOCKET_MANAGER_H #include "conn_callback.h" +#include "msg_sender.h" #include "socket_manager_c_api.h" #include -#include #include +#include +#include namespace socket_manager { +/** + * @brief The log data structure. + * + * This structure is used to pass log data from C to C++. + */ +struct LogData { + // log level + SOCKET_MANAGER_C_API_TraceLevel level; + // module name + std::string_view target; + // file name (empty if not available) + std::string_view file; + // code line (-1 if not available) + int line; + // log message + std::string_view message; +}; + +/** + * Helper function to convert from C log data to C++ log data. + */ +LogData from_c_log_data(SOCKET_MANAGER_C_API_LogData log_data); + +/** + * @brief Initialize the logger for the socket manager. + * + * This function cannot be called more than once, + * otherwise it will throw an exception.. + * + * Tracer must be thread safe (as most loggers are thread safe). + */ +void init_logger(void (*tracer)(SOCKET_MANAGER_C_API_LogData), + SOCKET_MANAGER_C_API_TraceLevel tracer_max_level, + SOCKET_MANAGER_C_API_TraceLevel log_print_level); + +/** + * @brief Manages a set of sockets. + * + *

Usage

+ * The user needs to implement a `ConnCallback` object to handle + * connection events, and pass it to the constructor of `SocketManager`. + *

+ * When the connection is established, the `on_connect` callback + * returns a `MsgSender` object for sending messages to the peer, + * and a `Connection` object for receiving messages from the peer. + *

+ * To receive messages from the peer, the user needs to implement + * a `MsgReceiver` object and pass it to the `Connection::start` + * method. + * + *

Memory Management

+ * The system internally use shared pointers to manage the lifetime + * of the objects. Note the following dependency relationship to avoid + * memory leak (i.e., circular reference): + *
    + *
  • `SocketManager` ---strong ref--> `ConnCallback`
  • + *
  • `ConnCallback` ---strong ref--> (active) `Connection`s (drop on + * `connection_close`)
  • `Connection` ---strong ref--> `Notifier`
  • + *
  • `Connection` ---strong ref--> `(Async)MsgReceiver`
  • + *
  • `MsgSender` ---strong ref--> `Connection`
  • + *
+ * Notice that if `MsgSender` is strongly referenced by `Notifier`, + * or strongly referenced by `(Async)MsgReceiver`, then the connection + * will have a memory leak. The user could `reset()` the + * `shared_ptr\` on connection close event, and thus break the + * cycle.

In short, the user must guarantee that the `MsgSender` + * object is released for connection resources to be properly released. + * + *

Note on lifetime:

+ * + *
    + *
  • The `connection callback` object should have + * a longer lifetime than the socket manager.
  • + * + *
  • The `msg receiver` should live as long as + * connection is not closed.
  • + * + *
  • The `Notifier` object should live as long as + * connection is not closed.
  • + *
+ */ +class SocketManager { + +public: /** - * @brief Manages a set of sockets. + * Create a socket manager with n threads. * - *

Usage

- * The user needs to implement a `ConnCallback` object to handle - * connection events, and pass it to the constructor of `SocketManager`. - *

- * When the connection is established, the `on_connect` callback - * returns a `MsgSender` object for sending messages to the peer, - * and a `Connection` object for receiving messages from the peer. - *

- * To receive messages from the peer, the user needs to implement - * a `MsgReceiver` object and pass it to the `Connection::start` - * method. + * @param n_threads the number of threads to use. If n_threads is 0, then + * the number of threads is equal to the number of cores. + * Default to single-threaded runtime. + */ + explicit SocketManager(const std::shared_ptr &conn_cb, + size_t n_threads = 1); + + /** + * Listen on the given address. * - *

Memory Management

- * The system internally use shared pointers to manage the lifetime - * of the objects. Note the following dependency relationship to avoid - * memory leak (i.e., circular reference): - *
    - *
  • `SocketManager` ---strong ref--> `ConnCallback`
  • - *
  • `ConnCallback` ---strong ref--> (active) `Connection`s (drop on `connection_close`)
  • - *
  • `Connection` ---strong ref--> `Notifier`
  • - *
  • `Connection` ---strong ref--> `(Async)MsgReceiver`
  • - *
  • `MsgSender` ---strong ref--> `Connection`
  • - *
- * Notice that if `MsgSender` is strongly referenced by `Notifier`, - * or strongly referenced by `(Async)MsgReceiver`, then the connection - * will have a memory leak. The user could `reset()` the `shared_ptr\` - * on connection close event, and thus break the cycle. - *

- * In short, the user must guarantee that the `MsgSender` object - * is released for connection resources to be properly released. + *

Thread Safety

+ * Thread safe. * - *

Note on lifetime:

+ *

Errors

+ * Throws `std::runtime_error` if socket manager runtime has been aborted. + * Throws `std::runtime_error` if the address is invalid. * - *
    - *
  • The `connection callback` object should have - * a longer lifetime than the socket manager.
  • + * @param addr: the ip address to listen to (support both ipv4 and ipv6). + */ + void listen_on_addr(const std::string &addr); + + /** + * Connect to the given address. * - *
  • The `msg receiver` should live as long as - * connection is not closed.
  • + *

    Thread Safety

    + * Thread safe. * - *
  • The `Notifier` object should live as long as - * connection is not closed.
  • - *
+ *

Errors

+ * Throws `std::runtime_error` if socket manager runtime has been aborted. + * Throws `std::runtime_error` if the address is invalid. + * + * @param addr: the ip address to listen to (support both ipv4 and ipv6). + * @param delay: the delay in milliseconds before connecting to the address. */ - class SocketManager { - - public: - - /** - * Create a socket manager with n threads. - * - * @param n_threads the number of threads to use. If n_threads is 0, then - * the number of threads is equal to the number of cores. - * Default to single-threaded runtime. - */ - explicit SocketManager(const std::shared_ptr &conn_cb, size_t n_threads = 1); - - /** - * Listen on the given address. - * - *

Thread Safety

- * Thread safe. - * - *

Errors

- * Throws `std::runtime_error` if socket manager runtime has been aborted. - * Throws `std::runtime_error` if the address is invalid. - * - * @param addr: the ip address to listen to (support both ipv4 and ipv6). - */ - void listen_on_addr(const std::string &addr); + void connect_to_addr(const std::string &addr, uint64_t delay = 0); - /** - * Connect to the given address. - * - *

Thread Safety

- * Thread safe. - * - *

Errors

- * Throws `std::runtime_error` if socket manager runtime has been aborted. - * Throws `std::runtime_error` if the address is invalid. - * - * @param addr: the ip address to listen to (support both ipv4 and ipv6). - * @param delay: the delay in milliseconds before connecting to the address. - */ - void connect_to_addr(const std::string &addr, uint64_t delay = 0); - - /** - * Cancel listening on the given address. - * - *

Thread Safety

- * Thread safe. - * - *

Errors

- * Throws `std::runtime_error` if socket manager runtime has been aborted. - * Throw `std::runtime_error` if the address is invalid. - * - * @param addr cancel listening on this address. - */ - void cancel_listen_on_addr(const std::string &addr); - - /** - * Stop all background threads and drop all connections. - *

- * Calling a second time will return immediately (if `wait = false`). - * - *

Argument

- * - `wait`: if true, wait for all the background threads to finish. - * Default to true. - * - *

Thread Safety

- * Thread safe. - * - *

Errors

- * Throws `std::runtime_error` if `wait = true` and the background - * thread panicked. - */ - void abort(bool wait = true); - - /** - * Join and wait on the `SocketManager` background runtime. - *

- * Returns immediately on the second call. - * - *

Thread Safety

- * Thread safe. - * - *

Errors

- * Throws `std::runtime_error` if the background runtime panicked. - */ - void join(); + /** + * Cancel listening on the given address. + * + *

Thread Safety

+ * Thread safe. + * + *

Errors

+ * Throws `std::runtime_error` if socket manager runtime has been aborted. + * Throw `std::runtime_error` if the address is invalid. + * + * @param addr cancel listening on this address. + */ + void cancel_listen_on_addr(const std::string &addr); - private: + /** + * Stop all background threads and drop all connections. + *

+ * Calling a second time will return immediately (if `wait = false`). + * + *

Argument

+ * - `wait`: if true, wait for all the background threads to finish. + * Default to true. + * + *

Thread Safety

+ * Thread safe. + * + *

Errors

+ * Throws `std::runtime_error` if `wait = true` and the background + * thread panicked. + */ + void abort(bool wait = true); - std::unique_ptr> inner; - std::shared_ptr conn_cb; + /** + * Join and wait on the `SocketManager` background runtime. + *

+ * Returns immediately on the second call. + * + *

Thread Safety

+ * Thread safe. + * + *

Errors

+ * Throws `std::runtime_error` if the background runtime panicked. + */ + void join(); - }; +private: + std::unique_ptr> + inner; + std::shared_ptr conn_cb; +}; } // namespace socket_manager diff --git a/include/socket_manager_c_api.h b/include/socket_manager_c_api.h index bbb46eb..63b1d4f 100644 --- a/include/socket_manager_c_api.h +++ b/include/socket_manager_c_api.h @@ -15,6 +15,48 @@ enum class SOCKET_MANAGER_C_API_ConnStateCode { ConnectError = 3, }; +/** + * Trace Level + */ +enum class SOCKET_MANAGER_C_API_TraceLevel { + /** + * The "trace" level. + * + * Designates very low priority, often extremely verbose, information. + */ + Trace = 0, + /** + * The "debug" level. + * + * Designates lower priority information. + */ + Debug = 1, + /** + * The "info" level. + * + * Designates useful information. + */ + Info = 2, + /** + * The "warn" level. + * + * Designates hazardous situations. + */ + Warn = 3, + /** + * The "error" level. + * + * Designates very serious errors. + */ + Error = 4, + /** + * Turn off all levels. + * + * Disable log output. + */ + Off = 5, +}; + struct SOCKET_MANAGER_C_API_Connection; /** @@ -158,6 +200,26 @@ struct SOCKET_MANAGER_C_API_ConnMsg { size_t Len; }; +/** + * Log Data + */ +struct SOCKET_MANAGER_C_API_LogData { + SOCKET_MANAGER_C_API_TraceLevel Level; + const char *Target; + size_t TargetN; + const char *File; + size_t FileN; + /** + * -1 if not available + */ + int Line; + /** + * The `message` pointer is only valid for the duration of the callback. + */ + const char *Message; + size_t MessageN; +}; + extern "C" { /** @@ -329,7 +391,7 @@ void socket_manager_msg_sender_free(SOCKET_MANAGER_C_API_MsgSender *sender); * Set `err` to null_ptr if there is no error. */ extern void socket_manager_extern_on_conn(SOCKET_MANAGER_C_API_OnConnObj this_, - SOCKET_MANAGER_C_API_ConnStates conn, + SOCKET_MANAGER_C_API_ConnStates states, char **err); /** @@ -463,6 +525,19 @@ int socket_manager_join(SOCKET_MANAGER_C_API_SocketManager *manager, char **err) */ void socket_manager_free(SOCKET_MANAGER_C_API_SocketManager *manager); +/** + * Init logger. + * + * # Arguments + * - `tracer`: The tracer object. + * - `tracer_max_level`: The max level of the tracer. + * - `log_print_level`: The level of the log to print. + */ +void socket_manager_logger_init(void (*tracer)(SOCKET_MANAGER_C_API_LogData), + SOCKET_MANAGER_C_API_TraceLevel tracer_max_level, + SOCKET_MANAGER_C_API_TraceLevel log_print_level, + char **err); + } // extern "C" #endif // SOCKET_MANAGER_C_API_H diff --git a/justfile b/justfile index 52c5c3a..559ac6c 100644 --- a/justfile +++ b/justfile @@ -15,8 +15,7 @@ dev-docker: docker build -f ./dockerfile/dev-containers/jammy/Dockerfile -t congyuwang/socket-manager-dev:jammy . test: - cd build && SOCKET_LOG=debug ctest --output-on-failure && cd .. - + cd build && ctest --output-on-failure && cd .. time: /usr/bin/time -l -h -p ./build/tests/CommonCxxTests test_transfer_data_large /usr/bin/time -l -h -p ./build/tests/CommonCxxTests test_transfer_data_large_async @@ -29,21 +28,25 @@ test-linking: cd ../.. debug: - cmake -B build -DCMAKE_BUILD_TYPE=Debug - cmake --build build --config Debug + cmake -B build -DCMAKE_BUILD_TYPE=Debug \ + -DCMAKE_TOOLCHAIN_FILE=toolchain.cmake \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + cmake --build build --parallel 4 --config Debug just test build: cmake -B build -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_TOOLCHAIN_FILE=toolchain.cmake \ - -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON + -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON cmake --build build --parallel 4 --config Release --verbose just test build-static: cmake -B build -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_TOOLCHAIN_FILE=toolchain.cmake \ - -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=true \ + -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DBUILD_SHARED_LIBS=OFF cmake --build build --parallel 4 --config Release --verbose just test diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..67d9a53 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly" +profile = "minimal" diff --git a/socket_manager/connection.cc b/socket_manager/connection.cc index 6ffdc85..8b65d12 100644 --- a/socket_manager/connection.cc +++ b/socket_manager/connection.cc @@ -1,51 +1,53 @@ #include "socket_manager/connection.h" - +#include "socket_manager/msg_sender.h" namespace socket_manager { - Connection::Connection(SOCKET_MANAGER_C_API_Connection *inner) - : notifier(std::make_shared()), - inner(inner, - [](SOCKET_MANAGER_C_API_Connection *ptr) { - socket_manager_connection_free(ptr); - }) {} - - void Connection::start( - std::shared_ptr msg_receiver, - std::shared_ptr send_notifier, - size_t msg_buffer_size, - unsigned long long read_msg_flush_interval, - unsigned long long write_flush_interval) { +Connection::Connection(SOCKET_MANAGER_C_API_Connection *inner) + : notifier(std::make_shared()), + inner(inner, [](SOCKET_MANAGER_C_API_Connection *ptr) { + socket_manager_connection_free(ptr); + }) {} - if (msg_receiver == nullptr) { - throw std::runtime_error("msg_receiver should not be nullptr"); - } - // keep the msg_receiver alive. - this->receiver = std::move(msg_receiver); - // keep the notifier alive. - if (send_notifier != nullptr) { - this->notifier = std::move(send_notifier); - } +void Connection::start(std::shared_ptr msg_receiver, + std::shared_ptr send_notifier, + size_t msg_buffer_size, + unsigned long long read_msg_flush_interval, + unsigned long long write_flush_interval) { - // start the connection. - // calling twice `connection_start` will throw exception. - char *err = nullptr; - if (socket_manager_connection_start(inner.get(), SOCKET_MANAGER_C_API_OnMsgObj{ - this->receiver.get(), - }, msg_buffer_size, read_msg_flush_interval, write_flush_interval, &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } + if (msg_receiver == nullptr) { + throw std::runtime_error("msg_receiver should not be nullptr"); + } + // keep the msg_receiver alive. + this->receiver = std::move(msg_receiver); + // keep the notifier alive. + if (send_notifier != nullptr) { + this->notifier = std::move(send_notifier); } - void Connection::close() { - char *err = nullptr; - if (socket_manager_connection_close(inner.get(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } + // start the connection. + // calling twice `connection_start` will throw exception. + char *err = nullptr; + if (0 != socket_manager_connection_start(inner.get(), + SOCKET_MANAGER_C_API_OnMsgObj{ + this->receiver.get(), + }, + msg_buffer_size, + read_msg_flush_interval, + write_flush_interval, &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); + } +} + +void Connection::close() { + char *err = nullptr; + if (0 != socket_manager_connection_close(inner.get(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} } // namespace socket_manager diff --git a/socket_manager/msg_receiver.cc b/socket_manager/msg_receiver.cc index bcd2434..7967451 100644 --- a/socket_manager/msg_receiver.cc +++ b/socket_manager/msg_receiver.cc @@ -1,8 +1,9 @@ #include "socket_manager/msg_receiver.h" namespace socket_manager { - long MsgReceiver::on_message_async(std::string_view data, Waker &&waker) { - on_message(data); - return (long) data.length(); - } +long MsgReceiver::on_message_async(std::string_view data, Waker &&waker) { + on_message(data); + Waker const drop_waker(std::move(waker)); + return (long)data.length(); } +} // namespace socket_manager diff --git a/socket_manager/msg_sender.cc b/socket_manager/msg_sender.cc index 3abf5da..71f1085 100644 --- a/socket_manager/msg_sender.cc +++ b/socket_manager/msg_sender.cc @@ -3,56 +3,52 @@ namespace socket_manager { - void MsgSender::send_block(std::string_view data) { - char *err = nullptr; - if (socket_manager_msg_sender_send_block( - inner.get(), data.data(), data.length(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void MsgSender::send_block(std::string_view data) { + char *err = nullptr; + if (0 != socket_manager_msg_sender_send_block(inner.get(), data.data(), + data.length(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - void MsgSender::send_nonblock(std::string_view data) { - char *err = nullptr; - if (socket_manager_msg_sender_send_nonblock( - inner.get(), data.data(), data.length(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void MsgSender::send_nonblock(std::string_view data) { + char *err = nullptr; + if (0 != socket_manager_msg_sender_send_nonblock(inner.get(), data.data(), + data.length(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - long MsgSender::send_async(std::string_view data) { - char *err = nullptr; - long n = socket_manager_msg_sender_send_async( - inner.get(), - data.data(), - data.length(), - SOCKET_MANAGER_C_API_Notifier{conn->notifier.get()}, - &err); - if (err) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } - return n; +long MsgSender::send_async(std::string_view data) { + char *err = nullptr; + long const bytes_sent = socket_manager_msg_sender_send_async( + inner.get(), data.data(), data.length(), + SOCKET_MANAGER_C_API_Notifier{conn->notifier.get()}, &err); + if (err != nullptr) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } + return bytes_sent; +} - void MsgSender::flush() { - char *err = nullptr; - if (socket_manager_msg_sender_flush(inner.get(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void MsgSender::flush() { + char *err = nullptr; + if (0 != socket_manager_msg_sender_flush(inner.get(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - MsgSender::MsgSender(SOCKET_MANAGER_C_API_MsgSender *inner, const std::shared_ptr &conn) - : conn(conn), - inner(inner, - [](SOCKET_MANAGER_C_API_MsgSender *ptr) { - socket_manager_msg_sender_free(ptr); - }) {} +MsgSender::MsgSender(SOCKET_MANAGER_C_API_MsgSender *inner, + const std::shared_ptr &conn) + : conn(conn), inner(inner, [](SOCKET_MANAGER_C_API_MsgSender *ptr) { + socket_manager_msg_sender_free(ptr); + }) {} } // namespace socket_manager diff --git a/socket_manager/socket_manager.cc b/socket_manager/socket_manager.cc index ede0c98..16db18f 100644 --- a/socket_manager/socket_manager.cc +++ b/socket_manager/socket_manager.cc @@ -1,68 +1,95 @@ #include "socket_manager/socket_manager.h" +#include "socket_manager_c_api.h" +#include namespace socket_manager { - SocketManager::SocketManager(const std::shared_ptr &conn_cb, size_t n_threads) - : conn_cb(conn_cb) { - char *err = nullptr; - auto inner_ptr = socket_manager_init(SOCKET_MANAGER_C_API_OnConnObj{ - conn_cb.get() - }, n_threads, &err); - inner = std::unique_ptr>( - inner_ptr, - [](SOCKET_MANAGER_C_API_SocketManager *ptr) { socket_manager_free(ptr); } - ); - if (err) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +LogData from_c_log_data(SOCKET_MANAGER_C_API_LogData log_data) { + return { + log_data.Level, + std::string_view(log_data.Target, log_data.TargetN), + std::string_view(log_data.File, log_data.FileN), + log_data.Line, + std::string_view(log_data.Message, log_data.MessageN), + }; +} + +void init_logger(void (*tracer)(SOCKET_MANAGER_C_API_LogData), + SOCKET_MANAGER_C_API_TraceLevel tracer_max_level, + SOCKET_MANAGER_C_API_TraceLevel log_print_level) { + char *err = nullptr; + socket_manager_logger_init(tracer, tracer_max_level, log_print_level, &err); + if (err != nullptr) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); + } +} + +SocketManager::SocketManager(const std::shared_ptr &conn_cb, + size_t n_threads) + : conn_cb(conn_cb) { + char *err = nullptr; + auto *inner_ptr = socket_manager_init( + SOCKET_MANAGER_C_API_OnConnObj{conn_cb.get()}, n_threads, &err); + inner = std::unique_ptr< + SOCKET_MANAGER_C_API_SocketManager, + std::function>( + inner_ptr, [](SOCKET_MANAGER_C_API_SocketManager *ptr) { + socket_manager_free(ptr); + }); + if (err != nullptr) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - void SocketManager::listen_on_addr(const std::string &addr) { - char *err = nullptr; - if (socket_manager_listen_on_addr(inner.get(), addr.c_str(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void SocketManager::listen_on_addr(const std::string &addr) { + char *err = nullptr; + if (0 != socket_manager_listen_on_addr(inner.get(), addr.c_str(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - void SocketManager::connect_to_addr(const std::string &addr, uint64_t delay) { - char *err = nullptr; - if (socket_manager_connect_to_addr(inner.get(), addr.c_str(), delay, &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void SocketManager::connect_to_addr(const std::string &addr, uint64_t delay) { + char *err = nullptr; + if (0 != + socket_manager_connect_to_addr(inner.get(), addr.c_str(), delay, &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - void SocketManager::cancel_listen_on_addr(const std::string &addr) { - char *err = nullptr; - if (socket_manager_cancel_listen_on_addr(inner.get(), addr.c_str(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void SocketManager::cancel_listen_on_addr(const std::string &addr) { + char *err = nullptr; + if (0 != + socket_manager_cancel_listen_on_addr(inner.get(), addr.c_str(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - void SocketManager::abort(bool wait) { - char *err = nullptr; - if (socket_manager_abort(inner.get(), wait, &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void SocketManager::abort(bool wait) { + char *err = nullptr; + if (0 != socket_manager_abort(inner.get(), wait, &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} - void SocketManager::join() { - char *err = nullptr; - if (socket_manager_join(inner.get(), &err)) { - const std::string err_str(err); - free(err); - throw std::runtime_error(err_str); - } +void SocketManager::join() { + char *err = nullptr; + if (0 != socket_manager_join(inner.get(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } +} } // namespace socket_manager diff --git a/socket_manager/socket_manager_c_api.cc b/socket_manager/socket_manager_c_api.cc index 441f611..4efb810 100644 --- a/socket_manager/socket_manager_c_api.cc +++ b/socket_manager/socket_manager_c_api.cc @@ -1,126 +1,107 @@ #include "socket_manager_c_api.h" -#include "socket_manager/msg_receiver.h" -#include "socket_manager/conn_callback.h" #include "socket_manager/common/waker.h" +#include "socket_manager/conn_callback.h" +#include "socket_manager/msg_receiver.h" +#include "socket_manager/msg_sender.h" -inline -static char *string_dup(const std::string &str) { +inline char *string_dup(const std::string &str) { auto size = str.size(); - char *buffer = (char *) malloc(size + 1); + char *buffer = static_cast(malloc(size + 1)); memcpy(buffer, str.c_str(), size + 1); return buffer; } +#define SOCKET_MANAGER_CATCH_ERROR(err, expr) \ + try { \ + *(err) = nullptr; \ + expr; \ + } catch (std::runtime_error & e) { \ + *(err) = string_dup(e.what()); \ + } catch (...) { \ + *(err) = string_dup("unknown error"); \ + } + /** * RecvWaker for the sender. */ -extern "C" void socket_manager_extern_notifier_wake(SOCKET_MANAGER_C_API_Notifier this_) { - auto wr = reinterpret_cast(this_.This); - wr->wake(); +extern "C" void +socket_manager_extern_notifier_wake(SOCKET_MANAGER_C_API_Notifier this_) { + auto *notifier = reinterpret_cast(this_.This); + notifier->wake(); } -extern "C" long socket_manager_extern_on_msg(SOCKET_MANAGER_C_API_OnMsgObj this_, - SOCKET_MANAGER_C_API_ConnMsg msg, - SOCKET_MANAGER_C_API_CWaker waker, - char **err) { - auto receiver = reinterpret_cast(this_.This); - try { - auto recv = receiver->on_message_async( - std::string_view(msg.Bytes, msg.Len), - socket_manager::Waker(waker) - ); - *err = nullptr; - return recv; - } catch (std::runtime_error &e) { - *err = string_dup(e.what()); - return 0; - } catch (...) { - *err = string_dup("unknown error"); - return 0; - } +extern "C" long +socket_manager_extern_on_msg(SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, + SOCKET_MANAGER_C_API_CWaker waker, char **err) { + auto *receiver = + reinterpret_cast(this_.This); + SOCKET_MANAGER_CATCH_ERROR(err, return receiver->on_message_async( + std::string_view(msg.Bytes, msg.Len), + socket_manager::Waker(waker))) + // on error, return error and no-byte read, the runtime will close the + // connection. + return 0; } -extern "C" void socket_manager_extern_on_conn( - SOCKET_MANAGER_C_API_OnConnObj this_, - SOCKET_MANAGER_C_API_ConnStates states, - char **error) { +extern "C" void +socket_manager_extern_on_conn(SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates states, + char **error) { - auto conn_cb = static_cast(this_.This); + auto *conn_cb = static_cast(this_.This); switch (states.Code) { - case SOCKET_MANAGER_C_API_ConnStateCode::Connect: { - auto on_connect = states.Data.OnConnect; - auto local_addr = std::string(on_connect.Local); - auto peer_addr = std::string(on_connect.Peer); + case SOCKET_MANAGER_C_API_ConnStateCode::Connect: { + auto on_connect = states.Data.OnConnect; + auto local_addr = std::string(on_connect.Local); + auto peer_addr = std::string(on_connect.Peer); - std::shared_ptr conn(new socket_manager::Connection(on_connect.Conn)); - std::shared_ptr sender(new socket_manager::MsgSender(on_connect.Send, conn)); + std::shared_ptr conn( + new socket_manager::Connection(on_connect.Conn)); + std::shared_ptr sender( + new socket_manager::MsgSender(on_connect.Send, conn)); - // keep the connection alive - { - std::unique_lock lock(conn_cb->lock); - conn_cb->conns[local_addr + peer_addr] = conn; - } - try { - conn_cb->on_connect(local_addr, peer_addr, std::move(conn), std::move(sender)); - *error = nullptr; - } catch (std::runtime_error &e) { - *error = string_dup(e.what()); - } catch (...) { - *error = string_dup("unknown error"); - } - break; + // keep the connection alive + { + std::unique_lock const lock(conn_cb->lock); + conn_cb->conns[local_addr + peer_addr] = conn; } - case SOCKET_MANAGER_C_API_ConnStateCode::ConnectionClose: { - auto on_connection_close = states.Data.OnConnectionClose; - auto local_addr = std::string(on_connection_close.Local); - auto peer_addr = std::string(on_connection_close.Peer); + SOCKET_MANAGER_CATCH_ERROR(error, conn_cb->on_connect(local_addr, peer_addr, + std::move(conn), + std::move(sender))) + break; + } + case SOCKET_MANAGER_C_API_ConnStateCode::ConnectionClose: { + auto on_connection_close = states.Data.OnConnectionClose; + auto local_addr = std::string(on_connection_close.Local); + auto peer_addr = std::string(on_connection_close.Peer); - // remove the connection from the map - { - std::unique_lock lock(conn_cb->lock); - conn_cb->conns.erase(local_addr + peer_addr); - } - try { - conn_cb->on_connection_close(local_addr, peer_addr); - *error = nullptr; - } catch (std::runtime_error &e) { - *error = string_dup(e.what()); - } catch (...) { - *error = string_dup("unknown error"); - } - break; - } - case SOCKET_MANAGER_C_API_ConnStateCode::ListenError: { - auto listen_error = states.Data.OnListenError; - auto addr = std::string(listen_error.Addr); - auto err = std::string(listen_error.Err); - try { - conn_cb->on_listen_error(addr, err); - *error = nullptr; - } catch (std::runtime_error &e) { - *error = string_dup(e.what()); - } catch (...) { - *error = string_dup("unknown error"); - } - break; - } - case SOCKET_MANAGER_C_API_ConnStateCode::ConnectError: { - auto connect_error = states.Data.OnConnectError; - auto addr = std::string(connect_error.Addr); - auto err = std::string(connect_error.Err); - try { - conn_cb->on_connect_error(addr, err); - *error = nullptr; - } catch (std::runtime_error &e) { - *error = string_dup(e.what()); - } catch (...) { - *error = string_dup("unknown error"); - } - break; - } - default: { - // should never reach here - *error = nullptr; + // remove the connection from the map + { + std::unique_lock const lock(conn_cb->lock); + conn_cb->conns.erase(local_addr + peer_addr); } + SOCKET_MANAGER_CATCH_ERROR( + error, conn_cb->on_connection_close(local_addr, peer_addr)) + break; + } + case SOCKET_MANAGER_C_API_ConnStateCode::ListenError: { + auto listen_error = states.Data.OnListenError; + auto addr = std::string(listen_error.Addr); + auto err = std::string(listen_error.Err); + SOCKET_MANAGER_CATCH_ERROR(error, conn_cb->on_listen_error(addr, err)) + break; + } + case SOCKET_MANAGER_C_API_ConnStateCode::ConnectError: { + auto connect_error = states.Data.OnConnectError; + auto addr = std::string(connect_error.Addr); + auto err = std::string(connect_error.Err); + SOCKET_MANAGER_CATCH_ERROR(error, conn_cb->on_connect_error(addr, err)) + break; + } + default: { + // should never reach here + *error = nullptr; + } } } diff --git a/socket_manager/waker.cc b/socket_manager/waker.cc index 52ced26..9a831d6 100644 --- a/socket_manager/waker.cc +++ b/socket_manager/waker.cc @@ -2,36 +2,35 @@ namespace socket_manager { - Waker::Waker() - : waker(SOCKET_MANAGER_C_API_CWaker{nullptr, nullptr}) {} +Waker::Waker() : waker(SOCKET_MANAGER_C_API_CWaker{nullptr, nullptr}) {} - Waker::Waker(SOCKET_MANAGER_C_API_CWaker waker) : waker(waker) {} +Waker::Waker(SOCKET_MANAGER_C_API_CWaker waker) : waker(waker) {} - Waker::Waker(Waker &&other) noexcept: waker(other.waker) { - other.waker.Data = nullptr; - other.waker.Vtable = nullptr; - } +Waker::Waker(Waker &&other) noexcept : waker(other.waker) { + other.waker.Data = nullptr; + other.waker.Vtable = nullptr; +} - Waker &Waker::operator=(Waker &&other) noexcept { - if (waker.Data != nullptr && waker.Vtable != nullptr) { - socket_manager_waker_free(waker); - } - waker = other.waker; - other.waker.Data = nullptr; - other.waker.Vtable = nullptr; - return *this; +Waker &Waker::operator=(Waker &&other) noexcept { + if (waker.Data != nullptr && waker.Vtable != nullptr) { + socket_manager_waker_free(waker); } - - void Waker::wake() { - if (waker.Data != nullptr && waker.Vtable != nullptr) { - socket_manager_waker_wake(&waker); - } + waker = other.waker; + other.waker.Data = nullptr; + other.waker.Vtable = nullptr; + return *this; +} + +void Waker::wake() { + if (waker.Data != nullptr && waker.Vtable != nullptr) { + socket_manager_waker_wake(&waker); } +} - Waker::~Waker() { - if (waker.Data != nullptr && waker.Vtable != nullptr) { - socket_manager_waker_free(waker); - } +Waker::~Waker() { + if (waker.Data != nullptr && waker.Vtable != nullptr) { + socket_manager_waker_free(waker); } +} } // namespace socket_manager diff --git a/src/c_api/mod.rs b/src/c_api/mod.rs index 09d0318..2375954 100644 --- a/src/c_api/mod.rs +++ b/src/c_api/mod.rs @@ -5,4 +5,5 @@ mod msg_sender; pub(crate) mod on_conn; pub(crate) mod on_msg; mod socket_manager; +pub(crate) mod tracer; mod utils; diff --git a/src/c_api/on_conn.rs b/src/c_api/on_conn.rs index ea1020e..0ee46eb 100644 --- a/src/c_api/on_conn.rs +++ b/src/c_api/on_conn.rs @@ -36,7 +36,7 @@ extern "C" { /// Set `err` to null_ptr if there is no error. pub(crate) fn socket_manager_extern_on_conn( this: OnConnObj, - conn: ConnStates, + states: ConnStates, err: *mut *mut c_char, ); } @@ -44,9 +44,9 @@ extern "C" { impl OnConnObj { /// connection callback pub(crate) fn call_inner(&self, conn_states: crate::ConnState) -> Result<(), String> { - let on_conn = |conn| unsafe { + let on_conn = |states| unsafe { let mut err: *mut c_char = null_mut(); - socket_manager_extern_on_conn(*self, conn, &mut err); + socket_manager_extern_on_conn(*self, states, &mut err); parse_c_err_str(err) }; match conn_states { diff --git a/src/c_api/tracer.rs b/src/c_api/tracer.rs new file mode 100644 index 0000000..074fc49 --- /dev/null +++ b/src/c_api/tracer.rs @@ -0,0 +1,145 @@ +//! Define a layer to pass log message to foreign interface. +use super::utils::write_error_c_str; +use crate::init_logger; +use libc::size_t; +use std::{ + ffi::{c_char, c_int}, + fmt, + ptr::null_mut, +}; +use tracing::{field::Field, Level}; +use tracing_subscriber::{filter::LevelFilter, Layer}; + +const MESSAGE: &str = "message"; +const EMPTY: &str = ""; + +/// Trace Level +#[repr(C)] +pub enum TraceLevel { + /// The "trace" level. + /// + /// Designates very low priority, often extremely verbose, information. + Trace = 0, + /// The "debug" level. + /// + /// Designates lower priority information. + Debug = 1, + /// The "info" level. + /// + /// Designates useful information. + Info = 2, + /// The "warn" level. + /// + /// Designates hazardous situations. + Warn = 3, + /// The "error" level. + /// + /// Designates very serious errors. + Error = 4, + /// Turn off all levels. + /// + /// Disable log output. + Off = 5, +} + +/// Log Data +#[repr(C)] +pub struct LogData { + pub level: TraceLevel, + pub target: *const c_char, + pub target_n: size_t, + pub file: *const c_char, + pub file_n: size_t, + /// -1 if not available + pub line: c_int, + /// The `message` pointer is only valid for the duration of the callback. + pub message: *const c_char, + pub message_n: size_t, +} + +/// Init logger. +/// +/// # Arguments +/// - `tracer`: The tracer object. +/// - `tracer_max_level`: The max level of the tracer. +/// - `log_print_level`: The level of the log to print. +#[no_mangle] +pub unsafe extern "C" fn socket_manager_logger_init( + tracer: unsafe extern "C" fn(LogData) -> (), + tracer_max_level: TraceLevel, + log_print_level: TraceLevel, + err: *mut *mut c_char, +) { + let foreign_logger = ForeignLogger(tracer).with_filter(tracer_max_level.into()); + match init_logger(log_print_level.into(), foreign_logger) { + Ok(_) => *err = null_mut(), + Err(e) => write_error_c_str(e, err), + } +} + +pub struct ForeignLogger(unsafe extern "C" fn(LogData) -> ()); + +impl Layer for ForeignLogger +where + S: tracing::Subscriber, +{ + fn on_event( + &self, + event: &tracing::Event<'_>, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + let mut get_msg = GetMsgVisitor(None); + event.record(&mut get_msg); + let file = event.metadata().file().unwrap_or(EMPTY); + let line = event.metadata().line().map(|l| l as c_int); + let message = get_msg.0.as_deref().unwrap_or(EMPTY); + let data = LogData { + level: event.metadata().level().into(), + target: event.metadata().target().as_ptr() as *const c_char, + target_n: event.metadata().target().len(), + file: file.as_ptr() as *const c_char, + file_n: file.len(), + line: line.unwrap_or(-1), + message: message.as_ptr() as *const c_char, + message_n: message.len(), + }; + unsafe { self.0(data) } + } +} + +// Helper methods and structs. + +struct GetMsgVisitor(Option); + +impl tracing::field::Visit for GetMsgVisitor { + fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) { + if field.name() == MESSAGE { + self.0 = Some(format!("{:?}", value)); + } + } +} + +impl Into for TraceLevel { + fn into(self) -> LevelFilter { + match self { + TraceLevel::Trace => LevelFilter::TRACE, + TraceLevel::Debug => LevelFilter::DEBUG, + TraceLevel::Info => LevelFilter::INFO, + TraceLevel::Warn => LevelFilter::WARN, + TraceLevel::Error => LevelFilter::ERROR, + TraceLevel::Off => LevelFilter::OFF, + } + } +} + +impl From<&Level> for TraceLevel { + fn from(value: &Level) -> Self { + match *value { + Level::TRACE => TraceLevel::Trace, + Level::DEBUG => TraceLevel::Debug, + Level::INFO => TraceLevel::Info, + Level::WARN => TraceLevel::Warn, + Level::ERROR => TraceLevel::Error, + } + } +} diff --git a/src/conn_handle.rs b/src/conn_handle.rs index 510816b..422087f 100644 --- a/src/conn_handle.rs +++ b/src/conn_handle.rs @@ -1,7 +1,6 @@ use crate::conn::{Conn, ConnConfig}; use crate::msg_sender::make_sender; use crate::{read, write, ConnState, ConnectionState, Msg}; -use futures::FutureExt; use std::net::SocketAddr; use std::sync::Arc; use std::task::{Poll, Waker}; @@ -81,7 +80,7 @@ pub(crate) fn handle_connection< /// On connection end, remove connection from connection state. async fn join_reader_writer( - (writer, reader): ( + (mut writer, mut reader): ( JoinHandle>, JoinHandle>, ), @@ -89,28 +88,22 @@ async fn join_reader_writer( ) { let writer_abort = writer.abort_handle(); let reader_abort = reader.abort_handle(); - let mut writer = writer.fuse(); - let mut reader = reader.fuse(); - loop { - tokio::select! { - w = &mut writer => { - if let Err(e) = w { - tracing::error!("writer stopped on error ({e}), local={local_addr}, peer={peer_addr}"); - } else { - tracing::debug!("writer stopped local={local_addr}, peer={peer_addr}"); - } - reader_abort.abort(); - break; + tokio::select! { + w = &mut writer => { + if let Err(e) = w { + tracing::error!("writer stopped on error ({e}), local={local_addr}, peer={peer_addr}"); + } else { + tracing::debug!("writer stopped local={local_addr}, peer={peer_addr}"); } - r = &mut reader => { - if let Err(e) = r { - tracing::error!("reader stopped on error ({e}), local={local_addr}, peer={peer_addr}"); - } else { - tracing::debug!("reader stopped local={local_addr}, peer={peer_addr}"); - } - writer_abort.abort(); - break; + reader_abort.abort(); + } + r = &mut reader => { + if let Err(e) = r { + tracing::error!("reader stopped on error ({e}), local={local_addr}, peer={peer_addr}"); + } else { + tracing::debug!("reader stopped local={local_addr}, peer={peer_addr}"); } + writer_abort.abort(); } } } diff --git a/src/lib.rs b/src/lib.rs index 34a640a..ea9580b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ mod read; mod utils; mod write; +use c_api::tracer::ForeignLogger; use dashmap::DashMap; use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, Ordering}; @@ -23,12 +24,15 @@ use tokio::net::{TcpListener, TcpStream}; use tokio::runtime::Handle; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::sync::{mpsc, oneshot}; -use tracing_subscriber::EnvFilter; +use tracing_subscriber::filter::{Filtered, LevelFilter}; +use tracing_subscriber::util::{SubscriberInitExt, TryInitError}; +use tracing_subscriber::Layer; +use tracing_subscriber::{prelude::*, Registry}; +pub use c_api::tracer::TraceLevel; pub use conn::*; pub use msg_sender::*; -const SOCKET_LOG: &str = "SOCKET_LOG"; const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); /// The Main Struct of the Library. @@ -104,6 +108,23 @@ impl ConnectionState { } } +/// Initialize socket manager logger. +/// Call this before anything else to prevent missing info. +/// This cannot be called twice, even accross multiple socket_manager instances. +pub fn init_logger( + log_print_level: LevelFilter, + foreign_logger: Filtered, +) -> Result<(), TryInitError> { + let fmt_layer = { + let fmt = tracing_subscriber::fmt::layer(); + fmt.with_filter(log_print_level) + }; + tracing_subscriber::registry() + .with(foreign_logger) + .with(fmt_layer) + .try_init() +} + /// Msg struct for the on_msg callback. pub struct Msg<'a> { bytes: &'a [u8], @@ -125,13 +146,6 @@ impl SocketManager { on_conn: OnConn, n_threads: usize, ) -> std::io::Result { - let _ = tracing_subscriber::fmt() - .with_env_filter( - EnvFilter::builder() - .with_env_var(SOCKET_LOG) - .from_env_lossy(), - ) - .try_init(); let runtime = utils::start_runtime(n_threads)?; let (cmd_send, cmd_recv) = mpsc::unbounded_channel::(); let connection_state = ConnectionState::new(); diff --git a/src/msg_sender.rs b/src/msg_sender.rs index da0e63b..919b08f 100644 --- a/src/msg_sender.rs +++ b/src/msg_sender.rs @@ -1,43 +1,39 @@ -use async_ringbuf::ring_buffer::AsyncRbWrite; -use async_ringbuf::{AsyncHeapConsumer, AsyncHeapProducer, AsyncHeapRb}; -use std::future::poll_fn; -use std::task::Poll::{self, Pending, Ready}; -use std::task::Waker; +use async_ringbuf::traits::{AsyncProducer, Producer, Split}; +use async_ringbuf::wrap::{AsyncCons, AsyncProd}; +use async_ringbuf::AsyncHeapRb; +use futures::AsyncWriteExt; +use std::sync::Arc; +use std::task::Poll::{Pending, Ready}; +use std::task::{Poll, Waker}; use tokio::runtime::Handle; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; /// 256KB ring buffer. pub const RING_BUFFER_SIZE: usize = 256 * 1024; -/// Sender Commands other than bytes. -pub(crate) enum SendCommand { - Flush, -} +pub type AsyncHeapProducer = AsyncProd>>; +pub type AsyncHeapConsumer = AsyncCons>>; pub(crate) fn make_sender(handle: Handle) -> (MsgSender, MsgRcv) { - let (cmd, cmd_recv) = unbounded_channel::(); let (rings_prd, rings) = unbounded_channel::>(); let (ring_buf, ring) = AsyncHeapRb::::new(RING_BUFFER_SIZE).split(); rings_prd.send(ring).unwrap(); ( MsgSender { - cmd, ring_buf, rings_prd, handle, }, - MsgRcv { cmd_recv, rings }, + MsgRcv { rings }, ) } pub(crate) struct MsgRcv { - pub(crate) cmd_recv: UnboundedReceiver, pub(crate) rings: UnboundedReceiver>, } /// Drop the sender to close the connection. pub struct MsgSender { - pub(crate) cmd: UnboundedSender, pub(crate) ring_buf: AsyncHeapProducer, pub(crate) rings_prd: UnboundedSender>, pub(crate) handle: Handle, @@ -55,7 +51,7 @@ fn burst_write( bytes: &[u8], ) -> BurstWriteState { loop { - let n = buf.as_mut_base().push_slice(&bytes[*offset..]); + let n = buf.push_slice(&bytes[*offset..]); if n == 0 { // no bytes read, return break BurstWriteState::Pending; @@ -82,30 +78,8 @@ impl MsgSender { return Ok(()); } // unfinished, enter into future - self.handle.clone().block_on(async { - loop { - if let BurstWriteState::Finished = - burst_write(&mut offset, &mut self.ring_buf, bytes) - { - return Ok(()); - } - poll_fn(|cx| { - unsafe { self.ring_buf.as_base().rb().register_head_waker(cx.waker()) }; - if self.ring_buf.is_closed() { - Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "connection closed", - ))) - } else if self.ring_buf.is_full() { - Pending::> - } else { - // continue to loop until pending - Ready(Ok(())) - } - }) - .await?; - } - }) + self.handle + .block_on(self.ring_buf.write_all(&bytes[offset..])) } /// The non-blocking API for sending bytes. @@ -125,14 +99,14 @@ impl MsgSender { // allocate new ring buffer if unable to write the entire message. let new_buf_size = RING_BUFFER_SIZE.max(bytes.len() - offset); let (mut ring_buf, ring) = AsyncHeapRb::::new(new_buf_size).split(); - ring_buf.as_mut_base().push_slice(&bytes[offset..]); + ring_buf.push_slice(&bytes[offset..]); self.rings_prd.send(ring).map_err(|e| { std::io::Error::new( std::io::ErrorKind::WriteZero, format!("connection closed: {e}"), ) })?; - // set head to new ring_buf + // set head to new ring_buf (must send before closing the old one) self.ring_buf = ring_buf; Ok(()) } @@ -145,32 +119,41 @@ impl MsgSender { return Ready(Ok(0)); } let mut offset = 0usize; + let mut waker_registered = false; loop { + // check if closed + if self.ring_buf.is_closed() { + break Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "connection closed", + ))); + } // attempt to write as much as possible burst_write(&mut offset, &mut self.ring_buf, bytes); if offset > 0 { break Ready(Ok(offset)); } // offset = 0, prepare to wait - unsafe { self.ring_buf.as_base().rb().register_head_waker(&waker) }; - // check the pending state ensues. - if self.ring_buf.is_closed() { - break Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "connection closed", - ))); - } else if self.ring_buf.is_full() { + if waker_registered { break Pending; } + // register waker + self.ring_buf.register_waker(&waker); + waker_registered = true; + // try again to ensure no missing wake } } pub fn flush(&mut self) -> std::io::Result<()> { - self.cmd.send(SendCommand::Flush).map_err(|_| { + let (ring_buf, ring) = AsyncHeapRb::::new(RING_BUFFER_SIZE).split(); + self.rings_prd.send(ring).map_err(|e| { std::io::Error::new( - std::io::ErrorKind::Other, - "failed to send flush command, connection closed", + std::io::ErrorKind::WriteZero, + format!("connection closed: {e}"), ) - }) + })?; + // set head to new ring_buf (must send before closing the old one) + self.ring_buf = ring_buf; + Ok(()) } } diff --git a/src/write.rs b/src/write.rs index c9d5cb1..970dc63 100644 --- a/src/write.rs +++ b/src/write.rs @@ -1,7 +1,8 @@ use crate::conn::ConnConfig; use crate::msg_sender::MsgRcv; use crate::read::MIN_MSG_BUFFER_SIZE; -use async_ringbuf::AsyncHeapConsumer; +use crate::AsyncHeapConsumer; +use async_ringbuf::traits::{AsyncConsumer, Consumer, Observer}; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::net::tcp::OwnedWriteHalf; @@ -50,25 +51,16 @@ async fn handle_writer_auto_flush( biased; // !has_data => wait for has_data // has_data => wait for write_threshold - _ = ring.wait(if !has_data {1} else {MIN_MSG_BUFFER_SIZE}) => { + _ = ring.wait_occupied((has_data as usize) * MIN_MSG_BUFFER_SIZE + (!has_data as usize)) => { if ring.is_closed() { break 'ring; } has_data = true; - if ring.len() >= MIN_MSG_BUFFER_SIZE { + if ring.occupied_len() >= MIN_MSG_BUFFER_SIZE { flush(&mut ring, &mut write).await?; has_data = false } } - // flush - cmd = recv.cmd_recv.recv() => { - // always flush, including if sender is dropped - flush(&mut ring, &mut write).await?; - if cmd.is_none() { - break 'close; - } - has_data = false; - } // tick flush _ = flush_tick.tick(), if has_data => { flush(&mut ring, &mut write).await?; @@ -105,20 +97,12 @@ async fn handle_writer_no_auto_flush( tokio::select! { biased; // buf threshold - _ = ring.wait(MIN_MSG_BUFFER_SIZE) => { + _ = ring.wait_occupied(MIN_MSG_BUFFER_SIZE) => { if ring.is_closed() { break 'ring; } flush(&mut ring, &mut write).await?; } - // flush - cmd = recv.cmd_recv.recv() => { - // always flush, including if sender is dropped - flush(&mut ring, &mut write).await?; - if cmd.is_none() { - break 'close; - } - } _ = stop.closed() => break 'close, } } @@ -137,10 +121,10 @@ async fn flush( write: &mut OwnedWriteHalf, ) -> std::io::Result<()> { loop { - let (left, _) = ring_buf.as_mut_base().as_slices(); + let (left, _) = ring_buf.as_slices(); if !left.is_empty() { let count = write.write(left).await?; - unsafe { ring_buf.as_mut_base().advance(count) }; + unsafe { ring_buf.advance_read_index(count) }; } else { // both empty, break return Ok(()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 57ef01a..f8585c6 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,8 +17,9 @@ create_test_sourcelist(Tests CommonCxxTests.cxx ${test_files}) # build test driver add_executable(CommonCxxTests ${Tests}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/spdlog-repo/include/) target_link_libraries(CommonCxxTests - PUBLIC + PRIVATE socket_manager ${CMAKE_THREAD_LIBS_INIT}) diff --git a/tests/spdlog-repo b/tests/spdlog-repo new file mode 160000 index 0000000..ff205fd --- /dev/null +++ b/tests/spdlog-repo @@ -0,0 +1 @@ +Subproject commit ff205fd29a4a2f6ebcbecffd149153256d89a671 diff --git a/tests/test_abort_join.cpp b/tests/test_abort_join.cpp index d5a7ea8..39409d8 100644 --- a/tests/test_abort_join.cpp +++ b/tests/test_abort_join.cpp @@ -1,14 +1,13 @@ #undef NDEBUG -#include #include "test_utils.h" +#include using namespace socket_manager; -void abort_manager(SocketManager &manager) { - manager.abort(); -} +void abort_manager(SocketManager &manager) { manager.abort(); } int test_abort_join(int argc, char **argv) { + SpdLogger::init(); auto nothing_cb = std::make_shared(); SocketManager nothing(nothing_cb); diff --git a/tests/test_auto_flush.cpp b/tests/test_auto_flush.cpp index d2d9f42..2ef88ce 100644 --- a/tests/test_auto_flush.cpp +++ b/tests/test_auto_flush.cpp @@ -6,52 +6,56 @@ class ReceiverHelloWorld : public DoNothingReceiver { public: - ReceiverHelloWorld(std::mutex &mutex, - std::condition_variable &cond, - std::atomic_bool &received) - : mutex(mutex), cond(cond), received(received) {} + ReceiverHelloWorld(const std::shared_ptr &mutex, + const std::shared_ptr &cond, + const std::shared_ptr &received) + : mutex(mutex), cond(cond), received(received) {} void on_message(std::string_view data) override { if (data == "hello world") { - received.store(true); - std::unique_lock u_lock(mutex); - cond.notify_all(); + received->store(true); + std::unique_lock u_lock(*mutex); + cond->notify_all(); } } - std::mutex &mutex; - std::condition_variable &cond; - std::atomic_bool &received; + std::shared_ptr mutex; + std::shared_ptr cond; + std::shared_ptr received; }; class HelloWorldManager : public DoNothingConnCallback { public: + HelloWorldManager() + : mutex(std::make_shared()), + cond(std::make_shared()), + received(std::make_shared(false)) {} + void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto do_nothing = std::make_unique(mutex, cond, received); + std::shared_ptr conn, + std::shared_ptr send) override { + auto do_nothing = + std::make_unique(mutex, cond, received); conn->start(std::move(do_nothing)); this->sender = send; sender->send_block("hello world"); } - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool received{false}; - -private: + std::shared_ptr mutex; + std::shared_ptr cond; + std::shared_ptr received; std::shared_ptr sender; }; class SendHelloWorldDoNotClose : public DoNothingConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { auto do_nothing = std::make_unique(); conn->start(std::move(do_nothing)); this->sender = sender; - std::thread t([this] { - this->sender->send_block("hello world"); - }); - t.detach(); + std::thread sender_t([this] { this->sender->send_block("hello world"); }); + sender_t.detach(); } private: @@ -60,24 +64,25 @@ class SendHelloWorldDoNotClose : public DoNothingConnCallback { }; int test_auto_flush(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40101"; auto send_cb = std::make_shared(); SocketManager send(send_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); auto recv_cb = std::make_shared(); SocketManager recv(recv_cb); recv.connect_to_addr(addr); while (true) { - std::unique_lock u_lock(recv_cb->mutex); - if (recv_cb->received.load()) { + std::unique_lock u_lock(*recv_cb->mutex); + if (recv_cb->received->load()) { break; } - recv_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + recv_cb->cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } return 0; diff --git a/tests/test_bad_address_connect.cpp b/tests/test_bad_address_connect.cpp index 70a3bb6..c8638ea 100644 --- a/tests/test_bad_address_connect.cpp +++ b/tests/test_bad_address_connect.cpp @@ -4,6 +4,7 @@ using namespace socket_manager; int test_bad_address_connect(int argc, char **argv) { + SpdLogger::init(); auto nothing_cb = std::make_shared(); SocketManager nothing(nothing_cb); try { diff --git a/tests/test_bad_address_listen.cpp b/tests/test_bad_address_listen.cpp index 9643b07..02b1623 100644 --- a/tests/test_bad_address_listen.cpp +++ b/tests/test_bad_address_listen.cpp @@ -4,6 +4,7 @@ using namespace socket_manager; int test_bad_address_listen(int argc, char **argv) { + SpdLogger::init(); auto nothing_cb = std::make_shared(); SocketManager nothing(nothing_cb); try { diff --git a/tests/test_callback_throw_error.cpp b/tests/test_callback_throw_error.cpp index 08367b1..1e9f7a0 100644 --- a/tests/test_callback_throw_error.cpp +++ b/tests/test_callback_throw_error.cpp @@ -1,22 +1,24 @@ #undef NDEBUG #include "test_utils.h" -#include #include +#include #include using namespace socket_manager; class OnConnectErrorBeforeStartCallback : public DoNothingConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { throw std::runtime_error("throw some error before calling start"); } }; class OnConnectErrorAfterStartCallback : public DoNothingConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { conn->start(std::make_unique()); throw std::runtime_error("throw some error after calling start"); } @@ -30,42 +32,50 @@ class OnMsgErrorReceiver : public MsgReceiver { class OnMsgErrorCallback : public ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { + std::shared_ptr conn, + std::shared_ptr send) override { conn->start(std::make_unique()); this->sender = send; sender.use_count(); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override {} + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override {} - void on_listen_error(const std::string &addr, const std::string &err) override {} + void on_listen_error(const std::string &addr, + const std::string &err) override {} - void on_connect_error(const std::string &addr, const std::string &err) override {} + void on_connect_error(const std::string &addr, + const std::string &err) override {} std::shared_ptr sender; }; class StoreAllEventsConnHelloCallback : public StoreAllEventsConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - std::unique_lock lock(mutex); + std::shared_ptr conn, + std::shared_ptr sender) override { + std::unique_lock lock(*mutex); auto conn_id = local_addr + "->" + peer_addr; - events.emplace_back(CONNECTED, conn_id); - auto msg_storer = std::make_unique(conn_id, mutex, cond, buffer); + events->emplace_back(CONNECTED, conn_id); + auto msg_storer = + std::make_unique(conn_id, mutex, cond, buffer); conn->start(std::move(msg_storer)); - std::thread t1([sender]() { + std::thread sender_t([sender]() { try { sender->send_block("hello"); - } catch (std::runtime_error &e) { /* ignore */ } + } catch (std::runtime_error &e) { /* ignore */ + } }); - t1.detach(); + sender_t.detach(); senders.emplace(conn_id, sender); - connected_count.fetch_add(1, std::memory_order_seq_cst); - cond.notify_all(); + connected_count->fetch_add(1, std::memory_order_seq_cst); + cond->notify_all(); } }; int test_callback_throw_error(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40102"; auto err_before_cb = std::make_shared(); @@ -79,26 +89,29 @@ int test_callback_throw_error(int argc, char **argv) { SocketManager store_record(store_record_cb); store_record.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); err_before.connect_to_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); err_after.connect_to_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); err_on_msg.connect_to_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + const size_t EXPECTED = 6; while (true) { - std::unique_lock u_lock(store_record_cb->mutex); - if (store_record_cb->events.size() == 6) { - assert(std::get<0>(store_record_cb->events[0]) == CONNECTED); - assert(std::get<0>(store_record_cb->events[1]) == CONNECTION_CLOSED); - assert(std::get<0>(store_record_cb->events[2]) == CONNECTED); - assert(std::get<0>(store_record_cb->events[3]) == CONNECTION_CLOSED); - assert(std::get<0>(store_record_cb->events[4]) == CONNECTED); - assert(std::get<0>(store_record_cb->events[5]) == CONNECTION_CLOSED); + std::unique_lock u_lock(*store_record_cb->mutex); + if (store_record_cb->events->size() == EXPECTED) { + assert(std::get<0>(store_record_cb->events->at(0)) == CONNECTED); + assert(std::get<0>(store_record_cb->events->at(1)) == CONNECTION_CLOSED); + assert(std::get<0>(store_record_cb->events->at(2)) == CONNECTED); + assert(std::get<0>(store_record_cb->events->at(3)) == CONNECTION_CLOSED); + assert(std::get<0>(store_record_cb->events->at(4)) == CONNECTED); + assert(std::get<0>(store_record_cb->events->at(5)) == CONNECTION_CLOSED); break; } - store_record_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + store_record_cb->cond->wait_for(u_lock, + std::chrono::milliseconds(WAIT_MILLIS)); } return 0; diff --git a/tests/test_drop_sender.cpp b/tests/test_drop_sender.cpp index 4e3dc79..5586851 100644 --- a/tests/test_drop_sender.cpp +++ b/tests/test_drop_sender.cpp @@ -6,12 +6,14 @@ // this test is to test that dropping sender object closes remote connections int test_drop_sender(int argc, char **argv) { + SpdLogger::init(); const std::string local_addr = "127.0.0.1:40100"; - std::mutex lock; - std::condition_variable cond; - std::atomic_int sig(0); - std::vector>> buffer; + auto lock = std::make_shared(); + auto cond = std::make_shared(); + auto sig = std::make_shared(0); + auto buffer = std::make_shared< + std::vector>>>(); auto server_cb = std::make_shared(); // bit flag socket drop sender directly, which should close the connection. @@ -22,32 +24,32 @@ int test_drop_sender(int argc, char **argv) { server.listen_on_addr(local_addr); // wait 10ms for server to start listening - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); test.connect_to_addr(local_addr); // Wait for the connection to close while (true) { - std::unique_lock u_lock(server_cb->mutex); - if (server_cb->events.size() == 2) { - assert(std::get<0>(server_cb->events[0]) == CONNECTED); - assert(std::get<0>(server_cb->events[1]) == CONNECTION_CLOSED); + std::unique_lock u_lock(*server_cb->mutex); + if (server_cb->events->size() == 2) { + assert(std::get<0>(server_cb->events->at(0)) == CONNECTED); + assert(std::get<0>(server_cb->events->at(1)) == CONNECTION_CLOSED); break; } - server_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + server_cb->cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } while (true) { - int load_sig = sig.load(std::memory_order_seq_cst); - if (load_sig & CONNECTION_CLOSED) { + int load_sig = sig->load(std::memory_order_seq_cst); + if (0 != (load_sig & CONNECTION_CLOSED)) { assert(load_sig & CONNECTED); assert(!(load_sig & CONNECT_ERROR)); assert(!(load_sig & LISTEN_ERROR)); - assert(buffer.empty()); + assert(buffer->empty()); return 0; } { - std::unique_lock u_lock(lock); - cond.wait_for(u_lock, std::chrono::milliseconds(10)); + std::unique_lock u_lock(*lock); + cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } } } diff --git a/tests/test_error_call_after_abort.cpp b/tests/test_error_call_after_abort.cpp index 10a9af9..53715f5 100644 --- a/tests/test_error_call_after_abort.cpp +++ b/tests/test_error_call_after_abort.cpp @@ -4,6 +4,7 @@ using namespace socket_manager; int test_error_call_after_abort(int argc, char **argv) { + SpdLogger::init(); auto nothing_cb = std::make_shared(); SocketManager nothing(nothing_cb); @@ -17,7 +18,8 @@ int test_error_call_after_abort(int argc, char **argv) { // should not reach here return 1; } catch (std::runtime_error &e) { - std::cout << "connect_to_addr after abort-join throw error: " << e.what() << std::endl; + std::cout << "connect_to_addr after abort-join throw error: " << e.what() + << std::endl; } // should not throw error on abort after abort diff --git a/tests/test_error_connect.cpp b/tests/test_error_connect.cpp index 137f543..9c6d2f8 100644 --- a/tests/test_error_connect.cpp +++ b/tests/test_error_connect.cpp @@ -1,12 +1,14 @@ #undef NDEBUG -#include #include "test_utils.h" +#include int test_error_connect(int argc, char **argv) { - std::mutex lock; - std::condition_variable cond; - std::atomic_int sig(0); - std::vector>> buffer; + SpdLogger::init(); + auto lock = std::make_shared(); + auto cond = std::make_shared(); + auto sig = std::make_shared(0); + auto buffer = std::make_shared< + std::vector>>>(); auto test_cb = std::make_shared(lock, cond, sig, buffer); SocketManager test(test_cb); @@ -14,17 +16,17 @@ int test_error_connect(int argc, char **argv) { // Wait for the connection to fail while (true) { - int load_sig = sig.load(std::memory_order_seq_cst); - if (load_sig & CONNECT_ERROR) { + int load_sig = sig->load(std::memory_order_seq_cst); + if (0 != (load_sig & CONNECT_ERROR)) { assert(!(load_sig & CONNECTED)); assert(!(load_sig & CONNECTION_CLOSED)); assert(!(load_sig & LISTEN_ERROR)); - assert(buffer.empty()); + assert(buffer->empty()); return 0; } { - std::unique_lock u_lock(lock); - cond.wait_for(u_lock, std::chrono::milliseconds(10)); + std::unique_lock u_lock(*lock); + cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } } } diff --git a/tests/test_error_listen.cpp b/tests/test_error_listen.cpp index a64cfa1..62e035b 100644 --- a/tests/test_error_listen.cpp +++ b/tests/test_error_listen.cpp @@ -1,12 +1,14 @@ #undef NDEBUG -#include #include "test_utils.h" +#include int test_error_listen(int argc, char **argv) { - std::mutex lock; - std::condition_variable cond; - std::atomic_int sig(0); - std::vector>> buffer; + SpdLogger::init(); + auto lock = std::make_shared(); + auto cond = std::make_shared(); + auto sig = std::make_shared(0); + auto buffer = std::make_shared< + std::vector>>>(); auto test_cb = std::make_shared(lock, cond, sig, buffer); SocketManager test(test_cb); @@ -15,17 +17,17 @@ int test_error_listen(int argc, char **argv) { // Wait for the connection to fail while (true) { - int load_sig = sig.load(std::memory_order_seq_cst); - if (load_sig & LISTEN_ERROR) { + int load_sig = sig->load(std::memory_order_seq_cst); + if (0 != (load_sig & LISTEN_ERROR)) { assert(!(load_sig & CONNECTED)); assert(!(load_sig & CONNECTION_CLOSED)); assert(!(load_sig & CONNECT_ERROR)); - assert(buffer.empty()); + assert(buffer->empty()); return 0; } { - std::unique_lock u_lock(lock); - cond.wait_for(u_lock, std::chrono::milliseconds(10)); + std::unique_lock u_lock(*lock); + cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } } } diff --git a/tests/test_error_send_after_closed.cpp b/tests/test_error_send_after_closed.cpp index 9057e1a..cdcecf2 100644 --- a/tests/test_error_send_after_closed.cpp +++ b/tests/test_error_send_after_closed.cpp @@ -4,6 +4,7 @@ #include int test_error_send_after_closed(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40107"; auto server_cb = std::make_shared(); @@ -15,20 +16,20 @@ int test_error_send_after_closed(int argc, char **argv) { server.listen_on_addr(addr); // Wait for 100ms - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); client.connect_to_addr(addr); // wait for connection success (server) std::string s_conn_id; while (true) { - std::unique_lock u_lock(server_cb->mutex); - if (server_cb->events.size() == 1) { - assert(std::get<0>(server_cb->events[0]) == CONNECTED); - s_conn_id = std::get<1>(server_cb->events[0]); + std::unique_lock u_lock(*server_cb->mutex); + if (server_cb->events->size() == 1) { + assert(std::get<0>(server_cb->events->at(0)) == CONNECTED); + s_conn_id = std::get<1>(server_cb->events->at(0)); break; } - server_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + server_cb->cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } // close connection from server (by dropping sender) @@ -37,19 +38,19 @@ int test_error_send_after_closed(int argc, char **argv) { // wait for connection closed (client) std::string c_conn_id; while (true) { - std::unique_lock u_lock(client_cb->mutex); - if (client_cb->events.size() == 2) { - assert(std::get<0>(client_cb->events[0]) == CONNECTED); - assert(std::get<0>(client_cb->events[1]) == CONNECTION_CLOSED); - c_conn_id = std::get<1>(client_cb->events[0]); - assert(std::get<1>(client_cb->events[1]) == c_conn_id); + std::unique_lock u_lock(*client_cb->mutex); + if (client_cb->events->size() == 2) { + assert(std::get<0>(client_cb->events->at(0)) == CONNECTED); + assert(std::get<0>(client_cb->events->at(1)) == CONNECTION_CLOSED); + c_conn_id = std::get<1>(client_cb->events->at(0)); + assert(std::get<1>(client_cb->events->at(1)) == c_conn_id); break; } - client_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + client_cb->cond->wait_for(u_lock, std::chrono::milliseconds(WAIT_MILLIS)); } // Wait for 100ms - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); // should emit runtime error if attempt to send from client after closed try { @@ -61,5 +62,4 @@ int test_error_send_after_closed(int argc, char **argv) { // should not reach here return 1; - } diff --git a/tests/test_error_twice_start.cpp b/tests/test_error_twice_start.cpp index 88fa0f4..3f99613 100644 --- a/tests/test_error_twice_start.cpp +++ b/tests/test_error_twice_start.cpp @@ -12,10 +12,11 @@ class DoNothingMsgReceiver : public MsgReceiver { class TwiceStartCallback : public ConnCallback { public: - TwiceStartCallback() : error_thrown(0) {}; + TwiceStartCallback() : error_thrown(0){}; void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { auto receiver = std::make_unique(); auto receiver2 = std::make_unique(); conn->start(std::move(receiver)); @@ -36,11 +37,14 @@ class TwiceStartCallback : public ConnCallback { cond.notify_all(); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override {} + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override {} - void on_listen_error(const std::string &addr, const std::string &err) override {} + void on_listen_error(const std::string &addr, + const std::string &err) override {} - void on_connect_error(const std::string &addr, const std::string &err) override {} + void on_connect_error(const std::string &addr, + const std::string &err) override {} std::atomic_int error_thrown; std::mutex mutex; @@ -48,6 +52,7 @@ class TwiceStartCallback : public ConnCallback { }; int test_error_twice_start(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40108"; auto bad_cb = std::make_shared(); @@ -58,7 +63,7 @@ int test_error_twice_start(int argc, char **argv) { bad.listen_on_addr(addr); // wait 100ms - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); good.connect_to_addr(addr); // wait for error @@ -67,7 +72,7 @@ int test_error_twice_start(int argc, char **argv) { if (bad_cb->error_thrown.load(std::memory_order_acquire) == 2) { break; } - bad_cb->cond.wait_for(lock, std::chrono::milliseconds(10)); + bad_cb->cond.wait_for(lock, std::chrono::milliseconds(WAIT_MILLIS)); } return 0; diff --git a/tests/test_hello_world_greetings.cpp b/tests/test_hello_world_greetings.cpp deleted file mode 100644 index 7d6978d..0000000 --- a/tests/test_hello_world_greetings.cpp +++ /dev/null @@ -1,115 +0,0 @@ -#undef NDEBUG -#include "test_utils.h" -#include -#include -#include - -int test_hello_world_greetings(int argc, char **argv) { - - const std::string addr = "127.0.0.1:40109"; - - // create server - auto server_cb = std::make_shared(); - auto client_cb = std::make_shared(); - - SocketManager server(server_cb); - SocketManager client(client_cb); - - server.listen_on_addr(addr); - // Wait for 100ms - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - // create client - client.connect_to_addr(addr); - - std::cout << "Client connect" << std::endl; - - std::string c_conn_id; - std::string s_conn_id; - - // wait for connection success (client side) - while (true) { - std::cout << "client before lock" << std::endl; - std::unique_lock u_lock(client_cb->mutex); - std::cout << "client lock" << std::endl; - if (client_cb->events.size() == 1) { - std::cout << "Client connection established: " << std::get<1>(client_cb->events[0]) << std::endl; - assert(std::get<0>(client_cb->events[0]) == CONNECTED); - c_conn_id = std::get<1>(client_cb->events[0]); - break; - } - client_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - - // wait for connection success (server side) - while (true) { - std::unique_lock u_lock(server_cb->mutex); - if (server_cb->events.size() == 1) { - std::cout << "Server connection established: " << std::get<1>(server_cb->events[0]) << std::endl; - assert(std::get<0>(server_cb->events[0]) == CONNECTED); - s_conn_id = std::get<1>(server_cb->events[0]); - server.cancel_listen_on_addr(addr); - break; - } - server_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - - // send message - client_cb->send_to(c_conn_id, "hello world"); - - // wait for server receive - while (true) { - std::unique_lock u_lock(server_cb->mutex); - if (server_cb->buffer.size() == 1) { - std::cout << "Server received: " << *std::get<1>(server_cb->buffer[0]) - << " from connection=" << std::get<0>(server_cb->buffer[0]) << std::endl; - assert(std::get<0>(server_cb->buffer[0]) == s_conn_id); - assert(*std::get<1>(server_cb->buffer[0]) == "hello world"); - break; - } - server_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - - server_cb->send_to(s_conn_id, "hello world"); - - // wait for client receive - while (true) { - std::unique_lock u_lock(client_cb->mutex); - if (client_cb->buffer.size() == 1) { - std::cout << "Client received: " << *std::get<1>(client_cb->buffer[0]) - << " from connection=" << std::get<0>(client_cb->buffer[0]) << std::endl; - assert(std::get<0>(client_cb->buffer[0]) == c_conn_id); - assert(*std::get<1>(client_cb->buffer[0]) == "hello world"); - break; - } - client_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - - // drop sender - server_cb->drop_connection(s_conn_id); - - // wait for connection close - while (true) { - std::unique_lock u_lock(server_cb->mutex); - if (server_cb->events.size() == 2) { - std::cout << "Connection closed: " << std::get<1>(server_cb->events[1]) << std::endl; - assert(std::get<0>(server_cb->events[1]) == CONNECTION_CLOSED); - break; - } - server_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - - // wait for connection close - while (true) { - std::unique_lock u_lock(client_cb->mutex); - if (client_cb->events.size() == 2) { - assert(std::get<0>(client_cb->events[1]) == CONNECTION_CLOSED); - std::cout << "Connection closed: " << std::get<1>(client_cb->events[1]) << std::endl; - break; - } - client_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - - client_cb->drop_connection(c_conn_id); - - return 0; -} diff --git a/tests/test_manual_flush.cpp b/tests/test_manual_flush.cpp index 933e044..9002754 100644 --- a/tests/test_manual_flush.cpp +++ b/tests/test_manual_flush.cpp @@ -1,107 +1,132 @@ #undef NDEBUG #include "test_utils.h" -#include #include -#include #include +#include +#include class FinalReceiver : public MsgReceiver { public: - FinalReceiver(bool &hasReceived, std::mutex &mutex, std::condition_variable &cond) - : has_received(hasReceived), mutex(mutex), cond(cond) {} + FinalReceiver(const std::shared_ptr &hasReceived, + const std::shared_ptr &mutex, + const std::shared_ptr &cond) + : has_received(hasReceived), mutex(mutex), cond(cond) {} private: - void on_message(std::string_view data) override { assert(data == "hello world"); - std::unique_lock lk(mutex); - has_received = true; - cond.notify_one(); + std::unique_lock lock(*mutex); + *has_received = true; + cond->notify_one(); std::cout << "final received" << std::endl; } - bool &has_received; - std::mutex &mutex; - std::condition_variable &cond; + std::shared_ptr has_received; + std::shared_ptr mutex; + std::shared_ptr cond; }; class EchoReceiver : public MsgReceiver { public: - EchoReceiver(bool &hasReceived, std::string &data, std::mutex &mutex, std::condition_variable &cond) - : has_received(hasReceived), _data(data), mutex(mutex), cond(cond) {} + EchoReceiver(const std::shared_ptr &hasReceived, + const std::shared_ptr &data, + const std::shared_ptr &mutex, + const std::shared_ptr &cond) + : has_received(hasReceived), _data(data), mutex(mutex), cond(cond) {} private: void on_message(std::string_view data) override { - std::unique_lock lk(mutex); - has_received = true; - _data.append(data); - cond.notify_one(); + std::unique_lock lock(*mutex); + *has_received = true; + _data->append(data); + cond->notify_one(); std::cout << "echo received" << std::endl; } - bool &has_received; - std::string &_data; - std::mutex &mutex; - std::condition_variable &cond; + std::shared_ptr has_received; + std::shared_ptr _data; + std::shared_ptr mutex; + std::shared_ptr cond; }; class HelloCallback : public ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { auto rcv = std::make_unique(has_received, mutex, cond); // disable write auto flush - conn->start(std::move(rcv), nullptr, DEFAULT_MSG_BUF_SIZE, 1, 0); - std::thread t([sender] { + conn->start(std::move(rcv), nullptr, DEFAULT_MSG_BUF_SIZE, + DEFAULT_READ_MSG_FLUSH_MILLI_SEC, 0); + std::thread sender_t([sender] { sender->send_block("hello world"); sender->flush(); }); - t.detach(); + sender_t.detach(); _sender = sender; std::cout << "hello world sent" << std::endl; } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override {} + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override {} - void on_listen_error(const std::string &addr, const std::string &err) override {} + void on_listen_error(const std::string &addr, + const std::string &err) override {} - void on_connect_error(const std::string &addr, const std::string &err) override {} + void on_connect_error(const std::string &addr, + const std::string &err) override {} public: + HelloCallback() + : has_received(std::make_shared(false)), + mutex(std::make_shared()), + cond(std::make_shared()) {} std::shared_ptr _sender; - bool has_received{false}; - std::mutex mutex; - std::condition_variable cond; + std::shared_ptr has_received; + std::shared_ptr mutex; + std::shared_ptr cond; }; class EchoCallback : public ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { auto rcv = std::make_unique(has_received, _data, mutex, cond); // disable write auto flush conn->start(std::move(rcv), nullptr, DEFAULT_MSG_BUF_SIZE, 1, 0); - std::thread t([sender, this]() { - std::unique_lock lk(mutex); - cond.wait(lk, [this]() { return has_received; }); - sender->send_block(_data); + std::thread sender_t([sender, this]() { + std::unique_lock lock(*mutex); + cond->wait(lock, [this]() { return has_received; }); + sender->send_block(*_data); std::cout << "echo received and sent back" << std::endl; }); - t.detach(); + sender_t.detach(); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override {} + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override {} - void on_listen_error(const std::string &addr, const std::string &err) override {} + void on_listen_error(const std::string &addr, + const std::string &err) override {} - void on_connect_error(const std::string &addr, const std::string &err) override {} + void on_connect_error(const std::string &addr, + const std::string &err) override {} - bool has_received{false}; - std::string _data; - std::mutex mutex; - std::condition_variable cond; +public: + EchoCallback() + : has_received(std::make_shared(false)), + _data(std::make_shared()), + mutex(std::make_shared()), + cond(std::make_shared()) {} + + std::shared_ptr has_received; + std::shared_ptr _data; + std::shared_ptr mutex; + std::shared_ptr cond; }; int test_manual_flush(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40201"; @@ -113,7 +138,7 @@ int test_manual_flush(int argc, char **argv) { hello.listen_on_addr(addr); // Wait for 100ms - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); // create client echo.connect_to_addr(addr); @@ -121,8 +146,9 @@ int test_manual_flush(int argc, char **argv) { // wait for message { - std::unique_lock lk(hello_cb->mutex); - hello_cb->cond.wait(lk, [&hello_cb]() { return hello_cb->has_received; }); + std::unique_lock lock(*hello_cb->mutex); + hello_cb->cond->wait(lock, + [&hello_cb]() { return hello_cb->has_received; }); } return 0; diff --git a/tests/test_multiple_connections.cpp b/tests/test_multiple_connections.cpp index f2e67ba..0c14bed 100644 --- a/tests/test_multiple_connections.cpp +++ b/tests/test_multiple_connections.cpp @@ -4,11 +4,14 @@ #include int test_multiple_connections(int argc, char **argv) { + SpdLogger::init(); // establish 3 connections from p0 (client) -> p1 (server) port 0 // and 2 connections from p0 (client) -> p1 (server) port 1 // and 2 connections from p1 (client) -> p0 (server) + const int TOTAL_CONN = 7; + const std::string p1_addr_0 = "127.0.0.1:40010"; const std::string p1_addr_1 = "127.0.0.1:40011"; const std::string p0_addr_0 = "127.0.0.1:40012"; @@ -16,121 +19,93 @@ int test_multiple_connections(int argc, char **argv) { auto p0_cb = std::make_shared(); auto p1_cb = std::make_shared(); - SocketManager p0(p0_cb); - SocketManager p1(p1_cb); + SocketManager party0(p0_cb); + SocketManager party1(p1_cb); // listen - p1.listen_on_addr(p1_addr_0); - p1.listen_on_addr(p1_addr_1); - p0.listen_on_addr(p0_addr_0); + party1.listen_on_addr(p1_addr_0); + party1.listen_on_addr(p1_addr_1); + party0.listen_on_addr(p0_addr_0); // wait for 100ms - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); // establish connections - for (int i = 0; i < 3; i++) { - p0.connect_to_addr(p1_addr_0); - } - for (int i = 0; i < 2; i++) { - p0.connect_to_addr(p1_addr_1); - } - for (int i = 0; i < 2; i++) { - p1.connect_to_addr(p0_addr_0); - } + party0.connect_to_addr(p1_addr_0); + party0.connect_to_addr(p1_addr_0); + party0.connect_to_addr(p1_addr_0); + + party0.connect_to_addr(p1_addr_1); + party0.connect_to_addr(p1_addr_1); + + party1.connect_to_addr(p0_addr_0); + party1.connect_to_addr(p0_addr_0); // wait for all connections established (7 in total) - while (true) { - std::unique_lock u_lock(p0_cb->mutex); - if (p0_cb->events.size() == 7) { - for (auto &e: p0_cb->events) { - std::cout << "p0 connection established: " << std::get<1>(e) << std::endl; - assert(std::get<0>(e) == CONNECTED); - } - break; - } - p0_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + { + std::unique_lock u_lock(*p0_cb->mutex); + p0_cb->cond->wait_for( + u_lock, std::chrono::milliseconds(WAIT_MILLIS), + [p0_cb]() { return p0_cb->events->size() == TOTAL_CONN; }); } - while (true) { - std::unique_lock u_lock(p1_cb->mutex); - if (p1_cb->events.size() == 7) { - for (auto &e: p1_cb->events) { - std::cout << "p0 connection established: " << std::get<1>(e) << std::endl; - assert(std::get<0>(e) == CONNECTED); - } - break; - } - p1_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + + { + std::unique_lock u_lock(*p1_cb->mutex); + p1_cb->cond->wait_for( + u_lock, std::chrono::milliseconds(WAIT_MILLIS), + [p1_cb]() { return p1_cb->events->size() == TOTAL_CONN; }); } // send messages from p0 to p1 and vice versa - for (auto &e: p0_cb->events) { - p0_cb->send_to(std::get<1>(e), "hello world"); + for (auto &event : *p0_cb->events) { + p0_cb->send_to(std::get<1>(event), "hello world"); } - for (auto &e: p1_cb->events) { - p1_cb->send_to(std::get<1>(e), "hello world"); + for (auto &event : *p1_cb->events) { + p1_cb->send_to(std::get<1>(event), "hello world"); } // check messages received in buffer - while (true) { - std::unique_lock u_lock(p0_cb->mutex); - if (p0_cb->buffer.size() == 7) { - for (auto &m: p0_cb->buffer) { - std::cout << "p0 received from connection " << std::get<0>(m) - << ": " << *std::get<1>(m) << std::endl; - assert(*std::get<1>(m) == "hello world"); - } - break; - } - p0_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + // confirm all connections from p1 are closed + { + std::unique_lock u_lock(*p0_cb->mutex); + p0_cb->cond->wait_for( + u_lock, std::chrono::milliseconds(WAIT_MILLIS), + [p0_cb]() { return p0_cb->buffer->size() == TOTAL_CONN; }); } - while (true) { - std::unique_lock u_lock(p1_cb->mutex); - if (p1_cb->buffer.size() == 7) { - for (auto &m: p1_cb->buffer) { - std::cout << "p1 received from connection " << std::get<0>(m) - << ": " << *std::get<1>(m) << std::endl; - assert(*std::get<1>(m) == "hello world"); - } - break; - } - p1_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + { + std::unique_lock u_lock(*p1_cb->mutex); + p1_cb->cond->wait_for( + u_lock, std::chrono::milliseconds(WAIT_MILLIS), + [p1_cb]() { return p1_cb->buffer->size() == TOTAL_CONN; }); } // shutdown all connections from p0 std::vector connections; { - std::unique_lock u_lock(p0_cb->mutex); - for (auto &e: p0_cb->events) { - if (std::get<0>(e) == CONNECTED) { - connections.push_back(std::get<1>(e)); + std::unique_lock u_lock(*p0_cb->mutex); + for (auto &event : *p0_cb->events) { + if (std::get<0>(event) == CONNECTED) { + connections.push_back(std::get<1>(event)); } } } assert(connections.size() == 7); - for (auto &c: connections) { - p0_cb->drop_connection(c); + for (auto &connect : connections) { + p0_cb->drop_connection(connect); } // confirm all connections from p1 are closed - while (true) { - std::cout << "p1 connected count: " << p1_cb->connected_count.load() << std::endl; - if (p1_cb->connected_count.load() == 0) { - break; - } - { - std::unique_lock u_lock(p1_cb->mutex); - p1_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(*p0_cb->mutex); + p0_cb->cond->wait_for( + u_lock, std::chrono::milliseconds(WAIT_MILLIS), + [p0_cb]() { return p0_cb->connected_count->load() == 0; }); } - while (true) { - std::cout << "p0 connected count: " << p0_cb->connected_count.load() << std::endl; - if (p0_cb->connected_count.load() == 0) { - break; - } - { - std::unique_lock u_lock(p0_cb->mutex); - p0_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(*p1_cb->mutex); + p1_cb->cond->wait_for( + u_lock, std::chrono::milliseconds(WAIT_MILLIS), + [p1_cb]() { return p1_cb->connected_count->load() == 0; }); } return 0; diff --git a/tests/test_transfer_data_large.cpp b/tests/test_transfer_data_large.cpp deleted file mode 100644 index 4c5716b..0000000 --- a/tests/test_transfer_data_large.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#undef NDEBUG - -#include "test_utils.h" -#include -#include - -class SendLargeDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - data.reserve(1024 * 1000); - for (int i = 0; i < 100 * 1024; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 1024; ++i) { - sender->send_block(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataLarge : public MsgReceiver { -public: - explicit StoreAllDataLarge(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - -int test_transfer_data_large(int argc, char **argv) { - const std::string addr = "127.0.0.1:40013"; - - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); - SocketManager send(send_cb); - SocketManager store(store_cb); - - send.listen_on_addr(addr); - - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - store.connect_to_addr(addr); - - // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - } -} diff --git a/tests/test_transfer_data_large_async.cpp b/tests/test_transfer_data_large_async.cpp index 7f745df..f6005c0 100644 --- a/tests/test_transfer_data_large_async.cpp +++ b/tests/test_transfer_data_large_async.cpp @@ -1,137 +1,34 @@ #undef NDEBUG -#include "test_utils.h" -#include -#include "concurrentqueue/concurrentqueue.h" -#include "concurrentqueue/lightweightsemaphore.h" +#include "transfer_common.h" +#include #include - -class CondWaker : public Notifier { -public: - explicit CondWaker(const std::shared_ptr &sem) : sem(sem) {} - - void wake() override { - sem->signal(); - } - -private: - std::shared_ptr sem; -}; - -class SendLargeDataConnCallbackAsync : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - auto sem = std::make_shared(); - auto waker = std::make_shared(sem); - conn->start(std::move(rcv), waker); - - std::string data; - data.reserve(1024 * 1000); - for (int i = 0; i < 100 * 1024; i++) { - data.append("helloworld"); - } - - std::thread t([=]() { - // send 1000MB data - int progress = 0; - size_t offset = 0; - std::string_view data_view(data); - while (progress < 1024) { - auto sent = sender->send_async(data_view.substr(offset)); - if (sent < 0) { - sem->wait(1000); - } else { - offset += sent; - } - if (offset == data.size()) { - offset = 0; - progress += 1; - } - } - }); - - t.detach(); - } -}; - -class StoreAllDataAsync : public MsgReceiver { -public: - explicit StoreAllDataAsync(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataNotifyOnCloseCallbackAsync : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - int test_transfer_data_large_async(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(SMALL_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); store.connect_to_addr(addr); // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; } diff --git a/tests/test_transfer_data_large_block.cpp b/tests/test_transfer_data_large_block.cpp new file mode 100644 index 0000000..b0c5ef6 --- /dev/null +++ b/tests/test_transfer_data_large_block.cpp @@ -0,0 +1,34 @@ +#undef NDEBUG + +#include "transfer_common.h" +#include +#include + +int test_transfer_data_large_block(int argc, char **argv) { + SpdLogger::init(); + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(LARGE_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); + } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; +} diff --git a/tests/test_transfer_data_large_manual_flush.cpp b/tests/test_transfer_data_large_manual_flush.cpp deleted file mode 100644 index e811685..0000000 --- a/tests/test_transfer_data_large_manual_flush.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#undef NDEBUG - -#include "test_utils.h" -#include -#include - - -class SendLargeManualDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - data.reserve(1024 * 1000); - for (int i = 0; i < 100 * 1024; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 1024; ++i) { - sender->send_block(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataLargeManual : public MsgReceiver { -public: - explicit StoreAllDataLargeManual(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - -int test_transfer_data_large_manual_flush(int argc, char **argv) { - const std::string addr = "127.0.0.1:40013"; - - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); - SocketManager send(send_cb); - SocketManager store(store_cb); - - send.listen_on_addr(addr); - - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - store.connect_to_addr(addr); - - // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - } -} diff --git a/tests/test_transfer_data_large_no_flush.cpp b/tests/test_transfer_data_large_no_flush.cpp new file mode 100644 index 0000000..31f9ac8 --- /dev/null +++ b/tests/test_transfer_data_large_no_flush.cpp @@ -0,0 +1,34 @@ +#undef NDEBUG + +#include "transfer_common.h" +#include +#include + +int test_transfer_data_large_no_flush(int argc, char **argv) { + SpdLogger::init(); + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(LARGE_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); + } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; +} diff --git a/tests/test_transfer_data_large_nonblock.cpp b/tests/test_transfer_data_large_nonblock.cpp index 246be5c..2fa4b76 100644 --- a/tests/test_transfer_data_large_nonblock.cpp +++ b/tests/test_transfer_data_large_nonblock.cpp @@ -1,107 +1,34 @@ #undef NDEBUG -#include "test_utils.h" +#include "transfer_common.h" #include #include - -class SendLargeNonBlockDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - data.reserve(1024 * 1000); - for (int i = 0; i < 100 * 1024; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 1024; ++i) { - sender->send_nonblock(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataLargeNonBlock : public MsgReceiver { -public: - explicit StoreAllDataLargeNonBlock(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - int test_transfer_data_large_nonblock(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(LARGE_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); store.connect_to_addr(addr); // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; } diff --git a/tests/test_transfer_data_mid.cpp b/tests/test_transfer_data_mid.cpp deleted file mode 100644 index b35cc26..0000000 --- a/tests/test_transfer_data_mid.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#undef NDEBUG - -#include "test_utils.h" -#include -#include - - -class SendMidDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - for (int i = 0; i < 10 * 1024; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 10 * 1024; ++i) { - sender->send_block(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataMid : public MsgReceiver { -public: - explicit StoreAllDataMid(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataMidNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - -int test_transfer_data_mid(int argc, char **argv) { - const std::string addr = "127.0.0.1:40013"; - - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); - SocketManager send(send_cb); - SocketManager store(store_cb); - - send.listen_on_addr(addr); - - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - store.connect_to_addr(addr); - - // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - } -} diff --git a/tests/test_transfer_data_mid_async.cpp b/tests/test_transfer_data_mid_async.cpp index fbdcbbf..72277b5 100644 --- a/tests/test_transfer_data_mid_async.cpp +++ b/tests/test_transfer_data_mid_async.cpp @@ -1,136 +1,34 @@ #undef NDEBUG -#include "test_utils.h" -#include -#include "concurrentqueue/concurrentqueue.h" -#include "concurrentqueue/lightweightsemaphore.h" +#include "transfer_common.h" +#include #include - -class CondWaker : public Notifier { -public: - explicit CondWaker(const std::shared_ptr &sem) : sem(sem) {} - - void wake() override { - sem->signal(); - } - -private: - std::shared_ptr sem; -}; - -class SendLargeDataConnCallbackAsyncMid : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - auto sem = std::make_shared(); - auto waker = std::make_shared(sem); - conn->start(std::move(rcv), waker); - - std::string data; - for (int i = 0; i < 10 * 1024; i++) { - data.append("helloworld"); - } - - std::thread t([=]() { - // send 1000MB data - int progress = 0; - size_t offset = 0; - std::string_view data_view(data); - while (progress < 10 * 1024) { - auto sent = sender->send_async(data_view.substr(offset)); - if (sent < 0) { - sem->wait(1000); - } else { - offset += sent; - } - if (offset == data.size()) { - offset = 0; - progress += 1; - } - } - }); - - t.detach(); - } -}; - -class StoreAllDataAsyncMid : public MsgReceiver { -public: - explicit StoreAllDataAsyncMid(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataNotifyOnCloseCallbackAsyncMid : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - int test_transfer_data_mid_async(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(MID_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); store.connect_to_addr(addr); // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; } diff --git a/tests/test_transfer_data_mid_block.cpp b/tests/test_transfer_data_mid_block.cpp new file mode 100644 index 0000000..c3f00ca --- /dev/null +++ b/tests/test_transfer_data_mid_block.cpp @@ -0,0 +1,34 @@ +#undef NDEBUG + +#include "transfer_common.h" +#include +#include + +int test_transfer_data_mid_block(int argc, char **argv) { + SpdLogger::init(); + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(MID_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); + } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; +} diff --git a/tests/test_transfer_data_mid_manual_flush.cpp b/tests/test_transfer_data_mid_manual_flush.cpp deleted file mode 100644 index b53e587..0000000 --- a/tests/test_transfer_data_mid_manual_flush.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#undef NDEBUG - -#include "test_utils.h" -#include -#include - - -class SendMidManualDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - for (int i = 0; i < 10 * 1024; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 10 * 1024; ++i) { - sender->send_block(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataMidManual : public MsgReceiver { -public: - explicit StoreAllDataMidManual(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataMidManualNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - -int test_transfer_data_mid_manual_flush(int argc, char **argv) { - const std::string addr = "127.0.0.1:40013"; - - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); - SocketManager send(send_cb); - SocketManager store(store_cb); - - send.listen_on_addr(addr); - - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - store.connect_to_addr(addr); - - // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - } -} diff --git a/tests/test_transfer_data_mid_no_flush.cpp b/tests/test_transfer_data_mid_no_flush.cpp new file mode 100644 index 0000000..24e6356 --- /dev/null +++ b/tests/test_transfer_data_mid_no_flush.cpp @@ -0,0 +1,34 @@ +#undef NDEBUG + +#include "transfer_common.h" +#include +#include + +int test_transfer_data_mid_no_flush(int argc, char **argv) { + SpdLogger::init(); + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(MID_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); + } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; +} diff --git a/tests/test_transfer_data_mid_nonblock.cpp b/tests/test_transfer_data_mid_nonblock.cpp index f65af7b..a5914f1 100644 --- a/tests/test_transfer_data_mid_nonblock.cpp +++ b/tests/test_transfer_data_mid_nonblock.cpp @@ -1,106 +1,34 @@ #undef NDEBUG -#include "test_utils.h" +#include "transfer_common.h" #include #include - -class SendMidNonBlockDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - for (int i = 0; i < 10 * 1024; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 10 * 1024; ++i) { - sender->send_nonblock(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataMidNonBlock : public MsgReceiver { -public: - explicit StoreAllDataMidNonBlock(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataMidNonBlockNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - int test_transfer_data_mid_nonblock(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(MID_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); store.connect_to_addr(addr); // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; } diff --git a/tests/test_transfer_data_small.cpp b/tests/test_transfer_data_small.cpp deleted file mode 100644 index 40dcf23..0000000 --- a/tests/test_transfer_data_small.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#undef NDEBUG - -#include "test_utils.h" -#include -#include - - -class SendSmallDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - for (int i = 0; i < 10; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 10 * 1024 * 1024; ++i) { - sender->send_block(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataSmall : public MsgReceiver { -public: - explicit StoreAllDataSmall(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataSmallNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - -int test_transfer_data_small(int argc, char **argv) { - const std::string addr = "127.0.0.1:40013"; - - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); - SocketManager send(send_cb); - SocketManager store(store_cb); - - send.listen_on_addr(addr); - - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - store.connect_to_addr(addr); - - // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - } -} diff --git a/tests/test_transfer_data_small_async.cpp b/tests/test_transfer_data_small_async.cpp index 64d2385..81185a8 100644 --- a/tests/test_transfer_data_small_async.cpp +++ b/tests/test_transfer_data_small_async.cpp @@ -1,136 +1,34 @@ #undef NDEBUG -#include "test_utils.h" -#include -#include "concurrentqueue/concurrentqueue.h" -#include "concurrentqueue/lightweightsemaphore.h" +#include "transfer_common.h" +#include #include - -class CondWaker : public Notifier { -public: - explicit CondWaker(const std::shared_ptr &sem) : sem(sem) {} - - void wake() override { - sem->signal(); - } - -private: - std::shared_ptr sem; -}; - -class SendLargeDataConnCallbackAsyncSmall : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - auto sem = std::make_shared(); - auto waker = std::make_shared(sem); - conn->start(std::move(rcv), waker); - - std::string data; - for (int i = 0; i < 10; i++) { - data.append("helloworld"); - } - - std::thread t([=]() { - // send 1000MB data - int progress = 0; - size_t offset = 0; - std::string_view data_view(data); - while (progress < 10 * 1024 * 1024) { - auto sent = sender->send_async(data_view.substr(offset)); - if (sent < 0) { - sem->wait(1000); - } else { - offset += sent; - } - if (offset == data.size()) { - offset = 0; - progress += 1; - } - } - }); - - t.detach(); - } -}; - -class StoreAllDataAsyncSmall : public MsgReceiver { -public: - explicit StoreAllDataAsyncSmall(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataNotifyOnCloseCallbackAsyncSmall : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - int test_transfer_data_small_async(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(SMALL_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); store.connect_to_addr(addr); // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; } diff --git a/tests/test_transfer_data_small_block.cpp b/tests/test_transfer_data_small_block.cpp new file mode 100644 index 0000000..2c535c8 --- /dev/null +++ b/tests/test_transfer_data_small_block.cpp @@ -0,0 +1,34 @@ +#undef NDEBUG + +#include "transfer_common.h" +#include +#include + +int test_transfer_data_small_block(int argc, char **argv) { + SpdLogger::init(); + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(SMALL_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); + } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; +} diff --git a/tests/test_transfer_data_small_manual_flush.cpp b/tests/test_transfer_data_small_manual_flush.cpp deleted file mode 100644 index 2382cf9..0000000 --- a/tests/test_transfer_data_small_manual_flush.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#undef NDEBUG - -#include "test_utils.h" -#include -#include - - -class SendSmallManualDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - for (int i = 0; i < 10; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 10 * 1024 * 1024; ++i) { - sender->send_block(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataSmallManual : public MsgReceiver { -public: - explicit StoreAllDataSmallManual(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataSmallManualNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - -int test_transfer_data_small_manual_flush(int argc, char **argv) { - const std::string addr = "127.0.0.1:40013"; - - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); - SocketManager send(send_cb); - SocketManager store(store_cb); - - send.listen_on_addr(addr); - - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - store.connect_to_addr(addr); - - // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } - } -} diff --git a/tests/test_transfer_data_small_no_flush.cpp b/tests/test_transfer_data_small_no_flush.cpp new file mode 100644 index 0000000..6ea982c --- /dev/null +++ b/tests/test_transfer_data_small_no_flush.cpp @@ -0,0 +1,34 @@ +#undef NDEBUG + +#include "transfer_common.h" +#include +#include + +int test_transfer_data_small_no_flush(int argc, char **argv) { + SpdLogger::init(); + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(SMALL_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); + } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; +} diff --git a/tests/test_transfer_data_small_nonblock.cpp b/tests/test_transfer_data_small_nonblock.cpp index e379d29..90eeeaa 100644 --- a/tests/test_transfer_data_small_nonblock.cpp +++ b/tests/test_transfer_data_small_nonblock.cpp @@ -1,106 +1,34 @@ #undef NDEBUG -#include "test_utils.h" +#include "transfer_common.h" #include #include - -class SendSmallNonBlockDataConnCallback : public DoNothingConnCallback { -public: - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - auto rcv = std::make_unique(); - conn->start(std::move(rcv)); - std::thread t([sender]() { - // send 1000MB data - std::string data; - for (int i = 0; i < 10; i++) { - data.append("helloworld"); - } - for (int i = 0; i < 10 * 1024 * 1024; ++i) { - sender->send_nonblock(data); - } - // close connection after sender finished. - }); - t.detach(); - } -}; - -class StoreAllDataSmallNonBlock : public MsgReceiver { -public: - explicit StoreAllDataSmallNonBlock(size_t &buffer, int &count) : buffer(buffer), count(count) {} - - void on_message(std::string_view data) override { - if (count % 100 == 0) { - std::cout << "received " << count << " messages " - << ",size = " << buffer << std::endl; - } - buffer += data.length(); - count += 1; - } - - size_t &buffer; - int &count; -}; - -class StoreAllDataSmallNonBlockNotifyOnCloseCallback : public ConnCallback { -public: - - void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr send) override { - auto rcv = std::make_unique(add_data, count); - // store sender so connection is not dropped. - this->sender = send; - conn->start(std::move(rcv)); - } - - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lk(mutex); - has_closed.store(true); - std::cout << "on_connection_close" << std::endl; - cond.notify_all(); - } - - void on_listen_error(const std::string &addr, const std::string &err) override {} - - void on_connect_error(const std::string &addr, const std::string &err) override {} - - std::mutex mutex; - std::condition_variable cond; - std::atomic_bool has_closed{false}; - size_t add_data{0}; - int count{0}; - std::shared_ptr sender; -}; - int test_transfer_data_small_nonblock(int argc, char **argv) { + SpdLogger::init(); const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(SMALL_MSG_SIZE, TOTAL_SIZE); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); send.listen_on_addr(addr); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); + std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_MILLIS)); store.connect_to_addr(addr); // Wait for the connection to close - while (true) { - if (store_cb->has_closed.load()) { - auto avg_size = store_cb->add_data / store_cb->count; - std::cout << "received " << store_cb->count << " messages ," - << "total size = " << store_cb->add_data << " bytes, " - << "average size = " << avg_size << " bytes" - << std::endl; - assert(store_cb->add_data == 1024 * 1024 * 1000); - return 0; - } - { - std::unique_lock u_lock(store_cb->mutex); - store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); - } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait(u_lock, + [store_cb]() { return store_cb->has_closed.load(); }); } + auto avg_size = *store_cb->add_data / *store_cb->count; + std::cout << "received " << *store_cb->count << " messages ," + << "total size = " << *store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" << std::endl; + assert(*store_cb->add_data == TOTAL_SIZE); + return 0; } diff --git a/tests/test_utils.h b/tests/test_utils.h index 3d11366..3385999 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -2,17 +2,42 @@ #ifndef SOCKET_MANAGER_TEST_UTILS_H #define SOCKET_MANAGER_TEST_UTILS_H -#include +#include "spdlog/spdlog.h" +#include #include +#include +#include #include -#include +#include #include -#include #include -#include -using namespace socket_manager; +class SpdLogger { +public: + static void + init(spdlog::level::level_enum level = spdlog::level::level_enum::debug) { + spdlog::set_level(level); + SOCKET_MANAGER_C_API_TraceLevel log_level = + SOCKET_MANAGER_C_API_TraceLevel::Off; + // Socket Manager level is the same as spdlog level from trace to err + if (level <= spdlog::level::err) { + log_level = static_cast(level); + } + socket_manager::init_logger(print_log, log_level, + SOCKET_MANAGER_C_API_TraceLevel::Off); + }; +private: + static void print_log(SOCKET_MANAGER_C_API_LogData log_data) { + socket_manager::LogData data = socket_manager::from_c_log_data(log_data); + spdlog::log(static_cast(data.level), + "{}: {}:{} {}", data.target, data.file, data.line, + data.message); + } +}; + +using namespace socket_manager; +const long long WAIT_MILLIS = 10; /// Flag Signal @@ -24,7 +49,6 @@ enum EventType { SEND_ERROR = 1 << 4 }; - /// /// Message Receivers /// @@ -35,23 +59,27 @@ class DoNothingReceiver : public MsgReceiver { class MsgStoreReceiver : public MsgReceiver { public: - MsgStoreReceiver(std::string conn_id, - std::mutex &mutex, - std::condition_variable &cond, - std::vector>> &buffer) - : conn_id(std::move(conn_id)), mutex(mutex), cond(cond), buffer(buffer) {} + MsgStoreReceiver( + std::string conn_id, const std::shared_ptr &mutex, + const std::shared_ptr &cond, + const std::shared_ptr< + std::vector>>> + &buffer) + : conn_id(std::move(conn_id)), mutex(mutex), cond(cond), buffer(buffer) {} void on_message(std::string_view data) override { - std::unique_lock lock(mutex); - buffer.emplace_back(conn_id, std::make_shared(data)); - cond.notify_all(); + std::unique_lock lock(*mutex); + buffer->emplace_back(conn_id, std::make_shared(data)); + cond->notify_all(); } private: std::string conn_id; - std::mutex &mutex; - std::condition_variable &cond; - std::vector>> &buffer; + std::shared_ptr mutex; + std::shared_ptr cond; + std::shared_ptr< + std::vector>>> + buffer; }; /// @@ -62,57 +90,71 @@ class MsgStoreReceiver : public MsgReceiver { class DoNothingConnCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { conn->close(); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override {} + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override {} - void on_listen_error(const std::string &addr, const std::string &err) override {} + void on_listen_error(const std::string &addr, + const std::string &err) override {} - void on_connect_error(const std::string &addr, const std::string &err) override {} + void on_connect_error(const std::string &addr, + const std::string &err) override {} }; class BitFlagCallback : public ConnCallback { public: - BitFlagCallback(std::mutex &mutex, std::condition_variable &cond, std::atomic_int &sig, - std::vector>> &buffer) - : mutex(mutex), cond(cond), sig(sig), buffer(buffer) {} + BitFlagCallback( + const std::shared_ptr &mutex, + const std::shared_ptr &cond, + const std::shared_ptr &sig, + const std::shared_ptr< + std::vector>>> + &buffer) + : mutex(mutex), cond(cond), sig(sig), buffer(buffer) {} void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { + std::shared_ptr conn, + std::shared_ptr sender) override { set_sig(CONNECTED); auto conn_id = local_addr + "->" + peer_addr; - auto msg_storer = std::make_unique(conn_id, mutex, cond, buffer); + auto msg_storer = + std::make_unique(conn_id, mutex, cond, buffer); conn->start(std::move(msg_storer)); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override { set_sig(CONNECTION_CLOSED); } - void on_listen_error(const std::string &addr, const std::string &err) override { + void on_listen_error(const std::string &addr, + const std::string &err) override { set_sig(LISTEN_ERROR); } - void on_connect_error(const std::string &addr, const std::string &err) override { + void on_connect_error(const std::string &addr, + const std::string &err) override { set_sig(CONNECT_ERROR); } protected: - void set_sig(int flag) { - std::lock_guard lock(mutex); - sig.fetch_or(flag, std::memory_order_seq_cst); - cond.notify_all(); + std::lock_guard lock(*mutex); + sig->fetch_or(flag, std::memory_order_seq_cst); + cond->notify_all(); } private: - - std::mutex &mutex; - std::condition_variable &cond; - std::atomic_int &sig; - std::vector>> &buffer; + std::shared_ptr mutex; + std::shared_ptr cond; + std::shared_ptr sig; + std::shared_ptr< + std::vector>>> + buffer; }; /// Store all events in order @@ -120,69 +162,84 @@ class BitFlagCallback : public ConnCallback { class StoreAllEventsConnCallback : public ConnCallback { public: - explicit StoreAllEventsConnCallback(bool clean_sender_on_close = true) - : connected_count(0), clean_sender_on_close(clean_sender_on_close) {} + : mutex(std::make_shared()), + connected_count(std::make_shared()), + cond(std::make_shared()), + events(std::make_shared< + std::vector>>()), + buffer(std::make_shared>>>()), + clean_sender_on_close(clean_sender_on_close) {} void on_connect(const std::string &local_addr, const std::string &peer_addr, - std::shared_ptr conn, std::shared_ptr sender) override { - std::unique_lock lock(mutex); + std::shared_ptr conn, + std::shared_ptr sender) override { + std::unique_lock lock(*mutex); auto conn_id = local_addr + "->" + peer_addr; - events.emplace_back(CONNECTED, conn_id); - auto msg_storer = std::make_unique(conn_id, mutex, cond, buffer); + events->emplace_back(CONNECTED, conn_id); + auto msg_storer = + std::make_unique(conn_id, mutex, cond, buffer); conn->start(std::move(msg_storer)); senders.emplace(conn_id, sender); - connected_count.fetch_add(1, std::memory_order_seq_cst); - cond.notify_all(); + connected_count->fetch_add(1, std::memory_order_seq_cst); + cond->notify_all(); } - void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { - std::unique_lock lock(mutex); + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override { + std::unique_lock lock(*mutex); auto conn_id = local_addr + "->" + peer_addr; - events.emplace_back(CONNECTION_CLOSED, conn_id); + events->emplace_back(CONNECTION_CLOSED, conn_id); if (clean_sender_on_close) { senders.erase(conn_id); } - connected_count.fetch_sub(1, std::memory_order_seq_cst); - cond.notify_all(); + connected_count->fetch_sub(1, std::memory_order_seq_cst); + cond->notify_all(); } - void on_listen_error(const std::string &addr, const std::string &err) override { - std::unique_lock lock(mutex); - events.emplace_back(LISTEN_ERROR, "listening to " + addr + " failed: " + err); - cond.notify_all(); + void on_listen_error(const std::string &addr, + const std::string &err) override { + std::unique_lock lock(*mutex); + events->emplace_back(LISTEN_ERROR, + "listening to " + addr + " failed: " + err); + cond->notify_all(); } - void on_connect_error(const std::string &addr, const std::string &err) override { - std::unique_lock lock(mutex); - events.emplace_back(CONNECT_ERROR, "connecting to " + addr + " failed: " + err); - cond.notify_all(); + void on_connect_error(const std::string &addr, + const std::string &err) override { + std::unique_lock lock(*mutex); + events->emplace_back(CONNECT_ERROR, + "connecting to " + addr + " failed: " + err); + cond->notify_all(); } void send_to(std::string &conn_id, const std::string &data) { - std::unique_lock lock(mutex); + std::unique_lock lock(*mutex); try { auto sender = senders.at(conn_id); sender->send_block(data); sender->flush(); } catch (std::out_of_range &e) { - std::cout << "connection " << conn_id << " not found during send" << std::endl; + std::cout << "connection " << conn_id << " not found during send" + << std::endl; } } void drop_connection(std::string &conn_id) { - std::unique_lock lock(mutex); + std::unique_lock lock(*mutex); senders.erase(conn_id); } - std::mutex mutex; - std::atomic_int connected_count; - std::condition_variable cond; - std::vector> events; - std::vector>> buffer; + std::shared_ptr mutex; + std::shared_ptr connected_count; + std::shared_ptr cond; + std::shared_ptr>> events; + std::shared_ptr< + std::vector>>> + buffer; std::unordered_map> senders; bool clean_sender_on_close; - }; -#endif //SOCKET_MANAGER_TEST_UTILS_H +#endif // SOCKET_MANAGER_TEST_UTILS_H diff --git a/tests/transfer_common.h b/tests/transfer_common.h new file mode 100644 index 0000000..866c721 --- /dev/null +++ b/tests/transfer_common.h @@ -0,0 +1,221 @@ +#undef NDEBUG + +#ifndef SOCKET_MANAGER_TEST_TRANSFER_COMMON_H +#define SOCKET_MANAGER_TEST_TRANSFER_COMMON_H + +#include "concurrentqueue/concurrentqueue.h" +#include "concurrentqueue/lightweightsemaphore.h" +#include "test_utils.h" +#include +#include + +const int PRINT_INTERVAL = 100; +const size_t TOTAL_SIZE = static_cast(1024) * 1024 * 1000; +const size_t LARGE_MSG_SIZE = static_cast(10) * 1024 * 1024; +const size_t MID_MSG_SIZE = static_cast(100) * 1024; +const size_t SMALL_MSG_SIZE = static_cast(100); + +class transfer_private { +public: + static std::string make_test_message(size_t msg_size) { + const std::string TEST_MSG = "helloworld"; + std::string data; + data.reserve(msg_size); + for (int i = 0; i < msg_size / TEST_MSG.size(); i++) { + data.append(TEST_MSG); + } + return data; + } +}; + +class SendBlockCB : public DoNothingConnCallback { +public: + explicit SendBlockCB(size_t msg_size, size_t total_size) + : msg_size(msg_size), total_size(total_size) {} + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr sender) override { + auto rcv = std::make_unique(); + + std::string data = transfer_private::make_test_message(msg_size); + size_t msg_count = total_size / msg_size; + + conn->start(std::move(rcv)); + std::thread([sender, data, msg_count]() { + for (int i = 0; i < msg_count; ++i) { + sender->send_block(data); + } + // close connection after sender finished. + }).detach(); + } + +private: + size_t msg_size; + size_t total_size; +}; + +class CondWaker : public Notifier { +public: + explicit CondWaker( + const std::shared_ptr &sem) + : sem(sem) {} + + void wake() override { sem->signal(); } + +private: + std::shared_ptr sem; +}; + +class SendAsyncCB : public DoNothingConnCallback { +public: + explicit SendAsyncCB(size_t msg_size, size_t total_size) + : msg_size(msg_size), total_size(total_size) {} + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr sender) override { + auto rcv = std::make_unique(); + auto sem = std::make_shared(); + auto waker = std::make_shared(sem); + + std::string data = transfer_private::make_test_message(msg_size); + size_t msg_count = total_size / msg_size; + + conn->start(std::move(rcv), waker); + std::thread([sender, msg_count, data, sem]() { + int progress = 0; + size_t offset = 0; + std::string_view data_view(data); + while (progress < msg_count) { + auto sent = sender->send_async(data_view.substr(offset)); + if (sent == PENDING) { + sem->wait(); + } else { + offset += sent; + } + if (offset == data.size()) { + offset = 0; + progress += 1; + } + } + // close connection after sender finished. + }).detach(); + } + +private: + size_t msg_size; + size_t total_size; +}; + +class SendNoFlushCB : public DoNothingConnCallback { +public: + explicit SendNoFlushCB(size_t msg_size, size_t total_size) + : msg_size(msg_size), total_size(total_size) {} + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr sender) override { + auto rcv = std::make_unique(); + + std::string data = transfer_private::make_test_message(msg_size); + size_t msg_count = total_size / msg_size; + + conn->start(std::move(rcv), nullptr, DEFAULT_MSG_BUF_SIZE, + DEFAULT_READ_MSG_FLUSH_MILLI_SEC, 0); + std::thread([sender, data, msg_count]() { + for (int i = 0; i < msg_count; ++i) { + sender->send_block(data); + } + // close connection after sender finished. + }).detach(); + } + +private: + size_t msg_size; + size_t total_size; +}; + +class SendNonBlockCB : public DoNothingConnCallback { +public: + explicit SendNonBlockCB(size_t msg_size, size_t total_size) + : msg_size(msg_size), total_size(total_size) {} + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr sender) override { + auto rcv = std::make_unique(); + + std::string data = transfer_private::make_test_message(msg_size); + size_t msg_count = total_size / msg_size; + + conn->start(std::move(rcv)); + std::thread([sender, data, msg_count]() { + for (int i = 0; i < msg_count; ++i) { + sender->send_nonblock(data); + } + // close connection after sender finished. + }).detach(); + } + +private: + size_t msg_size; + size_t total_size; +}; + +class CountReceived : public MsgReceiver { +public: + explicit CountReceived(const std::shared_ptr &buffer, + const std::shared_ptr &count) + : buffer(buffer), count(count) {} + + void on_message(std::string_view data) override { + if (*count % PRINT_INTERVAL == 0) { + std::cout << "received " << *count << " messages " + << ",size = " << *buffer << std::endl; + } + *buffer += data.length(); + *count += 1; + } + + std::shared_ptr buffer; + std::shared_ptr count; +}; + +class CountDataNotifyOnCloseCallback : public ConnCallback { +public: + CountDataNotifyOnCloseCallback() + : has_closed(false), add_data(std::make_shared(0)), + count(std::make_shared(0)) {} + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr send) override { + auto rcv = std::make_unique(add_data, count); + // store sender so connection is not dropped. + this->sender = send; + conn->start(std::move(rcv)); + } + + void on_connection_close(const std::string &local_addr, + const std::string &peer_addr) override { + has_closed.store(true); + std::cout << "on_connection_close" << std::endl; + cond.notify_all(); + } + + void on_listen_error(const std::string &addr, + const std::string &err) override {} + + void on_connect_error(const std::string &addr, + const std::string &err) override {} + + std::mutex mutex; + std::condition_variable cond; + std::atomic_bool has_closed; + std::shared_ptr add_data; + std::shared_ptr count; + std::shared_ptr sender; +}; + +#endif // SOCKET_MANAGER_TEST_TRANSFER_COMMON_H