diff --git a/.circleci/config.yml b/.circleci/config.yml index e15cd63..da8efe4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -47,11 +47,12 @@ jobs: cd tests/test_find_package cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build --config Release - cd build - ./helloworld_server & + ./build/helloworld_server & + SERVER_PID=$! # Give the server time to start sleep 1 rsp=$(curl http://127.0.0.1:49999) + kill $SERVER_PID if [[ "$rsp" == "Hello, world" ]]; then echo "Test passed" exit 0 @@ -59,6 +60,12 @@ jobs: echo "Test failed" exit 1 fi + - run: + name: "Test Build Examples" + command: | + cd examples/echo_server + cmake -B build -DCMAKE_BUILD_TYPE=Release + cmake --build build --config Release # Orchestrate jobs using workflows # See: https://circleci.com/docs/configuration-reference/#workflows @@ -68,5 +75,5 @@ workflows: - build: matrix: parameters: - os: ["jammy", "focal"] - shared: ["ON", "OFF"] + os: [ "jammy", "focal" ] + shared: [ "ON", "OFF" ] diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..e8d486a --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "cargo" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "weekly" diff --git a/.github/workflows/PR.yml b/.github/workflows/PR.yml index 394b4d5..7f0b9d3 100644 --- a/.github/workflows/PR.yml +++ b/.github/workflows/PR.yml @@ -64,11 +64,12 @@ jobs: eval "$(brew shellenv)" cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build --config Release - cd build - ./helloworld_server & + ./build/helloworld_server & + SERVER_PID=$! # Give the server time to start sleep 1 rsp=$(curl http://127.0.0.1:49999) + kill $SERVER_PID if [[ "$rsp" == "Hello, world" ]]; then echo "Test passed" exit 0 @@ -77,6 +78,13 @@ jobs: exit 1 fi + - name: Test Build Examples + working-directory: ${{github.workspace}}/examples/echo_server + run: | + eval "$(brew shellenv)" + cmake -B build -DCMAKE_BUILD_TYPE=Release + cmake --build build --config Release + build-linux: runs-on: ubuntu-latest @@ -133,11 +141,12 @@ jobs: run: | cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build --config Release - cd build - ./helloworld_server & + ./build/helloworld_server & + SERVER_PID=$! # Give the server time to start - rsp=$(curl http://127.0.0.1:49999) sleep 1 + rsp=$(curl http://127.0.0.1:49999) + kill $SERVER_PID if [[ "$rsp" == "Hello, world" ]]; then echo "Test passed" exit 0 @@ -145,3 +154,9 @@ jobs: echo "Test failed" exit 1 fi + + - name: Test Build Examples + working-directory: ${{github.workspace}}/examples/echo_server + run: | + cmake -B build -DCMAKE_BUILD_TYPE=Release + cmake --build build --config Release diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..c658dbd --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,98 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ "main" ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ "main" ] + schedule: + - cron: '33 15 * * 5' + +jobs: + analyze: + name: Analyze + # Runner size impacts CodeQL analysis time. To learn more, please see: + # - https://gh.io/recommended-hardware-resources-for-running-codeql + # - https://gh.io/supported-runners-and-hardware-resources + # - https://gh.io/using-larger-runners + # Consider using larger runners for possible analysis time improvements. + runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} + timeout-minutes: ${{ (matrix.language == 'swift' && 120) || 360 }} + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'cpp' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby', 'swift' ] + # Use only 'java' to analyze code written in Java, Kotlin or both + # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both + # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + with: + submodules: true + + - name: Install LLVM and Clang + run: | + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh 16 all + chmod +x update-alternatives-clang.sh + sudo ./update-alternatives-clang.sh 16 9999 + + - name: Install Rust toolchain + uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + default: true + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + + # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs + # queries: security-extended,security-and-quality + + + # Autobuild attempts to build any compiled languages (C/C++, C#, Go, Java, or Swift). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + # â„šī¸ Command-line programs to run using the OS shell. + # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun + + # If the Autobuild fails above, remove it and uncomment the following three lines. + # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. + + # - run: | + # echo "Run, Build Application using script" + # ./location_of_script_within_repo/buildscript.sh + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 + with: + category: "/language:${{matrix.language}}" diff --git a/.gitignore b/.gitignore index 6e0c6b0..3ea6d48 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ /target /Cargo.lock -.idea +**/.idea build build-debug cmake-build-debug diff --git a/.gitmodules b/.gitmodules index 7f835e6..98dca17 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "corrosion"] path = corrosion url = https://github.com/corrosion-rs/corrosion.git +[submodule "concurrentqueue"] + path = tests/concurrentqueue + url = https://github.com/cameron314/concurrentqueue.git diff --git a/CMakeLists.txt b/CMakeLists.txt index f3f61b3..23799b9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,11 +1,11 @@ cmake_minimum_required(VERSION 3.9) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) # read toolchain file before project set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/toolchain.cmake) # define project -project(socket_manager LANGUAGES C CXX VERSION 0.2.0) +project(socket_manager LANGUAGES C CXX VERSION 0.3.0) # set default build type as shared option(BUILD_SHARED_LIBS "Build using shared libraries" ON) diff --git a/README.md b/README.md index 7bfbc5f..36deac1 100644 --- a/README.md +++ b/README.md @@ -78,3 +78,28 @@ To enable lto, add: ```cmake set_property(TARGET PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) ``` + +## Memory (resource) Model + +Dropping `Sender` will close the connection, and drop its +reference to `Connection`. +`ConnCallback` will drop its internal reference to `Connection` when +the connection is closed. +Thus `Connection` will free any reference to `Notifier` +or `MsgReceiver`. +Note to drop all related resources, no reference to the +returned `Waker` should be kept. + +```mermaid +flowchart TD + M(SocketManager) -->|strong ref| CB(ConnCallback) + CB -->|strong ref, drop on close| CON(Connection) + CON -->|strong ref| NF(Notifier) + CON -->|strong ref| RCV(MsgReceiver) + SEND(Sender) -->|strong ref| CON + RCV -.->|returns| WK + WK(Waker) -.->|drop to release| READ(Read Task) + SEND -.->|drop to close| CB +``` + +See the example folder for more complicated case. diff --git a/cbindgen.toml b/cbindgen.toml index 08e4fea..87b869f 100644 --- a/cbindgen.toml +++ b/cbindgen.toml @@ -3,7 +3,7 @@ # possible values: "C", "C++", "Cython" # # default: "C++" -language = "C" +language = "C++" # A list of sys headers to #include (with angle brackets) # default: [] @@ -52,7 +52,7 @@ documentation = true # * "auto": "c++" if that's the language, "doxy" otherwise # # default: "auto" -documentation_style = "auto" +documentation_style = "doxy" # How much of the documentation for each item is output. # @@ -85,6 +85,7 @@ usize_is_size_t = true [export] item_types = ["enums", "structs", "opaque", "functions", "unions"] +prefix = "SOCKET_MANAGER_C_API_" [struct] diff --git a/dockerfile/dev-containers/focal/Dockerfile b/dockerfile/dev-containers/focal/Dockerfile index 413136b..6ca17fd 100644 --- a/dockerfile/dev-containers/focal/Dockerfile +++ b/dockerfile/dev-containers/focal/Dockerfile @@ -17,12 +17,13 @@ RUN apt-get update && \ git \ vim \ wget \ - curl + curl \ + screen # install llvm@16 WORKDIR /root -COPY ./dockerfile/dev-containers/llvm.sh /root COPY ./update-alternatives-clang.sh /root +RUN wget https://apt.llvm.org/llvm.sh RUN chmod +x llvm.sh RUN ./llvm.sh 16 all # use tsinghua mirror in China, comment out if you are not in China @@ -38,4 +39,6 @@ RUN rm update-alternatives-clang.sh RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly ENV PATH="/root/.cargo/bin:${PATH}" +RUN chsh -s /bin/bash + ENTRYPOINT ["/bin/bash"] diff --git a/dockerfile/dev-containers/jammy/Dockerfile b/dockerfile/dev-containers/jammy/Dockerfile index 76c26a4..7a28d90 100644 --- a/dockerfile/dev-containers/jammy/Dockerfile +++ b/dockerfile/dev-containers/jammy/Dockerfile @@ -17,12 +17,13 @@ RUN apt-get update && \ git \ vim \ wget \ - curl + curl \ + screen # install llvm@16 WORKDIR /root -COPY ./dockerfile/dev-containers/llvm.sh /root COPY ./update-alternatives-clang.sh /root +RUN wget https://apt.llvm.org/llvm.sh RUN chmod +x llvm.sh RUN ./llvm.sh 16 all # use tsinghua mirror in China, comment out if you are not in China @@ -38,4 +39,6 @@ RUN rm update-alternatives-clang.sh RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly ENV PATH="/root/.cargo/bin:${PATH}" +RUN chsh -s /bin/bash + ENTRYPOINT ["/bin/bash"] diff --git a/dockerfile/dev-containers/llvm.sh b/dockerfile/dev-containers/llvm.sh deleted file mode 100644 index 37ad095..0000000 --- a/dockerfile/dev-containers/llvm.sh +++ /dev/null @@ -1,173 +0,0 @@ -#!/bin/bash -################################################################################ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -################################################################################ -# -# This script will install the llvm toolchain on the different -# Debian and Ubuntu versions - -set -eux - -usage() { - set +x - echo "Usage: $0 [llvm_major_version] [all] [OPTIONS]" 1>&2 - echo -e "all\t\t\tInstall all packages." 1>&2 - echo -e "-n=code_name\t\tSpecifies the distro codename, for example bionic" 1>&2 - echo -e "-h\t\t\tPrints this help." 1>&2 - echo -e "-m=repo_base_url\tSpecifies the base URL from which to download." 1>&2 - exit 1; -} - -CURRENT_LLVM_STABLE=16 -BASE_URL="http://apt.llvm.org" - -# Check for required tools -needed_binaries=(lsb_release wget add-apt-repository gpg) -missing_binaries=() -for binary in "${needed_binaries[@]}"; do - if ! which $binary &>/dev/null ; then - missing_binaries+=($binary) - fi -done -if [[ ${#missing_binaries[@]} -gt 0 ]] ; then - echo "You are missing some tools this script requires: ${missing_binaries[@]}" - echo "(hint: apt install lsb-release wget software-properties-common gnupg)" - exit 4 -fi - -# Set default values for commandline arguments -# We default to the current stable branch of LLVM -LLVM_VERSION=$CURRENT_LLVM_STABLE -ALL=0 -DISTRO=$(lsb_release -is) -VERSION=$(lsb_release -sr) -UBUNTU_CODENAME="" -CODENAME_FROM_ARGUMENTS="" -# Obtain VERSION_CODENAME and UBUNTU_CODENAME (for Ubuntu and its derivatives) -source /etc/os-release -DISTRO=${DISTRO,,} -case ${DISTRO} in - debian) - if [[ "${VERSION}" == "unstable" ]] || [[ "${VERSION}" == "testing" ]]; then - CODENAME=unstable - LINKNAME= - else - # "stable" Debian release - CODENAME=${VERSION_CODENAME} - LINKNAME=-${CODENAME} - fi - ;; - *) - # ubuntu and its derivatives - if [[ -n "${UBUNTU_CODENAME}" ]]; then - CODENAME=${UBUNTU_CODENAME} - if [[ -n "${CODENAME}" ]]; then - LINKNAME=-${CODENAME} - fi - fi - ;; -esac - -# read optional command line arguments -if [ "$#" -ge 1 ] && [ "${1::1}" != "-" ]; then - if [ "$1" != "all" ]; then - LLVM_VERSION=$1 - else - # special case for ./llvm.sh all - ALL=1 - fi - OPTIND=2 - if [ "$#" -ge 2 ]; then - if [ "$2" == "all" ]; then - # Install all packages - ALL=1 - OPTIND=3 - fi - fi -fi - -while getopts ":hm:n:" arg; do - case $arg in - h) - usage - ;; - m) - BASE_URL=${OPTARG} - ;; - n) - CODENAME=${OPTARG} - if [[ "${CODENAME}" == "unstable" ]]; then - # link name does not apply to unstable repository - LINKNAME= - else - LINKNAME=-${CODENAME} - fi - CODENAME_FROM_ARGUMENTS="true" - ;; - esac -done - -if [[ $EUID -ne 0 ]]; then - echo "This script must be run as root!" - exit 1 -fi - -declare -A LLVM_VERSION_PATTERNS -LLVM_VERSION_PATTERNS[9]="-9" -LLVM_VERSION_PATTERNS[10]="-10" -LLVM_VERSION_PATTERNS[11]="-11" -LLVM_VERSION_PATTERNS[12]="-12" -LLVM_VERSION_PATTERNS[13]="-13" -LLVM_VERSION_PATTERNS[14]="-14" -LLVM_VERSION_PATTERNS[15]="-15" -LLVM_VERSION_PATTERNS[16]="-16" -LLVM_VERSION_PATTERNS[17]="" - -if [ ! ${LLVM_VERSION_PATTERNS[$LLVM_VERSION]+_} ]; then - echo "This script does not support LLVM version $LLVM_VERSION" - exit 3 -fi - -LLVM_VERSION_STRING=${LLVM_VERSION_PATTERNS[$LLVM_VERSION]} - -# join the repository name -if [[ -n "${CODENAME}" ]]; then - REPO_NAME="deb ${BASE_URL}/${CODENAME}/ llvm-toolchain${LINKNAME}${LLVM_VERSION_STRING} main" - - # check if the repository exists for the distro and version - if ! wget -q --method=HEAD ${BASE_URL}/${CODENAME} &> /dev/null; then - if [[ -n "${CODENAME_FROM_ARGUMENTS}" ]]; then - echo "Specified codename '${CODENAME}' is not supported by this script." - else - echo "Distribution '${DISTRO}' in version '${VERSION}' is not supported by this script." - fi - exit 2 - fi -fi - - -# install everything - -if [[ ! -f /etc/apt/trusted.gpg.d/apt.llvm.org.asc ]]; then - # download GPG key once - wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc -fi - -if [[ -z "`apt-key list 2> /dev/null | grep -i llvm`" ]]; then - # Delete the key in the old format - apt-key del AF4F7421 -fi -add-apt-repository "${REPO_NAME}" -y -apt-get update -PKG="clang-$LLVM_VERSION lldb-$LLVM_VERSION lld-$LLVM_VERSION clangd-$LLVM_VERSION" -if [[ $ALL -eq 1 ]]; then - # same as in test-install.sh - # No worries if we have dups - PKG="$PKG clang-tidy-$LLVM_VERSION clang-format-$LLVM_VERSION clang-tools-$LLVM_VERSION llvm-$LLVM_VERSION-dev lld-$LLVM_VERSION lldb-$LLVM_VERSION llvm-$LLVM_VERSION-tools libomp-$LLVM_VERSION-dev libc++-$LLVM_VERSION-dev libc++abi-$LLVM_VERSION-dev libclang-common-$LLVM_VERSION-dev libclang-$LLVM_VERSION-dev libclang-cpp$LLVM_VERSION-dev libunwind-$LLVM_VERSION-dev" - if test $LLVM_VERSION -gt 14; then - PKG="$PKG libclang-rt-$LLVM_VERSION-dev libpolly-$LLVM_VERSION-dev" - fi -fi -apt-get install -y $PKG diff --git a/examples/echo_server/CMakeLists.txt b/examples/echo_server/CMakeLists.txt new file mode 100644 index 0000000..eb2e1e1 --- /dev/null +++ b/examples/echo_server/CMakeLists.txt @@ -0,0 +1,12 @@ +cmake_minimum_required(VERSION 3.9) + +set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/toolchain.cmake) + +project(echo_server) + +set(CMAKE_CXX_STANDARD 20) +add_executable(echo_server src/echo_server.cpp) +set_property(TARGET echo_server PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + +find_package(socket_manager 0.3.0 REQUIRED) +target_link_libraries(echo_server PUBLIC socket_manager) diff --git a/examples/echo_server/README.md b/examples/echo_server/README.md new file mode 100644 index 0000000..ade4ec9 --- /dev/null +++ b/examples/echo_server/README.md @@ -0,0 +1,82 @@ +# An Echo Server Implemented Using SocketManager + +Implementing Echo using SocketManager is a bit tricky to +make the memory model correct. + +## The Memory Model Before Echo +Dropping `Sender` will close the connection, and drop its +reference to `Connection`. +`ConnCallback` will drop its internal reference to `Connection` when +the connection is closed. +Thus `Connection` will free any reference to `Notifier` +or `MsgReceiver`. +Note to drop all related resources, no reference to the +returned `Waker` should be kept. + +```mermaid +flowchart TD + M(SocketManager) -->|strong ref| CB(ConnCallback) + CB -->|strong ref, drop on close| CON(Connection) + CON -->|strong ref| NF(Notifier) + CON -->|strong ref| RCV(MsgReceiver) + SEND(Sender) -->|strong ref| CON +``` + +## The Memory Model Of Echo + +`Sender` is now referenced by `MsgReceiver` for sending +back the message, and cannot be dropped easily, +thus a cycle of reference is created. + +New links are marked in red. + +```mermaid +flowchart TD + NF ==>|strong ref| WK(Waker ref=1) + RCV ==>|strong ref| SEND + RCV ==>|strong ref| NF + CB ==> |strong ref| RCV + M(SocketManager) -->|strong ref| CB(ConnCallback) + CB -->|strong ref, drop on close| CON(Connection ref=2) + CON -->|strong ref| NF(Notifier ref=2) + CON -->|strong ref| RCV(MsgReceiver ref=2) + SEND(Sender ref=1) -->|strong ref| CON + WK -.-> TOKIO(tokio task) + linkStyle 0,1,2,3 color:red; + +``` + +The echo connection can be closed only be the remote, +thus we receive a `on_close` event from `ConnCallback`, +which will drop its reference to `Connection`. +Creating the following: + +```mermaid +flowchart TD + NF ==>|strong ref| WK(Waker ref=1) + RCV ==>|strong ref| SEND + RCV ==>|strong ref| NF + CB(ConnCallback) ==> |strong ref| RCV + CON -->|strong ref| NF(Notifier ref=2) + CON -->|strong ref| RCV(MsgReceiver ref=2) + SEND(Sender ref=1) -->|strong ref| CON(Connection ref=1) + WK -.-> TOKIO(tokio task) + linkStyle 0,1,2,3 color:red; + +``` + +Now, using the reference of `MsgReceiver` in `ConnCallback`, +we manually call erase the reference to `Sender` and `Notifer` +in `MsgReceiver`, and then drop the reference of `MsgReceiver`: + +```mermaid +flowchart TD + NF ==>|strong ref| WK(Waker ref=1) + CON -->|strong ref| NF(Notifier ref=1) + CON -->|strong ref| RCV(MsgReceiver ref=1) + SEND(Sender ref=0) -->|strong ref| CON(Connection ref=1) + WK -.-> TOKIO(tokio task) + linkStyle 0,1 color:red; +``` + +Now everything will be cleared up. diff --git a/examples/echo_server/justfile b/examples/echo_server/justfile new file mode 100644 index 0000000..86b3be5 --- /dev/null +++ b/examples/echo_server/justfile @@ -0,0 +1,16 @@ +clean: + rm -rf build + +debug: + cmake -B build -DCMAKE_TOOLCHAIN_FILE=toolchain.cmake \ + -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON + cmake --build build --parallel 4 --config Release --verbose + +build: + cmake -B build -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=toolchain.cmake \ + -DCMAKE_INTERPROCEDURAL_OPTIMIZATION=ON + cmake --build build --parallel 4 --config Release --verbose + +run: + SOCKET_LOG=info ./build/echo_server diff --git a/examples/echo_server/src/echo_server.cpp b/examples/echo_server/src/echo_server.cpp new file mode 100644 index 0000000..9dc0875 --- /dev/null +++ b/examples/echo_server/src/echo_server.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include +#include + +/** + * UniqueWaker is a wrapper of `socket_manager::Waker` + * that also implements `socket_manager::Notifier`. + */ +class WrapWaker : public socket_manager::Notifier { +public: + explicit WrapWaker(socket_manager::Waker &&wake) : waker(std::move(wake)) {} + + void set_waker(socket_manager::Waker &&wake) { + waker = std::move(wake); + } + +private: + void wake() override { + waker.wake(); + } + + socket_manager::Waker waker; +}; + +/** + * When the receiver receives, + * it tries to send back the message, + * unless the sender returns `PENDING`, + * it sleeps until the sender wakes it up. + */ +class EchoReceiver : public socket_manager::MsgReceiverAsync { +public: + explicit EchoReceiver( + std::shared_ptr &&sender, + const std::shared_ptr &waker + ) : waker(waker), sender(std::move(sender)) {}; + + /** + * Release resources to break potential ref cycles. + */ + void close() { + waker.reset(); + sender.reset(); + } + +private: + long on_message_async(std::string_view data, socket_manager::Waker &&wake) override { + waker->set_waker(std::move(wake)); + return sender->send_async(data); + }; + std::shared_ptr waker; + std::shared_ptr sender; +}; + +/** + * The callbacks for connection events. + */ +class EchoCallback : public socket_manager::ConnCallback { +private: + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, + std::shared_ptr sender) override { + auto waker = std::make_shared(socket_manager::Waker()); + auto recv = std::make_shared(std::move(sender), waker); + { + // add the receiver to the map for cleanup + std::lock_guard lock(mutex); + receivers[local_addr + peer_addr] = recv; + } + conn->start(std::move(recv), std::move(waker)); + } + + void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { + { + std::lock_guard lock(mutex); + auto find = receivers.find(local_addr + peer_addr); + if (find != receivers.end()) { + // release receiver resources + find->second->close(); + receivers.erase(find); + } else { + throw std::runtime_error("connection not found: " + local_addr + " -> " + peer_addr); + } + } + std::cout << "connection closed: " << local_addr << " -> " << peer_addr << std::endl; + } + + void on_listen_error(const std::string &addr, const std::string &err) override { + throw std::runtime_error("listen error: addr=" + addr + ", " + err); + } + + void on_connect_error(const std::string &addr, const std::string &err) override { + throw std::runtime_error("connect error: addr=" + addr + ", " + err); + } + + std::mutex mutex; + std::unordered_map> receivers; +}; + +int main() { + // start the server + auto callback = std::make_shared(); + auto manager = socket_manager::SocketManager(callback); + manager.listen_on_addr("127.0.0.1:10101"); + manager.join(); +} diff --git a/examples/echo_server/toolchain.cmake b/examples/echo_server/toolchain.cmake new file mode 100644 index 0000000..475dbcf --- /dev/null +++ b/examples/echo_server/toolchain.cmake @@ -0,0 +1,31 @@ +SET(CMAKE_C_FLAGS "-Wall -std=c99") +SET(CMAKE_C_FLAGS_DEBUG "-g") +SET(CMAKE_C_FLAGS_MINSIZEREL "-Os -DNDEBUG") +SET(CMAKE_C_FLAGS_RELEASE "-O3 -DNDEBUG") +SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O2 -g") + +SET(CMAKE_CXX_FLAGS "-Wall") +SET(CMAKE_CXX_FLAGS_DEBUG "-g") +SET(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG") +SET(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") +SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g") + +if (APPLE) + MESSAGE(STATUS "Using LLVM/Clang from Homebrew") + SET(CMAKE_C_COMPILER $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/clang) + SET(CMAKE_CXX_COMPILER $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/clang++) + SET(CMAKE_AR $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/llvm-ar) + SET(CMAKE_LINKER $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/llvm-ld) + SET(CMAKE_NM $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/llvm-nm) + SET(CMAKE_OBJDUMP $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/llvm-objdump) + SET(CMAKE_RANLIB $ENV{HOMEBREW_PREFIX}/opt/llvm/bin/llvm-ranlib) +else () + MESSAGE(STATUS "Using LLVM/Clang from Linux") + SET(CMAKE_C_COMPILER /usr/bin/clang) + SET(CMAKE_CXX_COMPILER /usr/bin/clang++) + SET(CMAKE_AR /usr/bin/llvm-ar) + SET(CMAKE_LINKER /usr/bin/llvm-ld) + SET(CMAKE_NM /usr/bin/llvm-nm) + SET(CMAKE_OBJDUMP /usr/bin/llvm-objdump) + SET(CMAKE_RANLIB /usr/bin/llvm-ranlib) +endif () diff --git a/include/socket_manager/common/notifier.h b/include/socket_manager/common/notifier.h new file mode 100644 index 0000000..29627d8 --- /dev/null +++ b/include/socket_manager/common/notifier.h @@ -0,0 +1,37 @@ +#ifndef SOCKET_MANAGER_NOTIFIER_H +#define SOCKET_MANAGER_NOTIFIER_H + +#include "socket_manager_c_api.h" + +namespace socket_manager { + /** + * @brief The Notifier class is used to receive notification + * from rust runtime to request c/c++ task for further + * execution. + * + * @note Our current implementation does not require ref count + * since we bound the lifetime of the notifier to the + * lifetime of the connection, and guarantee that the passed + * notifier is valid by keeping a reference of the notifier + * in the related connection. + */ + class Notifier { + + public: + virtual ~Notifier() = default; + + private: + + virtual void wake() = 0; + + friend void::socket_manager_extern_notifier_wake(struct SOCKET_MANAGER_C_API_Notifier this_); + + }; + + class NoopNotifier : public Notifier { + public: + void wake() override {} + }; +} + +#endif //SOCKET_MANAGER_NOTIFIER_H diff --git a/include/socket_manager/common/waker.h b/include/socket_manager/common/waker.h new file mode 100644 index 0000000..341b64b --- /dev/null +++ b/include/socket_manager/common/waker.h @@ -0,0 +1,61 @@ +#ifndef SOCKET_MANAGER_WAKER_H +#define SOCKET_MANAGER_WAKER_H + +#include "socket_manager_c_api.h" + +namespace socket_manager { + + /** + * Return `PENDING` to interrupt runtime task. + */ + const long PENDING = -1; + + /** + * @brief Waker is used to wake up a pending runtime task. + * + * The implementation of `MsgReceiverAsync::on_message_async` + * can return `PENDING = -1` to interrupt a message receiving + * task in the runtime (e.g. when the caller buffer is full). + * + * And use `waker.wake()` to resume the message receiving task + * when the caller buffer is ready. + * + *

Resource Leak Note

+ * The `Waker` must be properly destroyed to avoid resource leak. + */ + class Waker { + public: + /** + * Call wake() to wake up the receiver process. + */ + void wake(); + + ~Waker(); + + /** + * Create an empty noop waker. + */ + explicit Waker(); + + Waker(const Waker &) = delete; + + Waker &operator=(const Waker &) = delete; + + Waker(Waker &&) noexcept; + + Waker &operator=(Waker &&) noexcept; + + private: + explicit Waker(SOCKET_MANAGER_C_API_CWaker waker); + + friend long::socket_manager_extern_on_msg( + struct SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, + SOCKET_MANAGER_C_API_CWaker waker, + char **err); + + SOCKET_MANAGER_C_API_CWaker waker; + }; +} // namespace socket_manager + +#endif //SOCKET_MANAGER_WAKER_H diff --git a/include/socket_manager/conn_callback.h b/include/socket_manager/conn_callback.h index 760c9e5..f1a2598 100644 --- a/include/socket_manager/conn_callback.h +++ b/include/socket_manager/conn_callback.h @@ -15,32 +15,12 @@ namespace socket_manager { /** * The callback object for handling connection events. - * - * # Error Handling + *

* Throwing error in the callback will cause the runtime * to abort. * - * # Thread Safety + *

Thread Safety

* All callback methods must be thread safe and non-blocking. - * - * # Note on safety: - * - * - The `connection callback` object should have - * a longer lifetime than the socket manager. - * - * - The `msg receiver` callbacks should all have - * longer lifetimes than the `connection callback`. - * - * - The design stores a shared pointer to the - * `ConnCallback` in `SocketManager`, and shared - * pointers of `Connection`s in the `ConnCallback` - * objects, and store unique pointers of `MsgReceiver` - * in `Connection`. - * - * Thus establish a dependency relationship as follows: - * `SocketManager` -> shared `ConnCallback` -> shared `Connection`s - * -> unique `MsgReceiver`, where the later object has a longer - * lifetime than the former. */ class ConnCallback { public: @@ -49,32 +29,39 @@ namespace socket_manager { private: - friend char* ::socket_manager_extern_on_conn(struct OnConnObj this_, ConnStates conn); + friend void::socket_manager_extern_on_conn( + struct SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates conn, + char **err); /** * Called when a new connection is established. * - * # Error handling + *

Error handling

* Throwing error in `on_connect` callback will close the connection * and a `on_connection_close` callback will be evoked. - * + *

* It should be non-blocking. + *

+ * Drop the returned `MsgSender` to close the connection. * * @param local_addr the local address of the connection. * @param peer_addr the peer address of the connection. - * @param conn a `Connection` object for sending and receiving data. + * @param conn a `Connection` object for starting the connection. + * @param sender a `Sender` object for sending data. */ virtual void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) = 0; + std::shared_ptr conn, + std::shared_ptr sender) = 0; /** * Called when a connection is closed. * - * # Error handling + *

Error handling

* Throwing error in `on_connection_close` callback is logged as error, * but ignored. - * + *

* It should be non-blocking. * * @param local_addr the local address of the connection. @@ -86,10 +73,10 @@ namespace socket_manager { /** * Called when an error occurs when listening on the given address. * - * # Error handling + *

Error handling

* Throwing error in `on_listen_error` callback is logged as error, * but ignored. - * + *

* Should be non-blocking. * * @param addr the address that failed to listen on. @@ -101,10 +88,10 @@ namespace socket_manager { /** * Called when an error occurs when connecting to the given address. * - * # Error handling + *

Error handling

* Throwing error in `on_connect_error` callback is logged as error, * but ignored. - * + *

* Should be non-blocking. * * @param addr the address that failed to connect to. diff --git a/include/socket_manager/connection.h b/include/socket_manager/connection.h index b5c2f22..04229fc 100644 --- a/include/socket_manager/connection.h +++ b/include/socket_manager/connection.h @@ -4,49 +4,52 @@ #include #include #include +#include #include "msg_receiver.h" #include "msg_sender.h" #include "socket_manager_c_api.h" namespace socket_manager { - static unsigned long long DEFAULT_WRITE_FLUSH_MILLI_SEC = 5; // 5 millisecond - static unsigned long long DEFAULT_READ_MSG_FLUSH_MILLI_SEC = 5; // 5 millisecond - static size_t DEFAULT_MSG_BUF_SIZE = 64 * 1024; // 64KB + const unsigned long long DEFAULT_WRITE_FLUSH_MILLI_SEC = 5; // 5 millisecond + const unsigned long long DEFAULT_READ_MSG_FLUSH_MILLI_SEC = 5; // 5 millisecond + const size_t DEFAULT_MSG_BUF_SIZE = 64 * 1024; // 64KB class MsgSender; - class Waker; + class Notifier; + + class NoopNotifier; /** * Use Connection to send and receive messages from * established connections. */ - class Connection : public std::enable_shared_from_this { + class Connection { public: /** * Start a connection. * - * # Start / Close + *

Start / Close

* Exactly one of `start` or `close` should be called! * Calling more than once will throw runtime exception. * Not calling any of them might result in resource leak. * - * # Close started connection + *

Close started connection

* Drop the returned MsgSender object to close the connection * after starting it. * - * # Thread Safety + *

Thread Safety

* Thread safe, but should be called exactly once, * otherwise throws error. * - * To close the connection, drop the returned - * MsgSender object. - * - * @param msg_receiver the message receiver callback to - * receive messages from the peer. + * @param msg_receiver the message receiver callback to receive + * messages from the peer. Non-null. + * @param send_notifier the notifier for getting notified when the + * send buffer is ready. Pass nullptr to use a noop notifier. + * This parameter is needed only for async sending. * @param msg_buffer_size The size of the message buffer in bytes. * Set to 0 to use no buffer (i.e., call `on_msg` immediately on receiving * any data, expecting the user to implement buffer if needed). @@ -60,45 +63,45 @@ namespace socket_manager { * manual flush, and small messages might get stuck in buffer). * Default to 1 millisecond. */ - std::shared_ptr start( - std::unique_ptr msg_receiver, + void start( + std::shared_ptr msg_receiver, + std::shared_ptr send_notifier = nullptr, size_t msg_buffer_size = DEFAULT_MSG_BUF_SIZE, unsigned long long read_msg_flush_interval = DEFAULT_READ_MSG_FLUSH_MILLI_SEC, unsigned long long write_flush_interval = DEFAULT_WRITE_FLUSH_MILLI_SEC); /** * Close the connection without using it. - * + *

* `on_connection_close` callback will be called. * - * # Start / Close + *

Start / Close

* Exactly one of `start` or `close` should be called! * Calling more than once will throw runtime exception. * Not calling any of them might result in resource leak. */ void close(); - Connection(const Connection &) = delete; - - void operator=(const Connection &) = delete; - - ~Connection(); - private: friend class MsgSender; - friend char* ::socket_manager_extern_on_conn(struct OnConnObj this_, ConnStates conn); + friend void::socket_manager_extern_on_conn( + struct SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates conn, + char **err); // keep the msg_receiver alive - std::unique_ptr receiver; + std::shared_ptr receiver; - // keep the waker alive - std::shared_ptr waker; + // keep the notifier alive + std::shared_ptr notifier; - explicit Connection(CConnection *inner); + explicit Connection(SOCKET_MANAGER_C_API_Connection *inner); - CConnection *inner; + std::unique_ptr< + SOCKET_MANAGER_C_API_Connection, + std::function> inner; }; diff --git a/include/socket_manager/msg_receiver.h b/include/socket_manager/msg_receiver.h index af15f45..c121a89 100644 --- a/include/socket_manager/msg_receiver.h +++ b/include/socket_manager/msg_receiver.h @@ -1,46 +1,106 @@ #ifndef SOCKET_MANAGER_MSG_RECEIVER_H #define SOCKET_MANAGER_MSG_RECEIVER_H -#include +#include #include #include #include #include #include "socket_manager_c_api.h" +#include "waker.h" namespace socket_manager { /** * Implement this class to receive messages from Connection. + *

+ * Must read the following details to implement correctly! * - * # Thread Safety + *

Asynchronous Message Receiving

+ * The caller should return the exact number of bytes written + * to the runtime if some bytes are written. The runtime + * will increment the read offset accordingly. + *

+ * If the caller is unable to receive any bytes, + * it should return `PENDING = -1` to the runtime + * to interrupt message receiving task. The read offset + * will not be incremented. + *

+ * When the caller is able to receive bytes again, + * it should call `waker.wake()` to wake up the runtime. + * + *

Thread Safety

* The callback should be thread safe. */ - class MsgReceiver { + class MsgReceiverAsync { public: - virtual ~MsgReceiver() = default; + virtual ~MsgReceiverAsync() = default; private: /** * Called when a message is received. * - * # Thread Safety + *

Asynchronous Message Receiving

+ * The caller should return the exact number of bytes written + * to the runtime if some bytes are written. The runtime + * will increment the read offset accordingly. + *

+ * If the caller is unable to receive any bytes, + * it should return `PENDING = -1` to the runtime + * to interrupt message receiving task. The read offset + * will not be incremented. + *

+ * When the caller is able to receive bytes again, + * it should call `waker.wake()` to wake up the runtime. + * + *

MEMORY SAFETY

+ * The `data` is only valid during the call of this function. + * If you want to keep the data, you should copy it. + * + *

Thread Safety

* This callback must be thread safe. * It should also be non-blocking. * - * # Error Handling - * Throwing error in `on_message` callback will cause + *

Error Handling

+ * Throwing runtime_error in `on_message` callback will cause * the connection to close. * * @param data the message received. */ - virtual void on_message(const std::shared_ptr &data) = 0; + virtual long on_message_async(std::string_view data, Waker &&waker) = 0; + + friend long::socket_manager_extern_on_msg( + struct SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, + SOCKET_MANAGER_C_API_CWaker waker, + char **err); + + }; + + /** + * If the caller has unlimited buffer implementation, + * it can use this simplified class to receive messages. + *

+ * The caller should implement `on_message` method to + * store the received message in buffer or queue and immediately + * return `on_message` method, and should not block the runtime. + * + *

Thread Safety

+ * This callback must be thread safe. + */ + class MsgReceiver : public MsgReceiverAsync { + public: + + ~MsgReceiver() override = default; + + private: - friend char* ::socket_manager_extern_on_msg(struct OnMsgObj this_, ConnMsg msg); + virtual void on_message(std::string_view data) = 0; + long on_message_async(std::string_view data, Waker &&waker) override; }; } // namespace socket_manager diff --git a/include/socket_manager/msg_sender.h b/include/socket_manager/msg_sender.h index 7d32db9..7702b95 100644 --- a/include/socket_manager/msg_sender.h +++ b/include/socket_manager/msg_sender.h @@ -3,6 +3,8 @@ #include "socket_manager_c_api.h" #include "connection.h" +#include "notifier.h" +#include #include #include #include @@ -11,112 +13,99 @@ namespace socket_manager { class Connection; - /** - * Used for receiving writable notification for - * `try_send` method. - * - * Each `try_test()` call releases the waker, - * when `wake()` is actually invoked - * (i.e., the number of calls of `release` and `clone` - * are equal). - */ - class Waker { - - public: - virtual ~Waker() = default; - - private: - - virtual void wake() = 0; - - virtual void release() = 0; - - virtual void clone() = 0; - - friend void ::socket_manager_extern_sender_waker_wake(struct WakerObj this_); - - friend void ::socket_manager_extern_sender_waker_release(struct WakerObj this_); - - friend void ::socket_manager_extern_sender_waker_clone(struct WakerObj this_); - }; - /** * Use MsgSender to send messages to the peer. - * + *

* Drop the MsgSender object to close the connection. */ class MsgSender { public: - /** - * Send a message to the peer. + * Asynchronous message sending. + *

+ * To use the method, the user should pass a `notifier` object + * when calling `Connection::start()` in order to receive notification + * when the send buffer is ready. * - * # Blocking!! - * This method might block, so it should - * never be used within the callbacks. - * - * # Thread Safety - * This method is thread safe. - * This method does not implement backpressure - * (i.e., it caches all the messages in memory). + *

Async control flow (IMPORTANT)

+ * This function is non-blocking, it returns `PENDING = -1` + * if the send buffer is full. So the caller should wait + * by passing a `Notifier` which will be called when the + * buffer is ready. + *

+ * When the buffer is ready, the function returns number of bytes sent. + *

+ * The caller is responsible for updating the buffer offset!! * - * # Errors - * This method throws std::runtime_error when - * the connection is closed. + * @param data the message to send + * @param notifier `notifier.wake()` is evoked when send_async + * could accept more data. + * @return return the number of bytes successfully sent, + * or return `PENDING = -1` if the send buffer is full. + * @throws std::runtime_error when the connection is closed. + */ + long send_async(std::string_view data); + + /** + * Non-blocking message sending (NO BACKPRESSURE). + *

+ * This method is non-blocking. It returns immediately and + * caches all the data in the internal buffer. So it comes + * without back pressure. * * @param data the message to send + * @param notifier `notifier.wake()` is evoked when send_async + * could accept more data. + * @throws std::runtime_error when the connection is closed. */ - void send(const std::string &data); + void send_nonblock(std::string_view data); /** - * Non blocking message sending. + * Blocking message sending (DO NOT USE IN ASYNC CALLBACK). + * + *

Blocking!!

+ * This method might block, so it should never be used within any of + * the async callbacks. Otherwise the runtime can panic! * - * DO NOT USE THIS METHOD in mixture with `send` method. - * Since send method is blocking, it preserves the order, - * while this method must be used with the waker class. + *

Thread Safety

+ * This method is thread safe. + * + *

Errors

+ * This method throws std::runtime_error when + * the connection is closed. * * @param data the message to send - * @param offset the offset of the message to send. - * That is data[offset..] is the message to send. - * Increment the offset based on the return value. - * @param waker `waker.wake()` is evoked when try_send - * could accept more data. Pass nullptr to disable wake - * notification. - * @return If waker is provided, returns the number of bytes sent on success, - * and 0 on connection closed, -1 on pending. - * If waker is not provided, returns the number of bytes sent. - * 0 might indicate the connection is closed, or the message buffer is full. + * @throws std::runtime_error when the connection is closed. */ - long try_send(const std::string &data, size_t offset, const std::shared_ptr &waker = nullptr); + void send_block(std::string_view data); /** * Manually flush the internal buffer. * - * # Thread Safety + *

Thread Safety

* This method is thread safe. * */ void flush(); - /** - * Drop the sender to close the connection. - */ - ~MsgSender(); - - MsgSender(const MsgSender &) = delete; - - void operator=(const MsgSender &) = delete; - private: friend class Connection; - explicit MsgSender(CMsgSender *inner, const std::shared_ptr&); + friend void::socket_manager_extern_on_conn( + struct SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates conn, + char **err); + + explicit MsgSender(SOCKET_MANAGER_C_API_MsgSender *inner, const std::shared_ptr &); + // keep a reference of connection for storing waker object + // in connection, to prevent dangling pointer of waker. std::shared_ptr conn; - CMsgSender *inner; + std::unique_ptr> inner; }; diff --git a/include/socket_manager/socket_manager.h b/include/socket_manager/socket_manager.h index 7b24a30..5190071 100644 --- a/include/socket_manager/socket_manager.h +++ b/include/socket_manager/socket_manager.h @@ -3,20 +3,58 @@ #include "conn_callback.h" #include "socket_manager_c_api.h" +#include #include #include namespace socket_manager { /** - * Manages a set of sockets. + * @brief Manages a set of sockets. * - * Inherit from SocketManager and implement the virtual methods: - * `on_connect`, `on_connection_close`, `on_listen_error`, `on_connect_error`, - * `on_message` as callbacks. Note that these callbacks shouldn't block. + *

Usage

+ * The user needs to implement a `ConnCallback` object to handle + * connection events, and pass it to the constructor of `SocketManager`. + *

+ * When the connection is established, the `on_connect` callback + * returns a `MsgSender` object for sending messages to the peer, + * and a `Connection` object for receiving messages from the peer. + *

+ * To receive messages from the peer, the user needs to implement + * a `MsgReceiver` object and pass it to the `Connection::start` + * method. * - * Dropping this object will close all the connections and wait for all the - * threads to finish. + *

Memory Management

+ * The system internally use shared pointers to manage the lifetime + * of the objects. Note the following dependency relationship to avoid + * memory leak (i.e., circular reference): + *
    + *
  • `SocketManager` ---strong ref--> `ConnCallback`
  • + *
  • `ConnCallback` ---strong ref--> (active) `Connection`s (drop on `connection_close`)
  • + *
  • `Connection` ---strong ref--> `Notifier`
  • + *
  • `Connection` ---strong ref--> `(Async)MsgReceiver`
  • + *
  • `MsgSender` ---strong ref--> `Connection`
  • + *
+ * Notice that if `MsgSender` is strongly referenced by `Notifier`, + * or strongly referenced by `(Async)MsgReceiver`, then the connection + * will have a memory leak. The user could `reset()` the `shared_ptr\` + * on connection close event, and thus break the cycle. + *

+ * In short, the user must guarantee that the `MsgSender` object + * is released for connection resources to be properly released. + * + *

Note on lifetime:

+ * + *
    + *
  • The `connection callback` object should have + * a longer lifetime than the socket manager.
  • + * + *
  • The `msg receiver` should live as long as + * connection is not closed.
  • + * + *
  • The `Notifier` object should live as long as + * connection is not closed.
  • + *
*/ class SocketManager { @@ -34,10 +72,10 @@ namespace socket_manager { /** * Listen on the given address. * - * # Thread Safety + *

Thread Safety

* Thread safe. * - * # Errors + *

Errors

* Throws `std::runtime_error` if socket manager runtime has been aborted. * Throws `std::runtime_error` if the address is invalid. * @@ -48,10 +86,10 @@ namespace socket_manager { /** * Connect to the given address. * - * # Thread Safety + *

Thread Safety

* Thread safe. * - * # Errors + *

Errors

* Throws `std::runtime_error` if socket manager runtime has been aborted. * Throws `std::runtime_error` if the address is invalid. * @@ -62,10 +100,10 @@ namespace socket_manager { /** * Cancel listening on the given address. * - * # Thread Safety + *

Thread Safety

* Thread safe. * - * # Errors + *

Errors

* Throws `std::runtime_error` if socket manager runtime has been aborted. * Throw `std::runtime_error` if the address is invalid. * @@ -75,17 +113,17 @@ namespace socket_manager { /** * Stop all background threads and drop all connections. - * + *

* Calling a second time will return immediately (if `wait = false`). * - * # Argument + *

Argument

* - `wait`: if true, wait for all the background threads to finish. * Default to true. * - * # Thread Safety + *

Thread Safety

* Thread safe. * - * # Errors + *

Errors

* Throws `std::runtime_error` if `wait = true` and the background * thread panicked. */ @@ -93,26 +131,21 @@ namespace socket_manager { /** * Join and wait on the `SocketManager` background runtime. - * + *

* Returns immediately on the second call. * - * # Thread Safety + *

Thread Safety

* Thread safe. * - * # Errors + *

Errors

* Throws `std::runtime_error` if the background runtime panicked. */ void join(); - ~SocketManager(); - - SocketManager(const SocketManager &) = delete; - - void operator=(const SocketManager &) = delete; - private: - CSocketManager *inner; + std::unique_ptr> inner; std::shared_ptr conn_cb; }; diff --git a/include/socket_manager_c_api.h b/include/socket_manager_c_api.h index 4e9328a..523a100 100644 --- a/include/socket_manager_c_api.h +++ b/include/socket_manager_c_api.h @@ -1,32 +1,73 @@ #ifndef SOCKET_MANAGER_C_API_H #define SOCKET_MANAGER_C_API_H -#include -#include -#include -#include -#include - -typedef enum ConnStateCode { +#include +#include +#include +#include +#include +#include + +enum class SOCKET_MANAGER_C_API_ConnStateCode { Connect = 0, ConnectionClose = 1, ListenError = 2, ConnectError = 3, -} ConnStateCode; +}; -typedef struct CConnection CConnection; +struct SOCKET_MANAGER_C_API_Connection; /** * Drop the sender to close the connection. */ -typedef struct CMsgSender CMsgSender; +struct SOCKET_MANAGER_C_API_MsgSender; /** * The Main Struct of the Library. * * This struct is thread safe. */ -typedef struct CSocketManager CSocketManager; +struct SOCKET_MANAGER_C_API_SocketManager; + +/** + * The Notifier is constructed by the c/c++ code, + * and passed to the rust code. + * + * # Task Resume + * When `wake` callback is called by rust, the c/c++ task + * should resume its execution. + * + * # Lifetime Management. + * The Notifier has `clone` and `release` callbacks. + * Say a Notifier start with ref_count = 1, + * and when `clone` is called, increment its ref_count, + * and when `release` is called, decrement its ref_count. + * + * The notifier can be released when its ref_count falls back to 1. + * + * The c/c++ code must carefully manage the lifetime of the waker. + * to ensure that the waker is not dropped before the rust code + * is done with it. + */ +struct SOCKET_MANAGER_C_API_Notifier { + void *This; +}; + +/** + * # Safety + * Do not use this struct directly. + * Properly wrap it in c++ class. + * + * This struct is equivalent to a raw pointer. + * Manager with care. + * + * Note that the CWaker must be properly dropped. + * Otherwise, the associated task will leak. + */ +struct SOCKET_MANAGER_C_API_CWaker { + const void *Data; + const void *Vtable; +}; /** * Callback function for receiving messages. @@ -44,17 +85,9 @@ typedef struct CSocketManager CSocketManager; * # Thread Safety * Must be thread safe! */ -typedef struct OnMsgObj { +struct SOCKET_MANAGER_C_API_OnMsgObj { void *This; -} OnMsgObj; - -/** - * The data pointer is only valid for the duration of the callback. - */ -typedef struct ConnMsg { - const char *Bytes; - size_t Len; -} ConnMsg; +}; /** * Callback function for connection state changes. @@ -73,37 +106,38 @@ typedef struct ConnMsg { * # Thread Safety * Must be thread safe! */ -typedef struct OnConnObj { +struct SOCKET_MANAGER_C_API_OnConnObj { void *This; -} OnConnObj; +}; -typedef struct OnConnect { +struct SOCKET_MANAGER_C_API_OnConnect { const char *Local; const char *Peer; - struct CConnection *Conn; -} OnConnect; + SOCKET_MANAGER_C_API_MsgSender *Send; + SOCKET_MANAGER_C_API_Connection *Conn; +}; -typedef struct OnConnectionClose { +struct SOCKET_MANAGER_C_API_OnConnectionClose { const char *Local; const char *Peer; -} OnConnectionClose; +}; -typedef struct OnListenError { +struct SOCKET_MANAGER_C_API_OnListenError { const char *Addr; const char *Err; -} OnListenError; +}; -typedef struct OnConnectError { +struct SOCKET_MANAGER_C_API_OnConnectError { const char *Addr; const char *Err; -} OnConnectError; +}; -typedef union ConnStateData { - struct OnConnect OnConnect; - struct OnConnectionClose OnConnectionClose; - struct OnListenError OnListenError; - struct OnConnectError OnConnectError; -} ConnStateData; +union SOCKET_MANAGER_C_API_ConnStateData { + SOCKET_MANAGER_C_API_OnConnect OnConnect; + SOCKET_MANAGER_C_API_OnConnectionClose OnConnectionClose; + SOCKET_MANAGER_C_API_OnListenError OnListenError; + SOCKET_MANAGER_C_API_OnConnectError OnConnectError; +}; /** * All data is only valid for the duration of the callback @@ -111,25 +145,38 @@ typedef union ConnStateData { * * Do not manually free any of the data except `sender`!! */ -typedef struct ConnStates { - enum ConnStateCode Code; - union ConnStateData Data; -} ConnStates; +struct SOCKET_MANAGER_C_API_ConnStates { + SOCKET_MANAGER_C_API_ConnStateCode Code; + SOCKET_MANAGER_C_API_ConnStateData Data; +}; /** - * Send the msg sender obj to receive - * writable notification. + * The data pointer is only valid for the duration of the callback. */ -typedef struct WakerObj { - void *This; -} WakerObj; +struct SOCKET_MANAGER_C_API_ConnMsg { + const char *Bytes; + size_t Len; +}; -#ifdef __cplusplus extern "C" { -#endif // __cplusplus /** - * Start a connection with the given `OnMsgCallback`, and return a pointer to a `CMsgSender`. + * Waker for the try_send method. + */ +extern void socket_manager_extern_notifier_wake(SOCKET_MANAGER_C_API_Notifier this_); + +/** + * Call the waker to wake the relevant task of context. + */ +void socket_manager_waker_wake(const SOCKET_MANAGER_C_API_CWaker *waker); + +/** + * Release the waker. + */ +void socket_manager_waker_free(SOCKET_MANAGER_C_API_CWaker waker); + +/** + * Start a connection with the given `OnMsgCallback`, and return a pointer to a `MsgSender`. * * Only one of `connection_start` or `connection_close` should be called, * or it will result in runtime error. @@ -154,15 +201,12 @@ extern "C" { * Set to 0 to disable auto flush. * * `err` - A pointer to a pointer to a C string allocated by `malloc` on error. * - * # Returns - * A pointer to a `CMsgSender` on success, null on error. - * * # Errors + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`, - * and the returned pointer will be null. */ -struct CMsgSender *connection_start(struct CConnection *conn, - struct OnMsgObj on_msg, +int socket_manager_connection_start(SOCKET_MANAGER_C_API_Connection *conn, + SOCKET_MANAGER_C_API_OnMsgObj on_msg, size_t msg_buffer_size, unsigned long long read_msg_flush_interval, unsigned long long write_flush_interval, @@ -178,84 +222,87 @@ struct CMsgSender *connection_start(struct CConnection *conn, * Thread safe. * * # Errors - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int connection_close(struct CConnection *conn, char **err); +int socket_manager_connection_close(SOCKET_MANAGER_C_API_Connection *conn, char **err); /** * Destructor of `Connection`. */ -void connection_free(struct CConnection *conn); +void socket_manager_connection_free(SOCKET_MANAGER_C_API_Connection *conn); /** - * Callback function for receiving messages. - */ -extern char *socket_manager_extern_on_msg(struct OnMsgObj this_, struct ConnMsg msg); - -/** - * Callback function for connection state changes. - */ -extern char *socket_manager_extern_on_conn(struct OnConnObj this_, struct ConnStates conn); - -/** - * Waker for the try_send method. - */ -extern void socket_manager_extern_sender_waker_wake(struct WakerObj this_); - -/** - * Decrement ref count of the waker. - */ -extern void socket_manager_extern_sender_waker_release(struct WakerObj this_); - -/** - * Increment ref count of the waker. + * Send a message via the given `MsgSender` synchronously. + * This is a blocking API. + * + * # Thread Safety + * Thread safe. + * + * This function should never be called within the context of the async callbacks + * since it might block. + * + * # Errors + * If the connection is closed, the function will return 1 and set `err` to a pointer + * with WriteZero error. + * + * Returns 1 on error, 0 on success. + * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -extern void socket_manager_extern_sender_waker_clone(struct WakerObj this_); +int socket_manager_msg_sender_send_block(SOCKET_MANAGER_C_API_MsgSender *sender, + const char *msg, + size_t len, + char **err); /** - * Send a message via the given `CMsgSender`. + * Send a message via the given `MsgSender` . + * This is a non-blocking API. + * All sent data is buffered in a chain of ring buffer. + * This method does not implement back pressure since it + * caches all received data. + * Use `send_async` or `send_block` for back pressure. * * # Thread Safety * Thread safe. * - * This function should never be called within the context of the async callbacks - * since it might block. + * This function can be called within the context of the async callbacks. * * # Errors - * If the connection is closed, the function will return -1 and set `err` to a pointer + * If the connection is closed, the function will return 1 and set `err` to a pointer * with WriteZero error. * - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int msg_sender_send(struct CMsgSender *sender, const char *msg, size_t len, char **err); +int socket_manager_msg_sender_send_nonblock(SOCKET_MANAGER_C_API_MsgSender *sender, + const char *msg, + size_t len, + char **err); /** - * Try to send a message via the given `CMsgSender`. + * Try to send a message via the given `MsgSender` asynchronously. * * # Thread Safety * Thread safe. * - * This function is non-blocking, pass the MsgSender class - * to the waker_obj to receive notification to continue - * sending the message. + * # Async control flow (IMPORTANT) * - * # Return - * If waker is provided, returns the number of bytes sent on success, - * and 0 on connection closed, -1 on pending. + * This function is non-blocking, it returns `PENDING = -1` + * if the send buffer is full. So the caller should wait + * by passing a `Notifier` which will be called when the + * buffer is ready. * - * If waker is not provided, returns the number of bytes sent. - * 0 might indicate the connection is closed, or the message buffer is full. + * When the buffer is ready, the function returns number of bytes sent. * * # Errors + * Use `err` pointer to check for error. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -long msg_sender_try_send(struct CMsgSender *sender, - const char *msg, - size_t len, - struct WakerObj waker_obj, - char **err); +long socket_manager_msg_sender_send_async(SOCKET_MANAGER_C_API_MsgSender *sender, + const char *msg, + size_t len, + SOCKET_MANAGER_C_API_Notifier notifier, + char **err); /** * Manually flush the message sender. @@ -264,16 +311,51 @@ long msg_sender_try_send(struct CMsgSender *sender, * Thread safe. * * # Errors - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int msg_sender_flush(struct CMsgSender *sender, char **err); +int socket_manager_msg_sender_flush(SOCKET_MANAGER_C_API_MsgSender *sender, char **err); /** * Destructor of `MsgSender`. * Drop sender to actively close the connection. */ -void msg_sender_free(struct CMsgSender *sender); +void socket_manager_msg_sender_free(SOCKET_MANAGER_C_API_MsgSender *sender); + +/** + * Rust calls this function to send `conn: ConnStates` + * to the `this: OnConnObj`. If the process has any error, + * pass error to `err` pointer. + * Set `err` to null_ptr if there is no error. + */ +extern void socket_manager_extern_on_conn(SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates conn, + char **err); + +/** + * Rust calls this function to send `msg: ConnMsg` + * to `OnMsgObj`. If the process has any error, + * pass error to `err` pointer. + * Set `err` to null_ptr if there is no error. + * + * # Async control flow (IMPORTANT) + * + * The caller should return the exact number of bytes written + * to the runtime if some bytes are written. The runtime + * will increment the read offset accordingly. + * + * If the caller is unable to receive any bytes, + * it should return `PENDING = -1` to the runtime + * to interrupt message receiving task. The read offset + * will not be incremented. + * + * When the caller is able to receive bytes again, + * it should call `waker.wake()` to wake up the runtime. + */ +extern long socket_manager_extern_on_msg(SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, + SOCKET_MANAGER_C_API_CWaker waker, + char **err); /** * Initialize a new `SocketManager` and return a pointer to it. @@ -296,7 +378,9 @@ void msg_sender_free(struct CMsgSender *sender); * On Error, `err` will be set to a pointer to a C string allocated by `malloc`, * and the returned pointer will be null. */ -struct CSocketManager *socket_manager_init(struct OnConnObj on_conn, size_t n_threads, char **err); +SOCKET_MANAGER_C_API_SocketManager *socket_manager_init(SOCKET_MANAGER_C_API_OnConnObj on_conn, + size_t n_threads, + char **err); /** * Listen on the given address. @@ -305,10 +389,12 @@ struct CSocketManager *socket_manager_init(struct OnConnObj on_conn, size_t n_th * Thread safe. * * # Errors - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int socket_manager_listen_on_addr(struct CSocketManager *manager, const char *addr, char **err); +int socket_manager_listen_on_addr(SOCKET_MANAGER_C_API_SocketManager *manager, + const char *addr, + char **err); /** * Connect to the given address. @@ -317,10 +403,12 @@ int socket_manager_listen_on_addr(struct CSocketManager *manager, const char *ad * Thread safe. * * # Errors - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int socket_manager_connect_to_addr(struct CSocketManager *manager, const char *addr, char **err); +int socket_manager_connect_to_addr(SOCKET_MANAGER_C_API_SocketManager *manager, + const char *addr, + char **err); /** * Cancel listening on the given address. @@ -329,10 +417,10 @@ int socket_manager_connect_to_addr(struct CSocketManager *manager, const char *a * Thread safe. * * # Errors - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int socket_manager_cancel_listen_on_addr(struct CSocketManager *manager, +int socket_manager_cancel_listen_on_addr(SOCKET_MANAGER_C_API_SocketManager *manager, const char *addr, char **err); @@ -346,10 +434,10 @@ int socket_manager_cancel_listen_on_addr(struct CSocketManager *manager, * - `wait`: if true, wait for the background runtime to finish. * * # Errors - * Returns -1 on error, 0 on success. + * Returns 1 on error, 0 on success. * On Error, `err` will be set to a pointer to a C string allocated by `malloc`. */ -int socket_manager_abort(struct CSocketManager *manager, bool wait, char **err); +int socket_manager_abort(SOCKET_MANAGER_C_API_SocketManager *manager, bool wait, char **err); /** * Join and wait on the `SocketManager`. @@ -363,16 +451,14 @@ int socket_manager_abort(struct CSocketManager *manager, bool wait, char **err); * # Errors * Join returns error if the runtime panicked. */ -int socket_manager_join(struct CSocketManager *manager, char **err); +int socket_manager_join(SOCKET_MANAGER_C_API_SocketManager *manager, char **err); /** * Calling this function will abort all background runtime and join on them, * and free the `SocketManager`. */ -void socket_manager_free(struct CSocketManager *manager); +void socket_manager_free(SOCKET_MANAGER_C_API_SocketManager *manager); -#ifdef __cplusplus } // extern "C" -#endif // __cplusplus -#endif /* SOCKET_MANAGER_C_API_H */ +#endif // SOCKET_MANAGER_C_API_H diff --git a/justfile b/justfile index 9eb86d0..52c5c3a 100644 --- a/justfile +++ b/justfile @@ -1,6 +1,12 @@ cbind: cbindgen -q --config cbindgen.toml --crate tokio-socket-manager --output include/socket_manager_c_api.h +fmt: + cargo fmt + +clippy: + cargo clippy + clean: rm -rf build diff --git a/socket_manager/CMakeLists.txt b/socket_manager/CMakeLists.txt index 415735b..ac9325d 100644 --- a/socket_manager/CMakeLists.txt +++ b/socket_manager/CMakeLists.txt @@ -30,16 +30,22 @@ set(header ${CMAKE_SOURCE_DIR}/include/socket_manager_c_api.h ${header_path}/connection.h ${header_path}/conn_callback.h ${header_path}/msg_sender.h + ${header_path}/common/waker.h + ${header_path}/common/notifier.h ${header_path}/msg_receiver.h) set(src socket_manager_c_api.cc msg_sender.cc connection.cc + waker.cc + msg_receiver.cc socket_manager.cc) target_sources(${PROJECT_NAME} PRIVATE ${src}) target_include_directories(${PROJECT_NAME} PUBLIC $ + $ + $ $) set_target_properties(${PROJECT_NAME} PROPERTIES diff --git a/socket_manager/connection.cc b/socket_manager/connection.cc index c034b11..6ffdc85 100644 --- a/socket_manager/connection.cc +++ b/socket_manager/connection.cc @@ -3,44 +3,49 @@ namespace socket_manager { - Connection::Connection(CConnection *inner) : inner(inner) {} - - std::shared_ptr Connection::start( - std::unique_ptr msg_receiver, + Connection::Connection(SOCKET_MANAGER_C_API_Connection *inner) + : notifier(std::make_shared()), + inner(inner, + [](SOCKET_MANAGER_C_API_Connection *ptr) { + socket_manager_connection_free(ptr); + }) {} + + void Connection::start( + std::shared_ptr msg_receiver, + std::shared_ptr send_notifier, size_t msg_buffer_size, unsigned long long read_msg_flush_interval, unsigned long long write_flush_interval) { + if (msg_receiver == nullptr) { + throw std::runtime_error("msg_receiver should not be nullptr"); + } + // keep the msg_receiver alive. + this->receiver = std::move(msg_receiver); + // keep the notifier alive. + if (send_notifier != nullptr) { + this->notifier = std::move(send_notifier); + } + // start the connection. // calling twice `connection_start` will throw exception. char *err = nullptr; - CMsgSender *sender = connection_start(inner, OnMsgObj{ - msg_receiver.get(), - }, msg_buffer_size, read_msg_flush_interval, write_flush_interval, &err); - if (sender == nullptr) { + if (socket_manager_connection_start(inner.get(), SOCKET_MANAGER_C_API_OnMsgObj{ + this->receiver.get(), + }, msg_buffer_size, read_msg_flush_interval, write_flush_interval, &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); } - - // keep the msg_receiver alive. - receiver = std::move(msg_receiver); - - // return the sender - return std::shared_ptr(new MsgSender(sender, shared_from_this())); } void Connection::close() { char *err = nullptr; - if (connection_close(inner, &err)) { + if (socket_manager_connection_close(inner.get(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); } } - Connection::~Connection() { - connection_free(inner); - } - } // namespace socket_manager diff --git a/socket_manager/msg_receiver.cc b/socket_manager/msg_receiver.cc new file mode 100644 index 0000000..bcd2434 --- /dev/null +++ b/socket_manager/msg_receiver.cc @@ -0,0 +1,8 @@ +#include "socket_manager/msg_receiver.h" + +namespace socket_manager { + long MsgReceiver::on_message_async(std::string_view data, Waker &&waker) { + on_message(data); + return (long) data.length(); + } +} diff --git a/socket_manager/msg_sender.cc b/socket_manager/msg_sender.cc index 6664abc..3abf5da 100644 --- a/socket_manager/msg_sender.cc +++ b/socket_manager/msg_sender.cc @@ -3,51 +3,56 @@ namespace socket_manager { - void MsgSender::send(const std::string &data) { + void MsgSender::send_block(std::string_view data) { char *err = nullptr; - if (msg_sender_send(inner, data.data(), data.length(), &err)) { + if (socket_manager_msg_sender_send_block( + inner.get(), data.data(), data.length(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); } } - long MsgSender::try_send(const std::string &data, size_t offset, const std::shared_ptr &waker) { - // check length - if (offset >= data.length()) { - throw std::runtime_error("offset >= data.length()"); + void MsgSender::send_nonblock(std::string_view data) { + char *err = nullptr; + if (socket_manager_msg_sender_send_nonblock( + inner.get(), data.data(), data.length(), &err)) { + const std::string err_str(err); + free(err); + throw std::runtime_error(err_str); } + } + + long MsgSender::send_async(std::string_view data) { char *err = nullptr; - long n = msg_sender_try_send( - inner, - data.data() + offset, - data.length() - offset, - WakerObj{waker.get()}, + long n = socket_manager_msg_sender_send_async( + inner.get(), + data.data(), + data.length(), + SOCKET_MANAGER_C_API_Notifier{conn->notifier.get()}, &err); if (err) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); } - // keep waker alive - conn->waker = waker; return n; } void MsgSender::flush() { char *err = nullptr; - if (msg_sender_flush(inner, &err)) { + if (socket_manager_msg_sender_flush(inner.get(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); } } - MsgSender::MsgSender(CMsgSender *inner, const std::shared_ptr &conn) - : conn(conn), inner(inner) {} - - MsgSender::~MsgSender() { - msg_sender_free(inner); - } + MsgSender::MsgSender(SOCKET_MANAGER_C_API_MsgSender *inner, const std::shared_ptr &conn) + : conn(conn), + inner(inner, + [](SOCKET_MANAGER_C_API_MsgSender *ptr) { + socket_manager_msg_sender_free(ptr); + }) {} } // namespace socket_manager diff --git a/socket_manager/socket_manager.cc b/socket_manager/socket_manager.cc index c9a328f..341f1f9 100644 --- a/socket_manager/socket_manager.cc +++ b/socket_manager/socket_manager.cc @@ -2,12 +2,17 @@ namespace socket_manager { - SocketManager::SocketManager(const std::shared_ptr &conn_cb, size_t n_threads) : conn_cb( - conn_cb) { + SocketManager::SocketManager(const std::shared_ptr &conn_cb, size_t n_threads) + : conn_cb(conn_cb) { char *err = nullptr; - inner = socket_manager_init(OnConnObj{ + auto inner_ptr = socket_manager_init(SOCKET_MANAGER_C_API_OnConnObj{ conn_cb.get() }, n_threads, &err); + inner = std::unique_ptr>( + inner_ptr, + [](SOCKET_MANAGER_C_API_SocketManager *ptr) { socket_manager_free(ptr); } + ); if (err) { const std::string err_str(err); free(err); @@ -17,7 +22,7 @@ namespace socket_manager { void SocketManager::listen_on_addr(const std::string &addr) { char *err = nullptr; - if (socket_manager_listen_on_addr(inner, addr.c_str(), &err)) { + if (socket_manager_listen_on_addr(inner.get(), addr.c_str(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); @@ -26,7 +31,7 @@ namespace socket_manager { void SocketManager::connect_to_addr(const std::string &addr) { char *err = nullptr; - if (socket_manager_connect_to_addr(inner, addr.c_str(), &err)) { + if (socket_manager_connect_to_addr(inner.get(), addr.c_str(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); @@ -35,7 +40,7 @@ namespace socket_manager { void SocketManager::cancel_listen_on_addr(const std::string &addr) { char *err = nullptr; - if (socket_manager_cancel_listen_on_addr(inner, addr.c_str(), &err)) { + if (socket_manager_cancel_listen_on_addr(inner.get(), addr.c_str(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); @@ -44,7 +49,7 @@ namespace socket_manager { void SocketManager::abort(bool wait) { char *err = nullptr; - if (socket_manager_abort(inner, wait, &err)) { + if (socket_manager_abort(inner.get(), wait, &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); @@ -53,15 +58,11 @@ namespace socket_manager { void SocketManager::join() { char *err = nullptr; - if (socket_manager_join(inner, &err)) { + if (socket_manager_join(inner.get(), &err)) { const std::string err_str(err); free(err); throw std::runtime_error(err_str); } } - SocketManager::~SocketManager() { - socket_manager_free(inner); - } - } // namespace socket_manager diff --git a/socket_manager/socket_manager_c_api.cc b/socket_manager/socket_manager_c_api.cc index b9c1462..441f611 100644 --- a/socket_manager/socket_manager_c_api.cc +++ b/socket_manager/socket_manager_c_api.cc @@ -1,6 +1,7 @@ #include "socket_manager_c_api.h" #include "socket_manager/msg_receiver.h" #include "socket_manager/conn_callback.h" +#include "socket_manager/common/waker.h" inline static char *string_dup(const std::string &str) { @@ -11,46 +12,48 @@ static char *string_dup(const std::string &str) { } /** - * Waker for the sender. + * RecvWaker for the sender. */ -extern void socket_manager_extern_sender_waker_wake(struct WakerObj this_) { - auto wr = reinterpret_cast(this_.This); +extern "C" void socket_manager_extern_notifier_wake(SOCKET_MANAGER_C_API_Notifier this_) { + auto wr = reinterpret_cast(this_.This); wr->wake(); } -extern void socket_manager_extern_sender_waker_release(struct WakerObj this_) { - auto wr = reinterpret_cast(this_.This); - wr->release(); -} - -extern void socket_manager_extern_sender_waker_clone(struct WakerObj this_) { - auto wr = reinterpret_cast(this_.This); - wr->clone(); -} - -extern char *socket_manager_extern_on_msg(struct OnMsgObj this_, ConnMsg msg) { - auto receiver = reinterpret_cast(this_.This); - auto data_ptr = std::make_shared(msg.Bytes, msg.Len); +extern "C" long socket_manager_extern_on_msg(SOCKET_MANAGER_C_API_OnMsgObj this_, + SOCKET_MANAGER_C_API_ConnMsg msg, + SOCKET_MANAGER_C_API_CWaker waker, + char **err) { + auto receiver = reinterpret_cast(this_.This); try { - receiver->on_message(data_ptr); + auto recv = receiver->on_message_async( + std::string_view(msg.Bytes, msg.Len), + socket_manager::Waker(waker) + ); + *err = nullptr; + return recv; } catch (std::runtime_error &e) { - return string_dup(e.what()); + *err = string_dup(e.what()); + return 0; } catch (...) { - return string_dup("unknown error"); + *err = string_dup("unknown error"); + return 0; } - return nullptr; } -extern char *socket_manager_extern_on_conn(struct OnConnObj this_, ConnStates states) { +extern "C" void socket_manager_extern_on_conn( + SOCKET_MANAGER_C_API_OnConnObj this_, + SOCKET_MANAGER_C_API_ConnStates states, + char **error) { auto conn_cb = static_cast(this_.This); switch (states.Code) { - case ConnStateCode::Connect: { + case SOCKET_MANAGER_C_API_ConnStateCode::Connect: { auto on_connect = states.Data.OnConnect; auto local_addr = std::string(on_connect.Local); auto peer_addr = std::string(on_connect.Peer); std::shared_ptr conn(new socket_manager::Connection(on_connect.Conn)); + std::shared_ptr sender(new socket_manager::MsgSender(on_connect.Send, conn)); // keep the connection alive { @@ -58,15 +61,16 @@ extern char *socket_manager_extern_on_conn(struct OnConnObj this_, ConnStates st conn_cb->conns[local_addr + peer_addr] = conn; } try { - conn_cb->on_connect(local_addr, peer_addr, conn); + conn_cb->on_connect(local_addr, peer_addr, std::move(conn), std::move(sender)); + *error = nullptr; } catch (std::runtime_error &e) { - return string_dup(e.what()); + *error = string_dup(e.what()); } catch (...) { - return string_dup("unknown error"); + *error = string_dup("unknown error"); } - return nullptr; + break; } - case ConnStateCode::ConnectionClose: { + case SOCKET_MANAGER_C_API_ConnStateCode::ConnectionClose: { auto on_connection_close = states.Data.OnConnectionClose; auto local_addr = std::string(on_connection_close.Local); auto peer_addr = std::string(on_connection_close.Peer); @@ -78,42 +82,45 @@ extern char *socket_manager_extern_on_conn(struct OnConnObj this_, ConnStates st } try { conn_cb->on_connection_close(local_addr, peer_addr); + *error = nullptr; } catch (std::runtime_error &e) { - return string_dup(e.what()); + *error = string_dup(e.what()); } catch (...) { - return string_dup("unknown error"); + *error = string_dup("unknown error"); } - return nullptr; + break; } - case ConnStateCode::ListenError: { + case SOCKET_MANAGER_C_API_ConnStateCode::ListenError: { auto listen_error = states.Data.OnListenError; auto addr = std::string(listen_error.Addr); auto err = std::string(listen_error.Err); try { conn_cb->on_listen_error(addr, err); + *error = nullptr; } catch (std::runtime_error &e) { - return string_dup(e.what()); + *error = string_dup(e.what()); } catch (...) { - return string_dup("unknown error"); + *error = string_dup("unknown error"); } - return nullptr; + break; } - case ConnStateCode::ConnectError: { + case SOCKET_MANAGER_C_API_ConnStateCode::ConnectError: { auto connect_error = states.Data.OnConnectError; auto addr = std::string(connect_error.Addr); auto err = std::string(connect_error.Err); try { conn_cb->on_connect_error(addr, err); + *error = nullptr; } catch (std::runtime_error &e) { - return string_dup(e.what()); + *error = string_dup(e.what()); } catch (...) { - return string_dup("unknown error"); + *error = string_dup("unknown error"); } - return nullptr; + break; } default: { // should never reach here - return nullptr; + *error = nullptr; } } } diff --git a/socket_manager/waker.cc b/socket_manager/waker.cc new file mode 100644 index 0000000..2b4bc8f --- /dev/null +++ b/socket_manager/waker.cc @@ -0,0 +1,34 @@ +#include "socket_manager/common/waker.h" + +namespace socket_manager { + + Waker::Waker() + : waker(SOCKET_MANAGER_C_API_CWaker{nullptr, nullptr}) {} + + Waker::Waker(SOCKET_MANAGER_C_API_CWaker waker) : waker(waker) {} + + Waker::Waker(Waker &&other) noexcept: waker(other.waker) { + other.waker.Data = nullptr; + other.waker.Vtable = nullptr; + } + + Waker &Waker::operator=(Waker &&other) noexcept { + waker = other.waker; + other.waker.Data = nullptr; + other.waker.Vtable = nullptr; + return *this; + } + + void Waker::wake() { + if (waker.Data != nullptr && waker.Vtable != nullptr) { + socket_manager_waker_wake(&waker); + } + } + + Waker::~Waker() { + if (waker.Data != nullptr && waker.Vtable != nullptr) { + socket_manager_waker_free(waker); + } + } + +} // namespace socket_manager diff --git a/src/c_api/async_ffi/mod.rs b/src/c_api/async_ffi/mod.rs new file mode 100644 index 0000000..b465819 --- /dev/null +++ b/src/c_api/async_ffi/mod.rs @@ -0,0 +1,4 @@ +//! Defines modules for bridging +//! c/c++ and rust async code. +pub(crate) mod notifier; +pub(crate) mod waker; diff --git a/src/c_api/async_ffi/notifier.rs b/src/c_api/async_ffi/notifier.rs new file mode 100644 index 0000000..d610afd --- /dev/null +++ b/src/c_api/async_ffi/notifier.rs @@ -0,0 +1,63 @@ +//! A notifier is a c/c++ callback object called by rust +//! to notify c/c++ code that they should resume certain tasks. +use std::ffi::c_void; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +/// The Notifier is constructed by the c/c++ code, +/// and passed to the rust code. +/// +/// # Task Resume +/// When `wake` callback is called by rust, the c/c++ task +/// should resume its execution. +/// +/// # Lifetime Management. +/// The Notifier has `clone` and `release` callbacks. +/// Say a Notifier start with ref_count = 1, +/// and when `clone` is called, increment its ref_count, +/// and when `release` is called, decrement its ref_count. +/// +/// The notifier can be released when its ref_count falls back to 1. +/// +/// The c/c++ code must carefully manage the lifetime of the waker. +/// to ensure that the waker is not dropped before the rust code +/// is done with it. +#[repr(C)] +#[derive(Copy, Clone)] +pub struct Notifier { + pub(crate) this: *mut c_void, +} + +#[link(name = "socket_manager")] +extern "C" { + /// Waker for the try_send method. + pub(crate) fn socket_manager_extern_notifier_wake(this: Notifier); +} + +impl Notifier { + #[inline] + pub(crate) unsafe fn to_waker(self) -> Waker { + const MSG_SENDER_WAKER_VTABLE: RawWakerVTable = make_vtable(); + + const fn make_vtable() -> RawWakerVTable { + RawWakerVTable::new( + |dat| RawWaker::new(dat, &MSG_SENDER_WAKER_VTABLE), + |dat| unsafe { + let this = Notifier { + this: dat as *mut c_void, + }; + socket_manager_extern_notifier_wake(this); + }, + |dat| unsafe { + let this = Notifier { + this: dat as *mut c_void, + }; + socket_manager_extern_notifier_wake(this); + }, + |_| {}, + ) + } + + let raw_waker = RawWaker::new(self.this as *const (), &MSG_SENDER_WAKER_VTABLE); + Waker::from_raw(raw_waker) + } +} diff --git a/src/c_api/async_ffi/waker.rs b/src/c_api/async_ffi/waker.rs new file mode 100644 index 0000000..3db2f5b --- /dev/null +++ b/src/c_api/async_ffi/waker.rs @@ -0,0 +1,69 @@ +//! A CWaker is a C compatible version of `std::task::Waker`, +//! that is used for c/c++ code to wake rust tasks. +use std::ffi::c_void; +use std::task::{RawWaker, RawWakerVTable, Waker}; + +/// # Safety +/// Do not use this struct directly. +/// Properly wrap it in c++ class. +/// +/// This struct is equivalent to a raw pointer. +/// Manager with care. +/// +/// Note that the CWaker must be properly dropped. +/// Otherwise, the associated task will leak. +#[repr(C)] +pub struct CWaker { + data: *const c_void, + vtable: *const c_void, +} + +/// Call the waker to wake the relevant task of context. +#[no_mangle] +pub unsafe extern "C" fn socket_manager_waker_wake(waker: &CWaker) { + waker.wake_by_ref(); +} + +/// Release the waker. +#[no_mangle] +pub unsafe extern "C" fn socket_manager_waker_free(waker: CWaker) { + drop(waker.into_waker()); +} + +impl CWaker { + /// Take ownership of the waker. + #[inline] + pub(crate) fn from_waker(waker: Waker) -> Self { + let raw_waker = waker.as_raw(); + let c_waker = Self { + data: raw_waker.data() as *const c_void, + vtable: raw_waker.vtable() as *const RawWakerVTable as *const c_void, + }; + // do not drop the waker. + std::mem::forget(waker); + c_waker + } + + /// Do Not restore ownership of the waker. + #[inline] + unsafe fn wake_by_ref(&self) { + let raw_waker = RawWaker::new( + self.data as *const (), + &*(self.vtable as *const RawWakerVTable), + ); + let waker = Waker::from_raw(raw_waker); + waker.wake_by_ref(); + // do not drop the waker. + std::mem::forget(waker); + } + + /// Restore ownership of the waker. + #[inline] + unsafe fn into_waker(self) -> Waker { + let raw_waker = RawWaker::new( + self.data as *const (), + &*(self.vtable as *const RawWakerVTable), + ); + Waker::from_raw(raw_waker) + } +} diff --git a/src/c_api/structs.rs b/src/c_api/conn_events.rs similarity index 90% rename from src/c_api/structs.rs rename to src/c_api/conn_events.rs index cedca29..d9f1c88 100644 --- a/src/c_api/structs.rs +++ b/src/c_api/conn_events.rs @@ -1,4 +1,5 @@ -use crate::c_api::callbacks::OnMsgObj; +use crate::c_api::on_msg::OnMsgObj; +use crate::MsgSender; use libc::size_t; use std::ffi::c_char; @@ -40,7 +41,8 @@ pub union ConnStateData { pub struct OnConnect { pub(crate) local: *const c_char, pub(crate) peer: *const c_char, - pub(crate) conn: *mut CConnection, + pub(crate) send: *mut MsgSender, + pub(crate) conn: *mut Connection, } #[repr(C)] @@ -64,6 +66,6 @@ pub struct OnConnectError { pub(crate) err: *const c_char, } -pub struct CConnection { +pub struct Connection { pub(crate) conn: crate::Conn, } diff --git a/src/c_api/connection.rs b/src/c_api/connection.rs index e02cc3a..a67f7db 100644 --- a/src/c_api/connection.rs +++ b/src/c_api/connection.rs @@ -1,8 +1,7 @@ -use crate::c_api::callbacks::OnMsgObj; -use crate::c_api::structs::CConnection; +use crate::c_api::conn_events::Connection; +use crate::c_api::on_msg::OnMsgObj; use crate::c_api::utils::write_error_c_str; use crate::conn::ConnConfig; -use crate::msg_sender::CMsgSender; use libc::size_t; use std::ffi::{c_char, c_int}; use std::num::NonZeroUsize; @@ -10,7 +9,7 @@ use std::os::raw::c_ulonglong; use std::ptr::null_mut; use std::time::Duration; -/// Start a connection with the given `OnMsgCallback`, and return a pointer to a `CMsgSender`. +/// Start a connection with the given `OnMsgCallback`, and return a pointer to a `MsgSender`. /// /// Only one of `connection_start` or `connection_close` should be called, /// or it will result in runtime error. @@ -35,32 +34,21 @@ use std::time::Duration; /// Set to 0 to disable auto flush. /// * `err` - A pointer to a pointer to a C string allocated by `malloc` on error. /// -/// # Returns -/// A pointer to a `CMsgSender` on success, null on error. -/// /// # Errors +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`, -/// and the returned pointer will be null. #[no_mangle] -pub unsafe extern "C" fn connection_start( - conn: *mut CConnection, +pub unsafe extern "C" fn socket_manager_connection_start( + conn: *mut Connection, on_msg: OnMsgObj, msg_buffer_size: size_t, read_msg_flush_interval: c_ulonglong, write_flush_interval: c_ulonglong, err: *mut *mut c_char, -) -> *mut CMsgSender { +) -> c_int { let conn = &mut (*conn).conn; - let write_flush_interval = if write_flush_interval == 0 { - None - } else { - Some(Duration::from_millis(write_flush_interval)) - }; - let read_msg_flush_interval = if read_msg_flush_interval == 0 { - None - } else { - Some(Duration::from_millis(read_msg_flush_interval)) - }; + let write_flush_interval = Duration::from_millis(write_flush_interval); + let read_msg_flush_interval = Duration::from_millis(read_msg_flush_interval); let msg_buffer_size = NonZeroUsize::new(msg_buffer_size); match conn.start_connection( on_msg, @@ -70,13 +58,13 @@ pub unsafe extern "C" fn connection_start( msg_buffer_size, }, ) { - Ok(sender) => { + Ok(_) => { *err = null_mut(); - Box::into_raw(Box::new(sender)) + 0 } Err(e) => { write_error_c_str(e, err); - null_mut() + 1 } } } @@ -90,10 +78,13 @@ pub unsafe extern "C" fn connection_start( /// Thread safe. /// /// # Errors -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] -pub unsafe extern "C" fn connection_close(conn: *mut CConnection, err: *mut *mut c_char) -> c_int { +pub unsafe extern "C" fn socket_manager_connection_close( + conn: *mut Connection, + err: *mut *mut c_char, +) -> c_int { let conn = &mut (*conn).conn; match conn.close() { Ok(_) => { @@ -102,13 +93,13 @@ pub unsafe extern "C" fn connection_close(conn: *mut CConnection, err: *mut *mut } Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } /// Destructor of `Connection`. #[no_mangle] -pub unsafe extern "C" fn connection_free(conn: *mut CConnection) { +pub unsafe extern "C" fn socket_manager_connection_free(conn: *mut Connection) { drop(Box::from_raw(conn)) } diff --git a/src/c_api/ffi.rs b/src/c_api/ffi.rs deleted file mode 100644 index 14f46b2..0000000 --- a/src/c_api/ffi.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::c_api::callbacks::{OnConnObj, OnMsgObj, WakerObj}; -use crate::c_api::structs::{ConnMsg, ConnStates}; -use std::ffi::c_char; - -#[link(name = "socket_manager")] -extern "C" { - /// Callback function for receiving messages. - pub(crate) fn socket_manager_extern_on_msg(this: OnMsgObj, msg: ConnMsg) -> *mut c_char; - - /// Callback function for connection state changes. - pub(crate) fn socket_manager_extern_on_conn(this: OnConnObj, conn: ConnStates) -> *mut c_char; - - /// Waker for the try_send method. - pub(crate) fn socket_manager_extern_sender_waker_wake(this: WakerObj); - - /// Decrement ref count of the waker. - pub(crate) fn socket_manager_extern_sender_waker_release(this: WakerObj); - - /// Increment ref count of the waker. - pub(crate) fn socket_manager_extern_sender_waker_clone(this: WakerObj); -} diff --git a/src/c_api/mod.rs b/src/c_api/mod.rs index 30b68b4..09d0318 100644 --- a/src/c_api/mod.rs +++ b/src/c_api/mod.rs @@ -1,7 +1,8 @@ -pub(crate) mod callbacks; +pub(crate) mod async_ffi; +mod conn_events; mod connection; -pub(crate) mod ffi; mod msg_sender; +pub(crate) mod on_conn; +pub(crate) mod on_msg; mod socket_manager; -mod structs; mod utils; diff --git a/src/c_api/msg_sender.rs b/src/c_api/msg_sender.rs index 1f93198..31d0714 100644 --- a/src/c_api/msg_sender.rs +++ b/src/c_api/msg_sender.rs @@ -1,11 +1,15 @@ -use crate::c_api::callbacks::WakerObj; +use crate::c_api::async_ffi::notifier::Notifier; use crate::c_api::utils::write_error_c_str; -use crate::msg_sender::CMsgSender; +use crate::msg_sender::MsgSender; use libc::size_t; use std::ffi::{c_char, c_int, c_long}; use std::ptr::null_mut; +use std::task::Poll; -/// Send a message via the given `CMsgSender`. +pub const PENDING: c_long = -1; + +/// Send a message via the given `MsgSender` synchronously. +/// This is a blocking API. /// /// # Thread Safety /// Thread safe. @@ -14,14 +18,14 @@ use std::ptr::null_mut; /// since it might block. /// /// # Errors -/// If the connection is closed, the function will return -1 and set `err` to a pointer +/// If the connection is closed, the function will return 1 and set `err` to a pointer /// with WriteZero error. /// -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] -pub unsafe extern "C" fn msg_sender_send( - sender: *mut CMsgSender, +pub unsafe extern "C" fn socket_manager_msg_sender_send_block( + sender: *mut MsgSender, msg: *const c_char, len: size_t, err: *mut *mut c_char, @@ -35,53 +39,90 @@ pub unsafe extern "C" fn msg_sender_send( } Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } -/// Try to send a message via the given `CMsgSender`. +/// Send a message via the given `MsgSender` . +/// This is a non-blocking API. +/// All sent data is buffered in a chain of ring buffer. +/// This method does not implement back pressure since it +/// caches all received data. +/// Use `send_async` or `send_block` for back pressure. /// /// # Thread Safety /// Thread safe. /// -/// This function is non-blocking, pass the MsgSender class -/// to the waker_obj to receive notification to continue -/// sending the message. +/// This function can be called within the context of the async callbacks. /// -/// # Return -/// If waker is provided, returns the number of bytes sent on success, -/// and 0 on connection closed, -1 on pending. +/// # Errors +/// If the connection is closed, the function will return 1 and set `err` to a pointer +/// with WriteZero error. /// -/// If waker is not provided, returns the number of bytes sent. -/// 0 might indicate the connection is closed, or the message buffer is full. +/// Returns 1 on error, 0 on success. +/// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. +#[no_mangle] +pub unsafe extern "C" fn socket_manager_msg_sender_send_nonblock( + sender: *mut MsgSender, + msg: *const c_char, + len: size_t, + err: *mut *mut c_char, +) -> c_int { + let sender = &mut (*sender); + let msg = std::slice::from_raw_parts(msg as *const u8, len); + match sender.send_nonblock(msg) { + Ok(_) => { + *err = null_mut(); + 0 + } + Err(e) => { + write_error_c_str(e, err); + 1 + } + } +} + +/// Try to send a message via the given `MsgSender` asynchronously. +/// +/// # Thread Safety +/// Thread safe. +/// +/// # Async control flow (IMPORTANT) +/// +/// This function is non-blocking, it returns `PENDING = -1` +/// if the send buffer is full. So the caller should wait +/// by passing a `Notifier` which will be called when the +/// buffer is ready. +/// +/// When the buffer is ready, the function returns number of bytes sent. /// /// # Errors +/// Use `err` pointer to check for error. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] -pub unsafe extern "C" fn msg_sender_try_send( - sender: *mut CMsgSender, +pub unsafe extern "C" fn socket_manager_msg_sender_send_async( + sender: *mut MsgSender, msg: *const c_char, len: size_t, - waker_obj: WakerObj, + notifier: Notifier, err: *mut *mut c_char, ) -> c_long { let sender = &mut (*sender); let msg = std::slice::from_raw_parts(msg as *const u8, len); - let waker_obj = if waker_obj.this.is_null() { - None - } else { - Some(waker_obj) - }; - match sender.try_send(msg, waker_obj) { - Ok(n) => { + // notifier should not be null + assert!(!notifier.this.is_null()); + let waker = notifier.to_waker(); + match sender.send_async(msg, waker) { + Poll::Ready(Ok(n)) => { *err = null_mut(); - n + n as c_long } - Err(e) => { + Poll::Ready(Err(e)) => { write_error_c_str(e, err); - 0 + 0 as c_long } + Poll::Pending => PENDING, } } @@ -91,10 +132,13 @@ pub unsafe extern "C" fn msg_sender_try_send( /// Thread safe. /// /// # Errors -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] -pub unsafe extern "C" fn msg_sender_flush(sender: *mut CMsgSender, err: *mut *mut c_char) -> c_int { +pub unsafe extern "C" fn socket_manager_msg_sender_flush( + sender: *mut MsgSender, + err: *mut *mut c_char, +) -> c_int { let sender = &mut (*sender); match sender.flush() { Ok(_) => { @@ -103,7 +147,7 @@ pub unsafe extern "C" fn msg_sender_flush(sender: *mut CMsgSender, err: *mut *mu } Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } @@ -111,6 +155,6 @@ pub unsafe extern "C" fn msg_sender_flush(sender: *mut CMsgSender, err: *mut *mu /// Destructor of `MsgSender`. /// Drop sender to actively close the connection. #[no_mangle] -pub unsafe extern "C" fn msg_sender_free(sender: *mut CMsgSender) { +pub unsafe extern "C" fn socket_manager_msg_sender_free(sender: *mut MsgSender) { drop(Box::from_raw(sender)) } diff --git a/src/c_api/callbacks.rs b/src/c_api/on_conn.rs similarity index 57% rename from src/c_api/callbacks.rs rename to src/c_api/on_conn.rs index 6eb3d11..52d20df 100644 --- a/src/c_api/callbacks.rs +++ b/src/c_api/on_conn.rs @@ -1,81 +1,11 @@ -use crate::c_api::ffi::{ - socket_manager_extern_on_conn, socket_manager_extern_on_msg, - socket_manager_extern_sender_waker_clone, socket_manager_extern_sender_waker_release, - socket_manager_extern_sender_waker_wake, -}; -use crate::c_api::structs::{ - CConnection, ConnMsg, ConnStateCode, ConnStateData, ConnStates, OnConnect, OnConnectError, +use crate::c_api::conn_events::{ + ConnStateCode, ConnStateData, ConnStates, Connection, OnConnect, OnConnectError, OnConnectionClose, OnListenError, }; +use crate::c_api::on_msg::OnMsgObj; use crate::c_api::utils::parse_c_err_str; use std::ffi::{c_char, c_void, CString}; -use std::task::{RawWaker, RawWakerVTable, Waker}; - -const MSG_SENDER_VTABLE: RawWakerVTable = WakerObj::make_vtable(); - -/// Send the msg sender obj to receive -/// writable notification. -#[repr(C)] -#[derive(Copy, Clone)] -pub struct WakerObj { - pub(crate) this: *mut c_void, -} - -impl WakerObj { - pub(crate) unsafe fn make_waker(&self) -> Waker { - let raw_waker = RawWaker::new(self.this as *const (), &MSG_SENDER_VTABLE); - // Increment the ref count since a new waker is created. - socket_manager_extern_sender_waker_clone(*self); - Waker::from_raw(raw_waker) - } - - const fn make_vtable() -> RawWakerVTable { - RawWakerVTable::new( - |dat| unsafe { - let this = dat as *mut c_void; - let msg_obj = WakerObj { this }; - socket_manager_extern_sender_waker_clone(msg_obj); - RawWaker::new(dat, &MSG_SENDER_VTABLE) - }, - |dat| unsafe { - let this = dat as *mut c_void; - let msg_obj = WakerObj { this }; - socket_manager_extern_sender_waker_wake(msg_obj); - socket_manager_extern_sender_waker_release(msg_obj); - }, - |dat| unsafe { - let this = dat as *mut c_void; - let msg_obj = WakerObj { this }; - socket_manager_extern_sender_waker_wake(msg_obj); - }, - |dat| unsafe { - let this = dat as *mut c_void; - let msg_obj = WakerObj { this }; - socket_manager_extern_sender_waker_release(msg_obj); - }, - ) - } -} - -/// Callback function for receiving messages. -/// -/// `callback_self` is feed to the first argument of the callback. -/// -/// # Error Handling -/// Returns null_ptr on success, otherwise returns a pointer to a malloced -/// C string containing the error message (the c string should be freed by the -/// caller). -/// -/// # Safety -/// The callback pointer must be valid before connection is closed!! -/// -/// # Thread Safety -/// Must be thread safe! -#[repr(C)] -#[derive(Copy, Clone)] -pub struct OnMsgObj { - this: *mut c_void, -} +use std::ptr::null_mut; /// Callback function for connection state changes. /// @@ -98,66 +28,45 @@ pub struct OnConnObj { this: *mut c_void, } -impl OnMsgObj { - pub fn call_inner(&self, conn_msg: crate::Msg<'_>) -> Result<(), String> { - let conn_msg = ConnMsg { - bytes: conn_msg.bytes.as_ptr() as *const c_char, - len: conn_msg.bytes.len(), - }; - unsafe { - let cb_result = socket_manager_extern_on_msg(*self, conn_msg); - if let Err(e) = parse_c_err_str(cb_result) { - tracing::error!("Error thrown in OnMsg callback: {e}"); - Err(e) - } else { - Ok(()) - } - } - } -} - -impl FnMut<(crate::Msg<'_>,)> for OnMsgObj { - extern "rust-call" fn call_mut(&mut self, conn_msg: (crate::Msg<'_>,)) -> Self::Output { - self.call_inner(conn_msg.0) - } -} - -impl FnOnce<(crate::Msg<'_>,)> for OnMsgObj { - type Output = Result<(), String>; - - extern "rust-call" fn call_once(self, conn_msg: (crate::Msg<'_>,)) -> Self::Output { - self.call_inner(conn_msg.0) - } -} - -impl Fn<(crate::Msg<'_>,)> for OnMsgObj { - extern "rust-call" fn call(&self, conn_msg: (crate::Msg<'_>,)) -> Self::Output { - self.call_inner(conn_msg.0) - } +#[link(name = "socket_manager")] +extern "C" { + /// Rust calls this function to send `conn: ConnStates` + /// to the `this: OnConnObj`. If the process has any error, + /// pass error to `err` pointer. + /// Set `err` to null_ptr if there is no error. + pub(crate) fn socket_manager_extern_on_conn( + this: OnConnObj, + conn: ConnStates, + err: *mut *mut c_char, + ); } impl OnConnObj { /// connection callback pub(crate) fn call_inner(&self, conn_states: crate::ConnState) -> Result<(), String> { let on_conn = |conn| unsafe { - let cb_result = socket_manager_extern_on_conn(*self, conn); - parse_c_err_str(cb_result) + let mut err: *mut c_char = null_mut(); + socket_manager_extern_on_conn(*self, conn, &mut err); + parse_c_err_str(err) }; match conn_states { crate::ConnState::OnConnect { local_addr, peer_addr, + send, conn, } => { let local = CString::new(local_addr.to_string()).unwrap(); let peer = CString::new(peer_addr.to_string()).unwrap(); - let conn = Box::into_raw(Box::new(CConnection { conn })); + let send = Box::into_raw(Box::new(send)); + let conn = Box::into_raw(Box::new(Connection { conn })); let conn_msg = ConnStates { code: ConnStateCode::Connect, data: ConnStateData { on_connect: OnConnect { local: local.as_ptr(), peer: peer.as_ptr(), + send, conn, }, }, @@ -258,7 +167,6 @@ impl Fn<(crate::ConnState,)> for OnConnObj { } } -unsafe impl Send for OnMsgObj {} -unsafe impl Sync for OnMsgObj {} unsafe impl Send for OnConnObj {} + unsafe impl Sync for OnConnObj {} diff --git a/src/c_api/on_msg.rs b/src/c_api/on_msg.rs new file mode 100644 index 0000000..e51899e --- /dev/null +++ b/src/c_api/on_msg.rs @@ -0,0 +1,105 @@ +use crate::c_api::async_ffi::waker::CWaker; +use crate::c_api::conn_events::ConnMsg; +use crate::c_api::utils::parse_c_err_str; +use std::ffi::{c_char, c_long, c_void}; +use std::ptr::null_mut; +use std::task::{Poll, Waker}; + +/// Callback function for receiving messages. +/// +/// `callback_self` is feed to the first argument of the callback. +/// +/// # Error Handling +/// Returns null_ptr on success, otherwise returns a pointer to a malloced +/// C string containing the error message (the c string should be freed by the +/// caller). +/// +/// # Safety +/// The callback pointer must be valid before connection is closed!! +/// +/// # Thread Safety +/// Must be thread safe! +#[repr(C)] +#[derive(Copy, Clone)] +pub struct OnMsgObj { + this: *mut c_void, +} + +#[link(name = "socket_manager")] +extern "C" { + /// Rust calls this function to send `msg: ConnMsg` + /// to `OnMsgObj`. If the process has any error, + /// pass error to `err` pointer. + /// Set `err` to null_ptr if there is no error. + /// + /// # Async control flow (IMPORTANT) + /// + /// The caller should return the exact number of bytes written + /// to the runtime if some bytes are written. The runtime + /// will increment the read offset accordingly. + /// + /// If the caller is unable to receive any bytes, + /// it should return `PENDING = -1` to the runtime + /// to interrupt message receiving task. The read offset + /// will not be incremented. + /// + /// When the caller is able to receive bytes again, + /// it should call `waker.wake()` to wake up the runtime. + pub(crate) fn socket_manager_extern_on_msg( + this: OnMsgObj, + msg: ConnMsg, + waker: CWaker, + err: *mut *mut c_char, + ) -> c_long; +} + +impl OnMsgObj { + pub fn call_inner( + &self, + conn_msg: crate::Msg<'_>, + waker: Waker, + ) -> Poll> { + let bytes = conn_msg.bytes.as_ptr() as *const c_char; + let len = conn_msg.bytes.len(); + let conn_msg = ConnMsg { bytes, len }; + // takes the ownership of the waker + let waker = CWaker::from_waker(waker); + unsafe { + let mut err: *mut c_char = null_mut(); + let cb_result = socket_manager_extern_on_msg(*self, conn_msg, waker, &mut err); + if let Err(e) = parse_c_err_str(err) { + tracing::error!("Error thrown in OnMsg callback: {e}"); + Poll::Ready(Err(e)) + } else if cb_result > 0 { + assert!(cb_result <= len as c_long); + Poll::Ready(Ok(cb_result as usize)) + } else { + Poll::Pending + } + } + } +} + +impl FnMut<(crate::Msg<'_>, Waker)> for OnMsgObj { + extern "rust-call" fn call_mut(&mut self, args: (crate::Msg<'_>, Waker)) -> Self::Output { + self.call_inner(args.0, args.1) + } +} + +impl FnOnce<(crate::Msg<'_>, Waker)> for OnMsgObj { + type Output = Poll>; + + extern "rust-call" fn call_once(self, args: (crate::Msg<'_>, Waker)) -> Self::Output { + self.call_inner(args.0, args.1) + } +} + +impl Fn<(crate::Msg<'_>, Waker)> for OnMsgObj { + extern "rust-call" fn call(&self, args: (crate::Msg<'_>, Waker)) -> Self::Output { + self.call_inner(args.0, args.1) + } +} + +unsafe impl Send for OnMsgObj {} + +unsafe impl Sync for OnMsgObj {} diff --git a/src/c_api/socket_manager.rs b/src/c_api/socket_manager.rs index 383d23d..496d314 100644 --- a/src/c_api/socket_manager.rs +++ b/src/c_api/socket_manager.rs @@ -1,6 +1,6 @@ -use crate::c_api::callbacks::OnConnObj; +use crate::c_api::on_conn::OnConnObj; use crate::c_api::utils::{socket_addr, write_error_c_str}; -use crate::CSocketManager; +use crate::SocketManager; use libc::size_t; use std::ffi::{c_char, c_int}; use std::ptr::null_mut; @@ -29,8 +29,8 @@ pub unsafe extern "C" fn socket_manager_init( on_conn: OnConnObj, n_threads: size_t, err: *mut *mut c_char, -) -> *mut CSocketManager { - match CSocketManager::init(on_conn, n_threads) { +) -> *mut SocketManager { + match SocketManager::init(on_conn, n_threads) { Ok(manager) => { *err = null_mut(); Box::into_raw(Box::new(manager)) @@ -48,11 +48,11 @@ pub unsafe extern "C" fn socket_manager_init( /// Thread safe. /// /// # Errors -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] pub unsafe extern "C" fn socket_manager_listen_on_addr( - manager: *mut CSocketManager, + manager: *mut SocketManager, addr: *const c_char, err: *mut *mut c_char, ) -> c_int { @@ -65,12 +65,12 @@ pub unsafe extern "C" fn socket_manager_listen_on_addr( } Err(e) => { write_error_c_str(e, err); - -1 + 1 } }, Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } @@ -81,11 +81,11 @@ pub unsafe extern "C" fn socket_manager_listen_on_addr( /// Thread safe. /// /// # Errors -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] pub unsafe extern "C" fn socket_manager_connect_to_addr( - manager: *mut CSocketManager, + manager: *mut SocketManager, addr: *const c_char, err: *mut *mut c_char, ) -> c_int { @@ -98,12 +98,12 @@ pub unsafe extern "C" fn socket_manager_connect_to_addr( } Err(e) => { write_error_c_str(e, err); - -1 + 1 } }, Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } @@ -114,11 +114,11 @@ pub unsafe extern "C" fn socket_manager_connect_to_addr( /// Thread safe. /// /// # Errors -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] pub unsafe extern "C" fn socket_manager_cancel_listen_on_addr( - manager: *mut CSocketManager, + manager: *mut SocketManager, addr: *const c_char, err: *mut *mut c_char, ) -> c_int { @@ -131,12 +131,12 @@ pub unsafe extern "C" fn socket_manager_cancel_listen_on_addr( } Err(e) => { write_error_c_str(e, err); - -1 + 1 } }, Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } @@ -150,11 +150,11 @@ pub unsafe extern "C" fn socket_manager_cancel_listen_on_addr( /// - `wait`: if true, wait for the background runtime to finish. /// /// # Errors -/// Returns -1 on error, 0 on success. +/// Returns 1 on error, 0 on success. /// On Error, `err` will be set to a pointer to a C string allocated by `malloc`. #[no_mangle] pub unsafe extern "C" fn socket_manager_abort( - manager: *mut CSocketManager, + manager: *mut SocketManager, wait: bool, err: *mut *mut c_char, ) -> c_int { @@ -166,7 +166,7 @@ pub unsafe extern "C" fn socket_manager_abort( } Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } @@ -183,7 +183,7 @@ pub unsafe extern "C" fn socket_manager_abort( /// Join returns error if the runtime panicked. #[no_mangle] pub unsafe extern "C" fn socket_manager_join( - manager: *mut CSocketManager, + manager: *mut SocketManager, err: *mut *mut c_char, ) -> c_int { let manager = &mut *manager; @@ -194,7 +194,7 @@ pub unsafe extern "C" fn socket_manager_join( } Err(e) => { write_error_c_str(e, err); - -1 + 1 } } } @@ -202,6 +202,6 @@ pub unsafe extern "C" fn socket_manager_join( /// Calling this function will abort all background runtime and join on them, /// and free the `SocketManager`. #[no_mangle] -pub unsafe extern "C" fn socket_manager_free(manager: *mut CSocketManager) { +pub unsafe extern "C" fn socket_manager_free(manager: *mut SocketManager) { drop(Box::from_raw(manager)) } diff --git a/src/conn.rs b/src/conn.rs index 317f2ee..653e90e 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -1,7 +1,7 @@ -use crate::msg_sender::CMsgSender; use crate::Msg; use std::num::NonZeroUsize; use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Poll, Waker}; use std::time::Duration; use tokio::sync::oneshot; @@ -13,20 +13,13 @@ pub struct Conn { struct ConnInner { conn_config_setter: oneshot::Sender<(OnMsg, ConnConfig)>, - send: CMsgSender, } impl Conn { - pub(crate) fn new( - conn_config_setter: oneshot::Sender<(OnMsg, ConnConfig)>, - send: CMsgSender, - ) -> Self { + pub(crate) fn new(conn_config_setter: oneshot::Sender<(OnMsg, ConnConfig)>) -> Self { Self { consumed: AtomicBool::new(false), - inner: Some(ConnInner { - conn_config_setter, - send, - }), + inner: Some(ConnInner { conn_config_setter }), } } } @@ -34,18 +27,18 @@ impl Conn { /// Connection configuration #[derive(Copy, Clone)] pub struct ConnConfig { - pub write_flush_interval: Option, - pub read_msg_flush_interval: Option, + /// zero represent no auto flush + pub write_flush_interval: Duration, + /// zero represent no auto flush + pub read_msg_flush_interval: Duration, pub msg_buffer_size: Option, } -impl) -> Result<(), String> + Send + 'static + Clone> Conn { +impl, Waker) -> Poll> + Send + 'static + Clone> + Conn +{ /// This function should be called only once. - pub fn start_connection( - &mut self, - on_msg: OnMsg, - config: ConnConfig, - ) -> std::io::Result { + pub fn start_connection(&mut self, on_msg: OnMsg, config: ConnConfig) -> std::io::Result<()> { self.consumed .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) .map_err(|_| { @@ -64,10 +57,10 @@ impl) -> Result<(), String> + Send + 'static + Clone> Conn std::io::Result<()> { self.consumed .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) diff --git a/src/conn_handle.rs b/src/conn_handle.rs index dad9c24..0d310e4 100644 --- a/src/conn_handle.rs +++ b/src/conn_handle.rs @@ -1,20 +1,19 @@ use crate::conn::{Conn, ConnConfig}; -use crate::msg_sender::{CMsgSender, SendCommand, RING_BUFFER_SIZE}; +use crate::msg_sender::make_sender; use crate::{read, write, ConnState, ConnectionState, Msg}; -use async_ringbuf::AsyncHeapRb; use futures::FutureExt; use std::net::SocketAddr; use std::sync::Arc; +use std::task::{Poll, Waker}; use tokio::net::TcpStream; use tokio::runtime::Handle; -use tokio::sync::mpsc::unbounded_channel; use tokio::sync::oneshot; use tokio::task::JoinHandle; /// This function handles connection from a client. pub(crate) fn handle_connection< OnConn: Fn(ConnState) -> Result<(), String> + Send + 'static + Clone, - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Unpin + 'static, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static, >( local_addr: SocketAddr, peer_addr: SocketAddr, @@ -23,14 +22,8 @@ pub(crate) fn handle_connection< on_conn: OnConn, connection_state: Arc, ) { - let (send, recv) = unbounded_channel::(); let (conn_config_setter, conn_config) = oneshot::channel::<(OnMsg, ConnConfig)>(); - let (buf_prd, ring_buf) = AsyncHeapRb::new(RING_BUFFER_SIZE).split(); - let send = CMsgSender { - cmd: send, - buf_prd, - handle: handle.clone(), - }; + let (send, recv) = make_sender(handle.clone()); let on_conn_clone = on_conn.clone(); // Call `on_conn` callback, and wait for user to call `start` on connection. @@ -39,7 +32,8 @@ pub(crate) fn handle_connection< on_conn(ConnState::OnConnect { local_addr, peer_addr, - conn: Conn::new(conn_config_setter, send), + send, + conn: Conn::new(conn_config_setter), }) .map_err(|_| ())?; @@ -60,7 +54,7 @@ pub(crate) fn handle_connection< // spawn reader and writer let (stop, stopper) = oneshot::channel::<()>(); let (read, write) = stream.into_split(); - let writer = handle.spawn(write::handle_writer(write, recv, ring_buf, config, stop)); + let writer = handle.spawn(write::handle_writer(write, recv, config, stop)); let reader = handle.spawn(read::handle_reader(read, on_msg, config)); // insert the stopper into connection_state diff --git a/src/lib.rs b/src/lib.rs index 0cc9fa6..5dfca9a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![feature(unboxed_closures)] #![feature(fn_traits)] +#![feature(waker_getters)] #![allow(improper_ctypes)] mod c_api; @@ -14,6 +15,7 @@ use dashmap::DashMap; use std::net::SocketAddr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::task::{Poll, Waker}; use std::time::Duration; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime::Handle; @@ -30,7 +32,7 @@ const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); /// The Main Struct of the Library. /// /// This struct is thread safe. -pub struct CSocketManager { +pub struct SocketManager { cmd_send: UnboundedSender, // use has_joined to fence the join_handle, @@ -53,6 +55,7 @@ pub enum ConnState { OnConnect { local_addr: SocketAddr, peer_addr: SocketAddr, + send: MsgSender, conn: Conn, }, /// sent on connection closed @@ -104,7 +107,7 @@ pub struct Msg<'a> { bytes: &'a [u8], } -impl CSocketManager { +impl SocketManager { /// start background threads to run the runtime /// /// # Arguments @@ -115,11 +118,11 @@ impl CSocketManager { /// - \>1: use the specified number of threads. pub fn init< OnConn: Fn(ConnState) -> Result<(), String> + Send + Sync + 'static + Clone, - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Sync + Unpin + 'static + Clone, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Sync + Unpin + 'static + Clone, >( on_conn: OnConn, n_threads: usize, - ) -> std::io::Result { + ) -> std::io::Result { let _ = tracing_subscriber::fmt() .with_env_filter( EnvFilter::builder() @@ -137,7 +140,7 @@ impl CSocketManager { runtime.shutdown_timeout(SHUTDOWN_TIMEOUT); tracing::info!("socket_manager stopped"); })); - Ok(CSocketManager { + Ok(SocketManager { cmd_send, has_joined: AtomicBool::new(false), join_handle, @@ -198,7 +201,7 @@ impl CSocketManager { } } -impl Drop for CSocketManager { +impl Drop for SocketManager { fn drop(&mut self) { let _ = self.abort(true); } @@ -207,7 +210,7 @@ impl Drop for CSocketManager { /// The main loop running in the background. async fn main< OnConn: Fn(ConnState) -> Result<(), String> + Send + Sync + 'static + Clone, - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Sync + Unpin + 'static + Clone, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Sync + Unpin + 'static + Clone, >( mut cmd_recv: UnboundedReceiver, handle: &Handle, @@ -236,9 +239,21 @@ async fn main< } /// This function connects to a port. +/// +/// The design of the function guarantees that either `OnConnect` or `OnConnectError` +/// will be called, but not both. And after `OnConnect` is called, `OnConnectionClose` +/// is guaranteed to be called. +/// +/// The follow diagram shows the possible state transitions: +/// +/// ```text +/// connect_command --> either --> OnConnect --> OnConnectionClose +/// | +/// --> OnConnectError +/// ``` fn connect_to_addr< OnConn: Fn(ConnState) -> Result<(), String> + Send + 'static + Clone, - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Unpin + 'static, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static, >( handle: &Handle, addr: SocketAddr, @@ -281,7 +296,7 @@ fn connect_to_addr< /// This function listens on a port. fn listen_on_addr< OnConn: Fn(ConnState) -> Result<(), String> + Send + Sync + 'static + Clone, - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Sync + Unpin + 'static + Clone, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Sync + Unpin + 'static + Clone, >( handle: &Handle, addr: SocketAddr, @@ -324,7 +339,7 @@ fn listen_on_addr< async fn accept_connections< OnConn: Fn(ConnState) -> Result<(), String> + Send + 'static + Clone, - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Unpin + 'static + Clone, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static + Clone, >( addr: SocketAddr, listener: TcpListener, diff --git a/src/msg_sender.rs b/src/msg_sender.rs index f9aef48..f4be8f7 100644 --- a/src/msg_sender.rs +++ b/src/msg_sender.rs @@ -1,23 +1,45 @@ -use crate::c_api::callbacks::WakerObj; -use async_ringbuf::AsyncHeapProducer; +use async_ringbuf::{AsyncHeapConsumer, AsyncHeapProducer, AsyncHeapRb}; use std::future::Future; use std::pin::pin; use std::task::Poll::Ready; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, Waker}; use tokio::runtime::Handle; -use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -pub const RING_BUFFER_SIZE: usize = 256 * 1024; // 256KB +/// 256KB ring buffer. +pub const RING_BUFFER_SIZE: usize = 256 * 1024; /// Sender Commands other than bytes. pub(crate) enum SendCommand { Flush, } +pub(crate) fn make_sender(handle: Handle) -> (MsgSender, MsgRcv) { + let (cmd, cmd_recv) = unbounded_channel::(); + let (rings_prd, rings) = unbounded_channel::>(); + let (ring_buf, ring) = AsyncHeapRb::::new(RING_BUFFER_SIZE).split(); + rings_prd.send(ring).unwrap(); + ( + MsgSender { + cmd, + ring_buf, + rings_prd, + handle, + }, + MsgRcv { cmd_recv, rings }, + ) +} + +pub(crate) struct MsgRcv { + pub(crate) cmd_recv: UnboundedReceiver, + pub(crate) rings: UnboundedReceiver>, +} + /// Drop the sender to close the connection. -pub struct CMsgSender { +pub struct MsgSender { pub(crate) cmd: UnboundedSender, - pub(crate) buf_prd: AsyncHeapProducer, + pub(crate) ring_buf: AsyncHeapProducer, + pub(crate) rings_prd: UnboundedSender>, pub(crate) handle: Handle, } @@ -46,29 +68,38 @@ fn burst_write( } } -impl CMsgSender { - /// The blocking API of sending bytes. +impl MsgSender { + /// The blocking API for sending bytes. /// Do not use this method in the callback (i.e. async context), /// as it might block. pub fn send_block(&mut self, bytes: &[u8]) -> std::io::Result<()> { + if bytes.is_empty() { + return Ok(()); + } + if self.ring_buf.is_closed() { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "connection closed", + )); + } let mut offset = 0usize; // attempt to write the entire message without blocking - if let BurstWriteState::Finished = burst_write(&mut offset, &mut self.buf_prd, bytes) { + if let BurstWriteState::Finished = burst_write(&mut offset, &mut self.ring_buf, bytes) { return Ok(()); } // unfinished, enter into future self.handle.clone().block_on(async { loop { - self.buf_prd.wait_free(RING_BUFFER_SIZE / 4).await; + self.ring_buf.wait_free(1).await; // check if closed - if self.buf_prd.is_closed() { + if self.ring_buf.is_closed() { break Err(std::io::Error::new( std::io::ErrorKind::Other, "connection closed", )); } if let BurstWriteState::Finished = - burst_write(&mut offset, &mut self.buf_prd, bytes) + burst_write(&mut offset, &mut self.ring_buf, bytes) { return Ok(()); } @@ -77,46 +108,65 @@ impl CMsgSender { Ok(()) } - /// Try sending bytes. + /// The non-blocking API for sending bytes. /// - /// Returning -1 to indicate pending. - pub fn try_send(&mut self, bytes: &[u8], waker_obj: Option) -> std::io::Result { + /// This API does not implement back pressure. + /// It caches all received bytes in memory + /// (efficiently using a chain of ring buffers). + pub fn send_nonblock(&mut self, bytes: &[u8]) -> std::io::Result<()> { if bytes.is_empty() { - return Ok(0); + return Ok(()); } - if self.buf_prd.is_closed() { + if self.ring_buf.is_closed() { return Err(std::io::Error::new( std::io::ErrorKind::WriteZero, "connection closed", )); } - let n = self.buf_prd.as_mut_base().push_slice(bytes); - if n > 0 { - return Ok(n as i64); - } - // n = 0, not closed - if let Some(waker_obj) = waker_obj { - // some waker, wait on the waker - let waker = unsafe { waker_obj.make_waker() }; - match pin!(self.buf_prd.wait_free(RING_BUFFER_SIZE / 4)) - .poll(&mut Context::from_waker(&waker)) - { - Ready(_) => { - // might be ready on closed - if self.buf_prd.is_closed() { - return Err(std::io::Error::new( - std::io::ErrorKind::WriteZero, - "connection closed", - )); - } - let n = self.buf_prd.as_mut_base().push_slice(bytes); - Ok(n as i64) - } - Poll::Pending => Ok(-1), + let mut offset = 0usize; + // loop until Finished writing the entire message. + loop { + if let BurstWriteState::Finished = burst_write(&mut offset, &mut self.ring_buf, bytes) { + break; } + // allocate new ring buffer if unable to write the entire message. + let remaining = bytes.len() - offset; + let new_buf_size = RING_BUFFER_SIZE.max(remaining); + let (ring_buf, ring) = AsyncHeapRb::::new(new_buf_size).split(); + self.rings_prd.send(ring).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::WriteZero, + format!("connection closed: {e}"), + ) + })?; + // set new ring_buf + self.ring_buf = ring_buf; + } + Ok(()) + } + + /// Try sending bytes (the async api). + /// + /// Unless the buffer is empty, it shouldn't return 0. + pub fn send_async(&mut self, bytes: &[u8], waker: Waker) -> Poll> { + if bytes.is_empty() { + return Ready(Ok(0)); + } + if self.ring_buf.is_closed() { + return Ready(Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "connection closed", + ))); + } + let mut offset = 0usize; + // attempt to write the entire message without blocking + burst_write(&mut offset, &mut self.ring_buf, bytes); + if offset > 0 { + Ready(Ok(offset)) } else { - // no waker, return 0 - Ok(0) + // buffer full nothing written, enter into future + let _ = pin!(self.ring_buf.wait_free(1)).poll(&mut Context::from_waker(&waker)); + Poll::Pending } } diff --git a/src/read.rs b/src/read.rs index ae55e22..70a0985 100644 --- a/src/read.rs +++ b/src/read.rs @@ -2,18 +2,18 @@ use crate::conn::ConnConfig; use crate::Msg; use std::pin::Pin; use std::task::Poll::Ready; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll, Waker}; use std::time::Duration; use tokio::io::{AsyncBufReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}; use tokio::net::tcp::OwnedReadHalf; use tokio::time::MissedTickBehavior; -pub const MIN_MSG_BUFFER_SIZE: usize = 1024 * 8; // 8KB -pub const MAX_MSG_BUFFER_SIZE: usize = 1024 * 1024 * 8; // 8MB +pub const MIN_MSG_BUFFER_SIZE: usize = 1024 * 8; +pub const MAX_MSG_BUFFER_SIZE: usize = 1024 * 1024 * 8; /// Receive bytes ReadHalf of TcpStream and call `on_msg` callback. pub(crate) async fn handle_reader< - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Unpin + 'static, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static, >( read: OwnedReadHalf, on_msg: OnMsg, @@ -24,15 +24,11 @@ pub(crate) async fn handle_reader< .get() .min(MAX_MSG_BUFFER_SIZE) .max(MIN_MSG_BUFFER_SIZE); - match config.read_msg_flush_interval { - None => handle_reader_no_auto_flush(read, on_msg, msg_buf_size).await, - Some(duration) => { - if duration.is_zero() { - handle_reader_no_auto_flush(read, on_msg, msg_buf_size).await - } else { - handle_reader_auto_flush(read, on_msg, duration, msg_buf_size).await - } - } + let duration = config.read_msg_flush_interval; + if duration.is_zero() { + handle_reader_no_auto_flush(read, on_msg, msg_buf_size).await + } else { + handle_reader_auto_flush(read, on_msg, duration, msg_buf_size).await } } else { handle_reader_no_buf(read, on_msg).await @@ -41,7 +37,7 @@ pub(crate) async fn handle_reader< /// Has write buffer with auto flush. async fn handle_reader_auto_flush< - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Unpin + 'static, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static, >( read: OwnedReadHalf, on_msg: OnMsg, @@ -88,7 +84,7 @@ async fn handle_reader_auto_flush< /// Has write buffer but no auto flush (small messages might get stuck). /// Call `OnMsg` on batch. This is not recommended. async fn handle_reader_no_auto_flush< - OnMsg: Fn(Msg<'_>) -> Result<(), String> + Send + Unpin + 'static, + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static, >( read: OwnedReadHalf, on_msg: OnMsg, @@ -105,8 +101,8 @@ async fn handle_reader_no_auto_flush< if n == 0 { break; } - tracing::trace!("received {n} bytes"); msg_writer.write_all(bytes).await?; + tracing::trace!("received {n} bytes"); buf_reader.consume(n); } @@ -116,12 +112,15 @@ async fn handle_reader_no_auto_flush< } /// Has no write buffer, received is sent immediately. -async fn handle_reader_no_buf) -> Result<(), String> + Send + Unpin + 'static>( +async fn handle_reader_no_buf< + OnMsg: Fn(Msg<'_>, Waker) -> Poll> + Send + Unpin + 'static, +>( read: OwnedReadHalf, on_msg: OnMsg, ) -> std::io::Result<()> { let recv_buffer_size = socket2::SockRef::from(read.as_ref()).recv_buffer_size()?; tracing::trace!("recv buffer size: {}", recv_buffer_size); + let mut on_msg = OnMsgWrite { on_msg }; let mut buf_reader = BufReader::with_capacity(recv_buffer_size, read); loop { let bytes = buf_reader.fill_buf().await?; @@ -129,10 +128,8 @@ async fn handle_reader_no_buf) -> Result<(), String> + Send + if n == 0 { break; } + on_msg.write_all(bytes).await?; tracing::trace!("received {n} bytes"); - if let Err(e) = on_msg(Msg { bytes }) { - return Err(std::io::Error::new(std::io::ErrorKind::Other, e)); - } buf_reader.consume(n); } Ok(()) @@ -143,18 +140,22 @@ struct OnMsgWrite { on_msg: OnMsg, } -impl) -> Result<(), String> + Send + 'static> AsyncWrite for OnMsgWrite { +impl, Waker) -> Poll> + Send + 'static> AsyncWrite + for OnMsgWrite +{ fn poll_write( self: Pin<&mut Self>, - _: &mut Context<'_>, + cx: &mut Context<'_>, bytes: &[u8], ) -> Poll> { let on_msg = &self.on_msg; - let result = match on_msg(Msg { bytes }) { - Ok(_) => Ok(bytes.len()), - Err(e) => Err(std::io::Error::new(std::io::ErrorKind::Other, e)), - }; - Ready(result) + let result = ready!(on_msg(Msg { bytes }, cx.waker().clone())); + Ready(result.map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("error writing message: {e}"), + ) + })) } fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { diff --git a/src/write.rs b/src/write.rs index 8e162b7..3a63357 100644 --- a/src/write.rs +++ b/src/write.rs @@ -1,37 +1,30 @@ use crate::conn::ConnConfig; -use crate::msg_sender::{SendCommand, RING_BUFFER_SIZE}; +use crate::msg_sender::MsgRcv; use async_ringbuf::AsyncHeapConsumer; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::net::tcp::OwnedWriteHalf; -use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::oneshot; use tokio::time::MissedTickBehavior; /// Receive bytes from recv and write to WriteHalf of TcpStream. pub(crate) async fn handle_writer( write: OwnedWriteHalf, - recv: UnboundedReceiver, - ring_buf: AsyncHeapConsumer, + recv: MsgRcv, config: ConnConfig, stop: oneshot::Sender<()>, ) -> std::io::Result<()> { - match config.write_flush_interval { - None => handle_writer_no_auto_flush(write, recv, ring_buf, stop).await, - Some(duration) => { - if duration.is_zero() { - handle_writer_no_auto_flush(write, recv, ring_buf, stop).await - } else { - handle_writer_auto_flush(write, recv, ring_buf, duration, stop).await - } - } + let duration = config.write_flush_interval; + if duration.is_zero() { + handle_writer_no_auto_flush(write, recv, stop).await + } else { + handle_writer_auto_flush(write, recv, duration, stop).await } } async fn handle_writer_auto_flush( mut write: OwnedWriteHalf, - mut recv: UnboundedReceiver, - mut ring_buf: AsyncHeapConsumer, + mut recv: MsgRcv, duration: Duration, mut stop: oneshot::Sender<()>, ) -> std::io::Result<()> { @@ -41,123 +34,153 @@ async fn handle_writer_auto_flush( let mut flush_tick = tokio::time::interval(duration); flush_tick.set_missed_tick_behavior(MissedTickBehavior::Skip); - 'closed: loop { + 'close: loop { + // obtain a ring buffer + let ring = tokio::select! { + biased; + ring = recv.rings.recv() => ring, + _ = stop.closed() => break 'close, + }; + let mut ring = match ring { + Some(ring) => ring, + None => break 'close, + }; + let chunk_size = send_buf_size.min(ring.capacity()); let mut has_data = true; - // start from burst mode - 'burst: loop { - // burst mode loop - if write_all_from_ring_buf(&mut ring_buf, &mut write).await? == 0 { - // exist burst mode loop when there is no data - break 'burst; - } - } - // when burst mode got no data, switch to waker mode - 'waker: loop { + 'ring: loop { tokio::select! { biased; - Some(_) = recv.recv() => { - // on flush, read all data from ring buffer and write to socket. - write_all_from_ring_buf(&mut ring_buf, &mut write).await?; - write.flush().await?; - // disable ticked flush when there is no data. + // buf threshold + _ = ring.wait(chunk_size) => { + if ring.is_closed() { + break 'ring; + } + flush(&mut ring, &mut write, chunk_size).await?; has_data = false; } - _ = ring_buf.wait(RING_BUFFER_SIZE / 4) => { - if ring_buf.is_closed() { - break 'closed; + // flush + cmd = recv.cmd_recv.recv() => { + // always flush, including if sender is dropped + flush(&mut ring, &mut write, chunk_size).await?; + if cmd.is_none() { + break 'close; } - // got a bunch of data, switch to burst mode - break 'waker; + has_data = false; } - _ = ring_buf.wait(1), if !has_data => { - if ring_buf.is_closed() { - break 'closed; + _ = ring.wait(1), if !has_data => { + if ring.is_closed() { + break 'ring; } - // got small amount of data, enable ticking flush, + // got data, no writing, enable ticking has_data = true; } + // tick flush _ = flush_tick.tick(), if has_data => { - // flush everything. - write_all_from_ring_buf(&mut ring_buf, &mut write).await?; - write.flush().await?; - // disable ticked flush when there is no data. + flush(&mut ring, &mut write, chunk_size).await?; has_data = false; } - _ = stop.closed() => { - break 'closed; - } - else => {} + _ = stop.closed() => break 'close, } } + // always clear the old ring_buf before reading the next + flush(&mut ring, &mut write, chunk_size).await?; } - write_all_from_ring_buf(&mut ring_buf, &mut write).await?; - write.flush().await?; + tracing::debug!("connection stopped"); write.shutdown().await?; - tracing::debug!("connection stopped (socket manager dropped)"); Ok(()) } async fn handle_writer_no_auto_flush( mut write: OwnedWriteHalf, - mut recv: UnboundedReceiver, - mut ring_buf: AsyncHeapConsumer, + mut recv: MsgRcv, mut stop: oneshot::Sender<()>, ) -> std::io::Result<()> { let send_buf_size = socket2::SockRef::from(write.as_ref()).send_buffer_size()?; tracing::trace!("send buffer size: {}", send_buf_size); - 'closed: loop { - // start from burst mode - 'burst: loop { - // burst mode loop - if write_all_from_ring_buf(&mut ring_buf, &mut write).await? == 0 { - // exist burst mode loop when there is no data - break 'burst; - } - } - // when burst mode got no data, switch to waker mode - 'waker: loop { + + 'close: loop { + // obtain a ring buffer + let ring = tokio::select! { + biased; + ring = recv.rings.recv() => ring, + _ = stop.closed() => break 'close, + }; + let mut ring = match ring { + Some(ring) => ring, + None => break 'close, + }; + let chunk_size = send_buf_size.min(ring.capacity()); + 'ring: loop { tokio::select! { biased; - Some(_) = recv.recv() => { - // on flush, read all data from ring buffer and write to socket. - write_all_from_ring_buf(&mut ring_buf, &mut write).await?; - write.flush().await?; - } - _ = ring_buf.wait(RING_BUFFER_SIZE / 4) => { - if ring_buf.is_closed() { - break 'closed; + // buf threshold + _ = ring.wait(chunk_size) => { + if ring.is_closed() { + break 'ring; } - break 'waker; + flush(&mut ring, &mut write, chunk_size).await?; } - _ = stop.closed() => { - break 'closed; + // flush + cmd = recv.cmd_recv.recv() => { + // always flush, including if sender is dropped + flush(&mut ring, &mut write, chunk_size).await?; + if cmd.is_none() { + break 'close; + } } - else => {} + _ = stop.closed() => break 'close, } } + // always clear the old ring_buf before reading the next + flush(&mut ring, &mut write, chunk_size).await?; } - // flush and close - write_all_from_ring_buf(&mut ring_buf, &mut write).await?; - write.flush().await?; + tracing::debug!("connection stopped"); write.shutdown().await?; - tracing::debug!("connection stopped (socket manager dropped)"); Ok(()) } /// directly write from ring buffer to bufwriter. -async fn write_all_from_ring_buf( +/// until the ring buffer is empty. +async fn flush( ring_buf: &mut AsyncHeapConsumer, write: &mut OwnedWriteHalf, + chunk_size: usize, ) -> std::io::Result { + let mut n = 0; + loop { + let written = write_chunk(ring_buf, write, chunk_size).await?; + if written == 0 { + break; + } + n += written; + } + Ok(n) +} + +#[inline] +async fn write_chunk( + ring_buf: &mut AsyncHeapConsumer, + write: &mut OwnedWriteHalf, + chunk_size: usize, +) -> std::io::Result { + debug_assert!(chunk_size > 0); let (left, right) = ring_buf.as_base().as_slices(); - let n = left.len() + right.len(); - if !left.is_empty() { - write.write_all(left).await?; + + // precompute all lengths to reduce cpu branching + let left_written = left.len().min(chunk_size); + let remaining = chunk_size - left_written; + let right_written = right.len().min(remaining); + let total = left_written + right_written; + + // execute write + if left_written > 0 { + write.write_all(&left[..left_written]).await?; } - if !right.is_empty() { - write.write_all(right).await?; + if right_written > 0 { + write.write_all(&right[..right_written]).await?; } - unsafe { ring_buf.as_mut_base().advance(n) }; - tracing::trace!("write {} bytes", n); - Ok(n) + + // update ring_buf + unsafe { ring_buf.as_mut_base().advance(total) }; + Ok(total) } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b0b2f2c..57ef01a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,17 +7,20 @@ find_package(Threads REQUIRED) # all files named test*.cpp are included set(test_files) file(GLOB test_file_paths test*.cpp) -foreach(test_file_path ${test_file_paths}) +foreach (test_file_path ${test_file_paths}) get_filename_component(TestFile ${test_file_path} NAME) list(APPEND test_files ${TestFile}) -endforeach() +endforeach () # add test sources create_test_sourcelist(Tests CommonCxxTests.cxx ${test_files}) # build test driver add_executable(CommonCxxTests ${Tests}) -target_link_libraries(CommonCxxTests PUBLIC socket_manager ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(CommonCxxTests + PUBLIC + socket_manager + ${CMAKE_THREAD_LIBS_INIT}) # enable lto for tests set_property(TARGET CommonCxxTests PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) @@ -29,4 +32,4 @@ foreach (test ${TestsToRun}) get_filename_component(TName ${test} NAME_WE) add_test(NAME ${TName} COMMAND CommonCxxTests ${TName}) set_tests_properties(${TName} PROPERTIES TIMEOUT 5) -endforeach() +endforeach () diff --git a/tests/concurrentqueue b/tests/concurrentqueue new file mode 160000 index 0000000..6dd38b8 --- /dev/null +++ b/tests/concurrentqueue @@ -0,0 +1 @@ +Subproject commit 6dd38b8a1dbaa7863aa907045f32308a56a6ff5d diff --git a/tests/lightweightsemaphore.h b/tests/lightweightsemaphore.h deleted file mode 100644 index f347373..0000000 --- a/tests/lightweightsemaphore.h +++ /dev/null @@ -1,427 +0,0 @@ -// Provides an efficient implementation of a semaphore (LightweightSemaphore). -// This is an extension of Jeff Preshing's sempahore implementation (licensed -// under the terms of its separate zlib license) that has been adapted and -// extended by Cameron Desrochers. - -#pragma once - -#include // For std::size_t -#include -#include // For std::make_signed - -#if defined(_WIN32) -// Avoid including windows.h in a header; we only need a handful of -// items, so we'll redeclare them here (this is relatively safe since -// the API generally has to remain stable between Windows versions). -// I know this is an ugly hack but it still beats polluting the global -// namespace with thousands of generic names or adding a .cpp for nothing. -extern "C" { - struct _SECURITY_ATTRIBUTES; - __declspec(dllimport) void* __stdcall CreateSemaphoreW(_SECURITY_ATTRIBUTES* lpSemaphoreAttributes, long lInitialCount, long lMaximumCount, const wchar_t* lpName); - __declspec(dllimport) int __stdcall CloseHandle(void* hObject); - __declspec(dllimport) unsigned long __stdcall WaitForSingleObject(void* hHandle, unsigned long dwMilliseconds); - __declspec(dllimport) int __stdcall ReleaseSemaphore(void* hSemaphore, long lReleaseCount, long* lpPreviousCount); -} -#elif defined(__MACH__) -#include -#elif defined(__MVS__) -#include -#elif defined(__unix__) -#include - -#if defined(__GLIBC_PREREQ) && defined(_GNU_SOURCE) -#if __GLIBC_PREREQ(2,30) -#define MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC -#endif -#endif -#endif - -namespace moodycamel -{ - namespace details - { - -// Code in the mpmc_sema namespace below is an adaptation of Jeff Preshing's -// portable + lightweight semaphore implementations, originally from -// https://github.com/preshing/cpp11-on-multicore/blob/master/common/sema.h -// LICENSE: -// Copyright (c) 2015 Jeff Preshing -// -// This software is provided 'as-is', without any express or implied -// warranty. In no event will the authors be held liable for any damages -// arising from the use of this software. -// -// Permission is granted to anyone to use this software for any purpose, -// including commercial applications, and to alter it and redistribute it -// freely, subject to the following restrictions: -// -// 1. The origin of this software must not be misrepresented; you must not -// claim that you wrote the original software. If you use this software -// in a product, an acknowledgement in the product documentation would be -// appreciated but is not required. -// 2. Altered source versions must be plainly marked as such, and must not be -// misrepresented as being the original software. -// 3. This notice may not be removed or altered from any source distribution. -#if defined(_WIN32) - class Semaphore -{ -private: - void* m_hSema; - - Semaphore(const Semaphore& other) = delete; - Semaphore& operator=(const Semaphore& other) = delete; - -public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - const long maxLong = 0x7fffffff; - m_hSema = CreateSemaphoreW(nullptr, initialCount, maxLong, nullptr); - assert(m_hSema); - } - - ~Semaphore() - { - CloseHandle(m_hSema); - } - - bool wait() - { - const unsigned long infinite = 0xffffffff; - return WaitForSingleObject(m_hSema, infinite) == 0; - } - - bool try_wait() - { - return WaitForSingleObject(m_hSema, 0) == 0; - } - - bool timed_wait(std::uint64_t usecs) - { - return WaitForSingleObject(m_hSema, (unsigned long)(usecs / 1000)) == 0; - } - - void signal(int count = 1) - { - while (!ReleaseSemaphore(m_hSema, count, nullptr)); - } -}; -#elif defined(__MACH__) -//--------------------------------------------------------- -// Semaphore (Apple iOS and OSX) -// Can't use POSIX semaphores due to http://lists.apple.com/archives/darwin-kernel/2009/Apr/msg00010.html -//--------------------------------------------------------- - class Semaphore - { - private: - semaphore_t m_sema; - - Semaphore(const Semaphore& other) = delete; - Semaphore& operator=(const Semaphore& other) = delete; - - public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - kern_return_t rc = semaphore_create(mach_task_self(), &m_sema, SYNC_POLICY_FIFO, initialCount); - assert(rc == KERN_SUCCESS); - (void)rc; - } - - ~Semaphore() - { - semaphore_destroy(mach_task_self(), m_sema); - } - - bool wait() - { - return semaphore_wait(m_sema) == KERN_SUCCESS; - } - - bool try_wait() - { - return timed_wait(0); - } - - bool timed_wait(std::uint64_t timeout_usecs) - { - mach_timespec_t ts; - ts.tv_sec = static_cast(timeout_usecs / 1000000); - ts.tv_nsec = static_cast((timeout_usecs % 1000000) * 1000); - - // added in OSX 10.10: https://developer.apple.com/library/prerelease/mac/documentation/General/Reference/APIDiffsMacOSX10_10SeedDiff/modules/Darwin.html - kern_return_t rc = semaphore_timedwait(m_sema, ts); - return rc == KERN_SUCCESS; - } - - void signal() - { - while (semaphore_signal(m_sema) != KERN_SUCCESS); - } - - void signal(int count) - { - while (count-- > 0) - { - while (semaphore_signal(m_sema) != KERN_SUCCESS); - } - } - }; -#elif defined(__unix__) || defined(__MVS__) - //--------------------------------------------------------- -// Semaphore (POSIX, Linux, zOS) -//--------------------------------------------------------- -class Semaphore -{ -private: - sem_t m_sema; - - Semaphore(const Semaphore& other) = delete; - Semaphore& operator=(const Semaphore& other) = delete; - -public: - Semaphore(int initialCount = 0) - { - assert(initialCount >= 0); - int rc = sem_init(&m_sema, 0, static_cast(initialCount)); - assert(rc == 0); - (void)rc; - } - - ~Semaphore() - { - sem_destroy(&m_sema); - } - - bool wait() - { - // http://stackoverflow.com/questions/2013181/gdb-causes-sem-wait-to-fail-with-eintr-error - int rc; - do { - rc = sem_wait(&m_sema); - } while (rc == -1 && errno == EINTR); - return rc == 0; - } - - bool try_wait() - { - int rc; - do { - rc = sem_trywait(&m_sema); - } while (rc == -1 && errno == EINTR); - return rc == 0; - } - - bool timed_wait(std::uint64_t usecs) - { - struct timespec ts; - const int usecs_in_1_sec = 1000000; - const int nsecs_in_1_sec = 1000000000; -#ifdef MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC - clock_gettime(CLOCK_MONOTONIC, &ts); -#else - clock_gettime(CLOCK_REALTIME, &ts); -#endif - ts.tv_sec += (time_t)(usecs / usecs_in_1_sec); - ts.tv_nsec += (long)(usecs % usecs_in_1_sec) * 1000; - // sem_timedwait bombs if you have more than 1e9 in tv_nsec - // so we have to clean things up before passing it in - if (ts.tv_nsec >= nsecs_in_1_sec) { - ts.tv_nsec -= nsecs_in_1_sec; - ++ts.tv_sec; - } - - int rc; - do { -#ifdef MOODYCAMEL_LIGHTWEIGHTSEMAPHORE_MONOTONIC - rc = sem_clockwait(&m_sema, CLOCK_MONOTONIC, &ts); -#else - rc = sem_timedwait(&m_sema, &ts); -#endif - } while (rc == -1 && errno == EINTR); - return rc == 0; - } - - void signal() - { - while (sem_post(&m_sema) == -1); - } - - void signal(int count) - { - while (count-- > 0) - { - while (sem_post(&m_sema) == -1); - } - } -}; -#else -#error Unsupported platform! (No semaphore wrapper available) -#endif - - } // end namespace details - - -//--------------------------------------------------------- -// LightweightSemaphore -//--------------------------------------------------------- - class LightweightSemaphore - { - public: - typedef std::make_signed::type ssize_t; - - private: - std::atomic m_count; - details::Semaphore m_sema; - int m_maxSpins; - - bool waitWithPartialSpinning(std::int64_t timeout_usecs = -1) - { - ssize_t oldCount; - int spin = m_maxSpins; - while (--spin >= 0) - { - oldCount = m_count.load(std::memory_order_relaxed); - if ((oldCount > 0) && m_count.compare_exchange_strong(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) - return true; - std::atomic_signal_fence(std::memory_order_acquire); // Prevent the compiler from collapsing the loop. - } - oldCount = m_count.fetch_sub(1, std::memory_order_acquire); - if (oldCount > 0) - return true; - if (timeout_usecs < 0) - { - if (m_sema.wait()) - return true; - } - if (timeout_usecs > 0 && m_sema.timed_wait((std::uint64_t)timeout_usecs)) - return true; - // At this point, we've timed out waiting for the semaphore, but the - // count is still decremented indicating we may still be waiting on - // it. So we have to re-adjust the count, but only if the semaphore - // wasn't signaled enough times for us too since then. If it was, we - // need to release the semaphore too. - while (true) - { - oldCount = m_count.load(std::memory_order_acquire); - if (oldCount >= 0 && m_sema.try_wait()) - return true; - if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) - return false; - } - } - - ssize_t waitManyWithPartialSpinning(ssize_t max, std::int64_t timeout_usecs = -1) - { - assert(max > 0); - ssize_t oldCount; - int spin = m_maxSpins; - while (--spin >= 0) - { - oldCount = m_count.load(std::memory_order_relaxed); - if (oldCount > 0) - { - ssize_t newCount = oldCount > max ? oldCount - max : 0; - if (m_count.compare_exchange_strong(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) - return oldCount - newCount; - } - std::atomic_signal_fence(std::memory_order_acquire); - } - oldCount = m_count.fetch_sub(1, std::memory_order_acquire); - if (oldCount <= 0) - { - if ((timeout_usecs == 0) || (timeout_usecs < 0 && !m_sema.wait()) || (timeout_usecs > 0 && !m_sema.timed_wait((std::uint64_t)timeout_usecs))) - { - while (true) - { - oldCount = m_count.load(std::memory_order_acquire); - if (oldCount >= 0 && m_sema.try_wait()) - break; - if (oldCount < 0 && m_count.compare_exchange_strong(oldCount, oldCount + 1, std::memory_order_relaxed, std::memory_order_relaxed)) - return 0; - } - } - } - if (max > 1) - return 1 + tryWaitMany(max - 1); - return 1; - } - - public: - LightweightSemaphore(ssize_t initialCount = 0, int maxSpins = 10000) : m_count(initialCount), m_maxSpins(maxSpins) - { - assert(initialCount >= 0); - assert(maxSpins >= 0); - } - - bool tryWait() - { - ssize_t oldCount = m_count.load(std::memory_order_relaxed); - while (oldCount > 0) - { - if (m_count.compare_exchange_weak(oldCount, oldCount - 1, std::memory_order_acquire, std::memory_order_relaxed)) - return true; - } - return false; - } - - bool wait() - { - return tryWait() || waitWithPartialSpinning(); - } - - bool wait(std::int64_t timeout_usecs) - { - return tryWait() || waitWithPartialSpinning(timeout_usecs); - } - - // Acquires between 0 and (greedily) max, inclusive - ssize_t tryWaitMany(ssize_t max) - { - assert(max >= 0); - ssize_t oldCount = m_count.load(std::memory_order_relaxed); - while (oldCount > 0) - { - ssize_t newCount = oldCount > max ? oldCount - max : 0; - if (m_count.compare_exchange_weak(oldCount, newCount, std::memory_order_acquire, std::memory_order_relaxed)) - return oldCount - newCount; - } - return 0; - } - - // Acquires at least one, and (greedily) at most max - ssize_t waitMany(ssize_t max, std::int64_t timeout_usecs) - { - assert(max >= 0); - ssize_t result = tryWaitMany(max); - if (result == 0 && max > 0) - result = waitManyWithPartialSpinning(max, timeout_usecs); - return result; - } - - ssize_t waitMany(ssize_t max) - { - ssize_t result = waitMany(max, -1); - assert(result > 0); - return result; - } - - void signal(ssize_t count = 1) - { - assert(count >= 0); - ssize_t oldCount = m_count.fetch_add(count, std::memory_order_release); - ssize_t toRelease = -oldCount < count ? -oldCount : count; - if (toRelease > 0) - { - m_sema.signal((int)toRelease); - } - } - - std::size_t availableApprox() const - { - ssize_t count = m_count.load(std::memory_order_relaxed); - return count > 0 ? static_cast(count) : 0; - } - }; - -} // end namespace moodycamel diff --git a/tests/test_auto_flush.cpp b/tests/test_auto_flush.cpp index 57624c9..d2d9f42 100644 --- a/tests/test_auto_flush.cpp +++ b/tests/test_auto_flush.cpp @@ -1,4 +1,5 @@ #undef NDEBUG + #include "test_utils.h" #include #include @@ -10,8 +11,8 @@ class ReceiverHelloWorld : public DoNothingReceiver { std::atomic_bool &received) : mutex(mutex), cond(cond), received(received) {} - void on_message(const std::shared_ptr &data) override { - if (*data == "hello world") { + void on_message(std::string_view data) override { + if (data == "hello world") { received.store(true); std::unique_lock u_lock(mutex); cond.notify_all(); @@ -26,10 +27,11 @@ class ReceiverHelloWorld : public DoNothingReceiver { class HelloWorldManager : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto do_nothing = std::make_unique(mutex, cond, received); - sender = conn->start(std::move(do_nothing)); - sender->send("hello world"); + conn->start(std::move(do_nothing)); + this->sender = send; + sender->send_block("hello world"); } std::mutex mutex; @@ -42,11 +44,12 @@ class HelloWorldManager : public DoNothingConnCallback { class SendHelloWorldDoNotClose : public DoNothingConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto do_nothing = std::make_unique(); - sender = conn->start(std::move(do_nothing)); + conn->start(std::move(do_nothing)); + this->sender = sender; std::thread t([this] { - sender->send("hello world"); + this->sender->send_block("hello world"); }); t.detach(); } diff --git a/tests/test_callback_throw_error.cpp b/tests/test_callback_throw_error.cpp index 7006584..08367b1 100644 --- a/tests/test_callback_throw_error.cpp +++ b/tests/test_callback_throw_error.cpp @@ -1,4 +1,5 @@ #undef NDEBUG + #include "test_utils.h" #include #include @@ -8,29 +9,30 @@ using namespace socket_manager; class OnConnectErrorBeforeStartCallback : public DoNothingConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { throw std::runtime_error("throw some error before calling start"); } }; class OnConnectErrorAfterStartCallback : public DoNothingConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { conn->start(std::make_unique()); throw std::runtime_error("throw some error after calling start"); } }; class OnMsgErrorReceiver : public MsgReceiver { - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { throw std::runtime_error("throw some error on receiving message"); } }; class OnMsgErrorCallback : public ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { - sender = conn->start(std::make_unique()); + std::shared_ptr conn, std::shared_ptr send) override { + conn->start(std::make_unique()); + this->sender = send; sender.use_count(); } @@ -40,20 +42,21 @@ class OnMsgErrorCallback : public ConnCallback { void on_connect_error(const std::string &addr, const std::string &err) override {} -private: std::shared_ptr sender; }; class StoreAllEventsConnHelloCallback : public StoreAllEventsConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { std::unique_lock lock(mutex); auto conn_id = local_addr + "->" + peer_addr; events.emplace_back(CONNECTED, conn_id); auto msg_storer = std::make_unique(conn_id, mutex, cond, buffer); - auto sender = conn->start(std::move(msg_storer)); + conn->start(std::move(msg_storer)); std::thread t1([sender]() { - sender->send("hello"); + try { + sender->send_block("hello"); + } catch (std::runtime_error &e) { /* ignore */ } }); t1.detach(); senders.emplace(conn_id, sender); diff --git a/tests/test_error_call_after_abort.cpp b/tests/test_error_call_after_abort.cpp index fb29ee9..10a9af9 100644 --- a/tests/test_error_call_after_abort.cpp +++ b/tests/test_error_call_after_abort.cpp @@ -15,7 +15,7 @@ int test_error_call_after_abort(int argc, char **argv) { try { nothing.connect_to_addr("127.0.0.1:40103"); // should not reach here - return -1; + return 1; } catch (std::runtime_error &e) { std::cout << "connect_to_addr after abort-join throw error: " << e.what() << std::endl; } diff --git a/tests/test_error_send_after_closed.cpp b/tests/test_error_send_after_closed.cpp index b3bcd3e..9057e1a 100644 --- a/tests/test_error_send_after_closed.cpp +++ b/tests/test_error_send_after_closed.cpp @@ -60,6 +60,6 @@ int test_error_send_after_closed(int argc, char **argv) { } // should not reach here - return -1; + return 1; } diff --git a/tests/test_error_twice_start.cpp b/tests/test_error_twice_start.cpp index 7c6c228..88fa0f4 100644 --- a/tests/test_error_twice_start.cpp +++ b/tests/test_error_twice_start.cpp @@ -1,11 +1,12 @@ #undef NDEBUG + #include "test_utils.h" #include #include class DoNothingMsgReceiver : public MsgReceiver { public: - void on_message(const std::shared_ptr &data) override {} + void on_message(std::string_view data) override {} }; class TwiceStartCallback : public ConnCallback { @@ -14,7 +15,7 @@ class TwiceStartCallback : public ConnCallback { TwiceStartCallback() : error_thrown(0) {}; void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto receiver = std::make_unique(); auto receiver2 = std::make_unique(); conn->start(std::move(receiver)); diff --git a/tests/test_find_package/CMakeLists.txt b/tests/test_find_package/CMakeLists.txt index 56679fc..6590361 100644 --- a/tests/test_find_package/CMakeLists.txt +++ b/tests/test_find_package/CMakeLists.txt @@ -4,9 +4,9 @@ set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/toolchain.cmake) project(helloworld_server) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) add_executable(helloworld_server helloworld_server.cpp) set_property(TARGET helloworld_server PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) -find_package(socket_manager 0.1.0 REQUIRED) +find_package(socket_manager 0.3.0 REQUIRED) target_link_libraries(helloworld_server PUBLIC socket_manager) diff --git a/tests/test_find_package/helloworld_server.cpp b/tests/test_find_package/helloworld_server.cpp index eb0d849..ac4cf2a 100644 --- a/tests/test_find_package/helloworld_server.cpp +++ b/tests/test_find_package/helloworld_server.cpp @@ -11,11 +11,11 @@ class HelloWorldReceiver : public socket_manager::MsgReceiver { std::unordered_map> &senders) : conn_id(std::move(conn_id)), mutex(mutex), senders(senders) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { try { std::unique_lock my_lock(mutex); auto sender = senders.at(conn_id); - sender->send("HTTP/1.1 200 OK\r\nContent-Length: 12\r\nConnection: close\r\n\r\nHello, world"); + sender->send_block("HTTP/1.1 200 OK\r\nContent-Length: 12\r\nConnection: close\r\n\r\nHello, world"); senders.erase(conn_id); } catch (const std::out_of_range &e) { std::cerr << "Exception at " << e.what() << std::endl; @@ -33,9 +33,10 @@ class MyCallback : public socket_manager::ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, + std::shared_ptr sender) override { auto id = local_addr + "->" + peer_addr; - auto sender = conn->start(std::make_unique(id, mutex, senders)); + conn->start(std::make_unique(id, mutex, senders)); { std::unique_lock my_lock(mutex); senders[id] = sender; diff --git a/tests/test_manual_flush.cpp b/tests/test_manual_flush.cpp index eaf3525..933e044 100644 --- a/tests/test_manual_flush.cpp +++ b/tests/test_manual_flush.cpp @@ -13,8 +13,8 @@ class FinalReceiver : public MsgReceiver { private: - void on_message(const std::shared_ptr &data) override { - assert(*data == "hello world"); + void on_message(std::string_view data) override { + assert(data == "hello world"); std::unique_lock lk(mutex); has_received = true; cond.notify_one(); @@ -32,10 +32,10 @@ class EchoReceiver : public MsgReceiver { : has_received(hasReceived), _data(data), mutex(mutex), cond(cond) {} private: - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { std::unique_lock lk(mutex); has_received = true; - _data.append(*data); + _data.append(data); cond.notify_one(); std::cout << "echo received" << std::endl; } @@ -48,12 +48,12 @@ class EchoReceiver : public MsgReceiver { class HelloCallback : public ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(has_received, mutex, cond); // disable write auto flush - auto sender = conn->start(std::move(rcv), DEFAULT_MSG_BUF_SIZE, 1, 0); + conn->start(std::move(rcv), nullptr, DEFAULT_MSG_BUF_SIZE, 1, 0); std::thread t([sender] { - sender->send("hello world"); + sender->send_block("hello world"); sender->flush(); }); t.detach(); @@ -76,14 +76,14 @@ class HelloCallback : public ConnCallback { class EchoCallback : public ConnCallback { void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(has_received, _data, mutex, cond); // disable write auto flush - auto sender = conn->start(std::move(rcv), DEFAULT_MSG_BUF_SIZE, 1, 0); + conn->start(std::move(rcv), nullptr, DEFAULT_MSG_BUF_SIZE, 1, 0); std::thread t([sender, this]() { std::unique_lock lk(mutex); cond.wait(lk, [this]() { return has_received; }); - sender->send(_data); + sender->send_block(_data); std::cout << "echo received and sent back" << std::endl; }); t.detach(); diff --git a/tests/test_transfer_data_large.cpp b/tests/test_transfer_data_large.cpp index 25da4cb..15eb96c 100644 --- a/tests/test_transfer_data_large.cpp +++ b/tests/test_transfer_data_large.cpp @@ -1,24 +1,26 @@ #undef NDEBUG + #include "test_utils.h" #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; class SendLargeDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data std::string data; - for (int i = 0; i < 1024 * 1024; i++) { + data.reserve(1024 * 1000); + for (int i = 0; i < 100 * 1024; i++) { data.append("helloworld"); } - for (int i = 0; i < 100; ++i) { - sender->send(data); + for (int i = 0; i < 1024; ++i) { + sender->send_block(data); } // close connection after sender finished. }); @@ -26,20 +28,20 @@ class SendLargeDataConnCallback : public DoNothingConnCallback { } }; -class StoreAllData : public MsgReceiver { +class StoreAllDataLarge : public MsgReceiver { public: - explicit StoreAllData(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataLarge(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int)data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -47,10 +49,11 @@ class StoreAllDataNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { - auto rcv = std::make_unique(add_data, count); + std::shared_ptr conn, std::shared_ptr send) override { + auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -67,7 +70,7 @@ class StoreAllDataNotifyOnCloseCallback : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_large_async.cpp b/tests/test_transfer_data_large_async.cpp index 134a99f..8c10045 100644 --- a/tests/test_transfer_data_large_async.cpp +++ b/tests/test_transfer_data_large_async.cpp @@ -1,59 +1,49 @@ #undef NDEBUG #include "test_utils.h" -#include "lightweightsemaphore.h" -#include +#include +#include "concurrentqueue/concurrentqueue.h" +#include "concurrentqueue/lightweightsemaphore.h" #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; -class CondWaker : public Waker { +class CondWaker : public Notifier { public: - explicit CondWaker(const std::shared_ptr &sem) - : sem(sem), waker_ref_count(0) {} + explicit CondWaker(const std::shared_ptr &sem) : sem(sem) {} void wake() override { - if (this->waker_ref_count.load(std::memory_order_acquire) > 0) { - sem->signal(); - } - } - - void clone() override { - this->waker_ref_count.fetch_add(1, std::memory_order_acq_rel); - } - - void release() override { - this->waker_ref_count.fetch_sub(1, std::memory_order_acq_rel); + sem->signal(); } private: std::shared_ptr sem; - std::atomic_size_t waker_ref_count; }; class SendLargeDataConnCallbackAsync : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + auto sem = std::make_shared(); + auto waker = std::make_shared(sem); + conn->start(std::move(rcv), waker); + + std::string data; + data.reserve(1024 * 1000); + for (int i = 0; i < 100 * 1024; i++) { + data.append("helloworld"); + } - std::thread t([sender]() { + std::thread t([=]() { // send 1000MB data int progress = 0; size_t offset = 0; - auto sem = std::make_shared(); - auto waker = std::make_shared(sem); - - std::string data; - for (int i = 0; i < 1024 * 1024; i++) { - data.append("helloworld"); - } - - while (progress < 100) { - auto sent = sender->try_send(data, offset, waker); + std::string_view data_view(data); + while (progress < 1024) { + auto sent = sender->send_async(data_view.substr(offset)); if (sent < 0) { - sem->wait(); + sem->wait(1000); } else { offset += sent; } @@ -70,18 +60,18 @@ class SendLargeDataConnCallbackAsync : public DoNothingConnCallback { class StoreAllDataAsync : public MsgReceiver { public: - explicit StoreAllDataAsync(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataAsync(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int) data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -89,10 +79,11 @@ class StoreAllDataNotifyOnCloseCallbackAsync : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -109,7 +100,7 @@ class StoreAllDataNotifyOnCloseCallbackAsync : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_large_manual_flush.cpp b/tests/test_transfer_data_large_manual_flush.cpp index 7dc106f..b58f781 100644 --- a/tests/test_transfer_data_large_manual_flush.cpp +++ b/tests/test_transfer_data_large_manual_flush.cpp @@ -1,24 +1,26 @@ #undef NDEBUG + #include "test_utils.h" #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; class SendLargeManualDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data std::string data; - for (int i = 0; i < 1024 * 1024; i++) { + data.reserve(1024 * 1000); + for (int i = 0; i < 100 * 1024; i++) { data.append("helloworld"); } - for (int i = 0; i < 100; ++i) { - sender->send(data); + for (int i = 0; i < 1024; ++i) { + sender->send_block(data); } // close connection after sender finished. }); @@ -26,20 +28,20 @@ class SendLargeManualDataConnCallback : public DoNothingConnCallback { } }; -class StoreAllData : public MsgReceiver { +class StoreAllDataLargeManual : public MsgReceiver { public: - explicit StoreAllData(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataLargeManual(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int)data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -47,10 +49,11 @@ class StoreAllDataNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { - auto rcv = std::make_unique(add_data, count); + std::shared_ptr conn, std::shared_ptr send) override { + auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -67,7 +70,7 @@ class StoreAllDataNotifyOnCloseCallback : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_small_busy.cpp b/tests/test_transfer_data_large_nonblock.cpp similarity index 58% rename from tests/test_transfer_data_small_busy.cpp rename to tests/test_transfer_data_large_nonblock.cpp index da2dbc5..6168443 100644 --- a/tests/test_transfer_data_small_busy.cpp +++ b/tests/test_transfer_data_large_nonblock.cpp @@ -4,68 +4,56 @@ #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; - -const std::string DATA = "helloworld" - "helloworld" - "helloworld" - "helloworld" - "helloworld" - "helloworld" - "helloworld" - "helloworld" - "helloworld" - "helloworld"; - -class SendLargeDataConnCallbackBusy : public DoNothingConnCallback { +const size_t MSG_BUF_SIZE = 256 * 1024; + +class SendLargeNonBlockDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); - + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data - int progress = 0; - size_t offset = 0; - while (progress < 1024 * 1024 * 10) { - offset += sender->try_send(DATA, offset); - if (offset == DATA.size()) { - offset = 0; - progress += 1; - } + std::string data; + data.reserve(1024 * 1000); + for (int i = 0; i < 100 * 1024; i++) { + data.append("helloworld"); } + for (int i = 0; i < 1024; ++i) { + sender->send_nonblock(data); + } + // close connection after sender finished. }); - t.detach(); } }; -class StoreAllDataBusy : public MsgReceiver { +class StoreAllDataLargeNonBlock : public MsgReceiver { public: - explicit StoreAllDataBusy(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataLargeNonBlock(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int) data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; -class StoreAllDataNotifyOnCloseCallbackBusy : public ConnCallback { +class StoreAllDataNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { - auto rcv = std::make_unique(add_data, count); + std::shared_ptr conn, std::shared_ptr send) override { + auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -82,16 +70,16 @@ class StoreAllDataNotifyOnCloseCallbackBusy : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; -int test_transfer_data_small_busy(int argc, char **argv) { +int test_transfer_data_large_nonblock(int argc, char **argv) { const std::string addr = "127.0.0.1:40013"; - auto send_cb = std::make_shared(); - auto store_cb = std::make_shared(); + auto send_cb = std::make_shared(); + auto store_cb = std::make_shared(); SocketManager send(send_cb); SocketManager store(store_cb); diff --git a/tests/test_transfer_data_mid.cpp b/tests/test_transfer_data_mid.cpp index b4a22bf..7f18e98 100644 --- a/tests/test_transfer_data_mid.cpp +++ b/tests/test_transfer_data_mid.cpp @@ -1,16 +1,17 @@ #undef NDEBUG + #include "test_utils.h" #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; class SendMidDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data std::string data; @@ -18,7 +19,7 @@ class SendMidDataConnCallback : public DoNothingConnCallback { data.append("helloworld"); } for (int i = 0; i < 10 * 1024; ++i) { - sender->send(data); + sender->send_block(data); } // close connection after sender finished. }); @@ -28,18 +29,18 @@ class SendMidDataConnCallback : public DoNothingConnCallback { class StoreAllDataMid : public MsgReceiver { public: - explicit StoreAllDataMid(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataMid(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int)data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -47,10 +48,11 @@ class StoreAllDataMidNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -67,7 +69,7 @@ class StoreAllDataMidNotifyOnCloseCallback : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_mid_async.cpp b/tests/test_transfer_data_mid_async.cpp index 67556b9..af30ea3 100644 --- a/tests/test_transfer_data_mid_async.cpp +++ b/tests/test_transfer_data_mid_async.cpp @@ -1,59 +1,48 @@ #undef NDEBUG #include "test_utils.h" -#include "lightweightsemaphore.h" -#include +#include +#include "concurrentqueue/concurrentqueue.h" +#include "concurrentqueue/lightweightsemaphore.h" #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; -class CondWaker : public Waker { +class CondWaker : public Notifier { public: - explicit CondWaker(const std::shared_ptr &sem) - : sem(sem), waker_ref_count(0) {} + explicit CondWaker(const std::shared_ptr &sem) : sem(sem) {} void wake() override { - if (this->waker_ref_count.load(std::memory_order_acquire) > 0) { - sem->signal(); - } - } - - void clone() override { - this->waker_ref_count.fetch_add(1, std::memory_order_acq_rel); - } - - void release() override { - this->waker_ref_count.fetch_sub(1, std::memory_order_acq_rel); + sem->signal(); } private: std::shared_ptr sem; - std::atomic_size_t waker_ref_count; }; class SendLargeDataConnCallbackAsyncMid : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + auto sem = std::make_shared(); + auto waker = std::make_shared(sem); + conn->start(std::move(rcv), waker); - std::thread t([sender]() { + std::string data; + for (int i = 0; i < 10 * 1024; i++) { + data.append("helloworld"); + } + + std::thread t([=]() { // send 1000MB data int progress = 0; size_t offset = 0; - auto sem = std::make_shared(); - auto waker = std::make_shared(sem); - - std::string data; - for (int i = 0; i < 10 * 1024; i++) { - data.append("helloworld"); - } - + std::string_view data_view(data); while (progress < 10 * 1024) { - auto sent = sender->try_send(data, offset, waker); + auto sent = sender->send_async(data_view.substr(offset)); if (sent < 0) { - sem->wait(); + sem->wait(1000); } else { offset += sent; } @@ -70,18 +59,18 @@ class SendLargeDataConnCallbackAsyncMid : public DoNothingConnCallback { class StoreAllDataAsyncMid : public MsgReceiver { public: - explicit StoreAllDataAsyncMid(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataAsyncMid(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int) data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -89,10 +78,11 @@ class StoreAllDataNotifyOnCloseCallbackAsyncMid : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -109,7 +99,7 @@ class StoreAllDataNotifyOnCloseCallbackAsyncMid : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_mid_manual_flush.cpp b/tests/test_transfer_data_mid_manual_flush.cpp index a454742..e52ea7e 100644 --- a/tests/test_transfer_data_mid_manual_flush.cpp +++ b/tests/test_transfer_data_mid_manual_flush.cpp @@ -1,16 +1,17 @@ #undef NDEBUG + #include "test_utils.h" #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; class SendMidManualDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data std::string data; @@ -18,7 +19,7 @@ class SendMidManualDataConnCallback : public DoNothingConnCallback { data.append("helloworld"); } for (int i = 0; i < 10 * 1024; ++i) { - sender->send(data); + sender->send_block(data); } // close connection after sender finished. }); @@ -28,18 +29,18 @@ class SendMidManualDataConnCallback : public DoNothingConnCallback { class StoreAllDataMidManual : public MsgReceiver { public: - explicit StoreAllDataMidManual(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataMidManual(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int)data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -47,10 +48,11 @@ class StoreAllDataMidManualNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -67,7 +69,7 @@ class StoreAllDataMidManualNotifyOnCloseCallback : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_mid_nonblock.cpp b/tests/test_transfer_data_mid_nonblock.cpp new file mode 100644 index 0000000..fd62eb3 --- /dev/null +++ b/tests/test_transfer_data_mid_nonblock.cpp @@ -0,0 +1,107 @@ +#undef NDEBUG + +#include "test_utils.h" +#include +#include + +const size_t MSG_BUF_SIZE = 256 * 1024; + +class SendMidNonBlockDataConnCallback : public DoNothingConnCallback { +public: + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, std::shared_ptr sender) override { + auto rcv = std::make_unique(); + conn->start(std::move(rcv)); + std::thread t([sender]() { + // send 1000MB data + std::string data; + for (int i = 0; i < 10 * 1024; i++) { + data.append("helloworld"); + } + for (int i = 0; i < 10 * 1024; ++i) { + sender->send_nonblock(data); + } + // close connection after sender finished. + }); + t.detach(); + } +}; + +class StoreAllDataMidNonBlock : public MsgReceiver { +public: + explicit StoreAllDataMidNonBlock(size_t &buffer, int &count) : buffer(buffer), count(count) {} + + void on_message(std::string_view data) override { + if (count % 100 == 0) { + std::cout << "received " << count << " messages " + << ",size = " << buffer << std::endl; + } + buffer += data.length(); + count += 1; + } + + size_t &buffer; + int &count; +}; + +class StoreAllDataMidNonBlockNotifyOnCloseCallback : public ConnCallback { +public: + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, std::shared_ptr send) override { + auto rcv = std::make_unique(add_data, count); + // store sender so connection is not dropped. + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); + } + + void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { + std::unique_lock lk(mutex); + has_closed.store(true); + std::cout << "on_connection_close" << std::endl; + cond.notify_all(); + } + + void on_listen_error(const std::string &addr, const std::string &err) override {} + + void on_connect_error(const std::string &addr, const std::string &err) override {} + + std::mutex mutex; + std::condition_variable cond; + std::atomic_bool has_closed{false}; + size_t add_data{0}; + int count{0}; + std::shared_ptr sender; +}; + +int test_transfer_data_mid_nonblock(int argc, char **argv) { + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + while (true) { + if (store_cb->has_closed.load()) { + auto avg_size = store_cb->add_data / store_cb->count; + std::cout << "received " << store_cb->count << " messages ," + << "total size = " << store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" + << std::endl; + assert(store_cb->add_data == 1024 * 1024 * 1000); + return 0; + } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + } + } +} diff --git a/tests/test_transfer_data_small.cpp b/tests/test_transfer_data_small.cpp index 73a4bfb..18332d9 100644 --- a/tests/test_transfer_data_small.cpp +++ b/tests/test_transfer_data_small.cpp @@ -1,16 +1,17 @@ #undef NDEBUG + #include "test_utils.h" #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; class SendSmallDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data std::string data; @@ -18,7 +19,7 @@ class SendSmallDataConnCallback : public DoNothingConnCallback { data.append("helloworld"); } for (int i = 0; i < 10 * 1024 * 1024; ++i) { - sender->send(data); + sender->send_block(data); } // close connection after sender finished. }); @@ -28,18 +29,18 @@ class SendSmallDataConnCallback : public DoNothingConnCallback { class StoreAllDataSmall : public MsgReceiver { public: - explicit StoreAllDataSmall(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataSmall(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int)data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -47,10 +48,11 @@ class StoreAllDataSmallNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -67,7 +69,7 @@ class StoreAllDataSmallNotifyOnCloseCallback : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_small_async.cpp b/tests/test_transfer_data_small_async.cpp index a290c73..bdd0625 100644 --- a/tests/test_transfer_data_small_async.cpp +++ b/tests/test_transfer_data_small_async.cpp @@ -1,59 +1,48 @@ #undef NDEBUG #include "test_utils.h" -#include "lightweightsemaphore.h" -#include +#include +#include "concurrentqueue/concurrentqueue.h" +#include "concurrentqueue/lightweightsemaphore.h" #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; -class CondWaker : public Waker { +class CondWaker : public Notifier { public: - explicit CondWaker(const std::shared_ptr &sem) - : sem(sem), waker_ref_count(0) {} + explicit CondWaker(const std::shared_ptr &sem) : sem(sem) {} void wake() override { - if (this->waker_ref_count.load(std::memory_order_acquire) > 0) { - sem->signal(); - } - } - - void clone() override { - this->waker_ref_count.fetch_add(1, std::memory_order_acq_rel); - } - - void release() override { - this->waker_ref_count.fetch_sub(1, std::memory_order_acq_rel); + sem->signal(); } private: std::shared_ptr sem; - std::atomic_size_t waker_ref_count; }; class SendLargeDataConnCallbackAsyncSmall : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + auto sem = std::make_shared(); + auto waker = std::make_shared(sem); + conn->start(std::move(rcv), waker); - std::thread t([sender]() { + std::string data; + for (int i = 0; i < 10; i++) { + data.append("helloworld"); + } + + std::thread t([=]() { // send 1000MB data int progress = 0; size_t offset = 0; - auto sem = std::make_shared(); - auto waker = std::make_shared(sem); - - std::string data; - for (int i = 0; i < 10; i++) { - data.append("helloworld"); - } - + std::string_view data_view(data); while (progress < 10 * 1024 * 1024) { - auto sent = sender->try_send(data, offset, waker); + auto sent = sender->send_async(data_view.substr(offset)); if (sent < 0) { - sem->wait(); + sem->wait(1000); } else { offset += sent; } @@ -70,18 +59,18 @@ class SendLargeDataConnCallbackAsyncSmall : public DoNothingConnCallback { class StoreAllDataAsyncSmall : public MsgReceiver { public: - explicit StoreAllDataAsyncSmall(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataAsyncSmall(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int) data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -89,10 +78,11 @@ class StoreAllDataNotifyOnCloseCallbackAsyncSmall : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -109,7 +99,7 @@ class StoreAllDataNotifyOnCloseCallbackAsyncSmall : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_small_manual_flush.cpp b/tests/test_transfer_data_small_manual_flush.cpp index 381381d..07cb4c1 100644 --- a/tests/test_transfer_data_small_manual_flush.cpp +++ b/tests/test_transfer_data_small_manual_flush.cpp @@ -1,16 +1,17 @@ #undef NDEBUG + #include "test_utils.h" #include #include -const size_t MSG_BUF_SIZE = 4 * 1024 * 1024; +const size_t MSG_BUF_SIZE = 256 * 1024; class SendSmallManualDataConnCallback : public DoNothingConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { auto rcv = std::make_unique(); - auto sender = conn->start(std::move(rcv)); + conn->start(std::move(rcv)); std::thread t([sender]() { // send 1000MB data std::string data; @@ -18,7 +19,7 @@ class SendSmallManualDataConnCallback : public DoNothingConnCallback { data.append("helloworld"); } for (int i = 0; i < 10 * 1024 * 1024; ++i) { - sender->send(data); + sender->send_block(data); } // close connection after sender finished. }); @@ -28,18 +29,18 @@ class SendSmallManualDataConnCallback : public DoNothingConnCallback { class StoreAllDataSmallManual : public MsgReceiver { public: - explicit StoreAllDataSmallManual(int &buffer, int &count) : buffer(buffer), count(count) {} + explicit StoreAllDataSmallManual(size_t &buffer, int &count) : buffer(buffer), count(count) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { if (count % 100 == 0) { std::cout << "received " << count << " messages " << ",size = " << buffer << std::endl; } - buffer += (int)data->size(); + buffer += data.length(); count += 1; } - int &buffer; + size_t &buffer; int &count; }; @@ -47,10 +48,11 @@ class StoreAllDataSmallManualNotifyOnCloseCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr send) override { auto rcv = std::make_unique(add_data, count); // store sender so connection is not dropped. - sender = conn->start(std::move(rcv), MSG_BUF_SIZE); + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); } void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { @@ -67,7 +69,7 @@ class StoreAllDataSmallManualNotifyOnCloseCallback : public ConnCallback { std::mutex mutex; std::condition_variable cond; std::atomic_bool has_closed{false}; - int add_data{0}; + size_t add_data{0}; int count{0}; std::shared_ptr sender; }; diff --git a/tests/test_transfer_data_small_nonblock.cpp b/tests/test_transfer_data_small_nonblock.cpp new file mode 100644 index 0000000..b677f72 --- /dev/null +++ b/tests/test_transfer_data_small_nonblock.cpp @@ -0,0 +1,107 @@ +#undef NDEBUG + +#include "test_utils.h" +#include +#include + +const size_t MSG_BUF_SIZE = 256 * 1024; + +class SendSmallNonBlockDataConnCallback : public DoNothingConnCallback { +public: + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, std::shared_ptr sender) override { + auto rcv = std::make_unique(); + conn->start(std::move(rcv)); + std::thread t([sender]() { + // send 1000MB data + std::string data; + for (int i = 0; i < 10; i++) { + data.append("helloworld"); + } + for (int i = 0; i < 10 * 1024 * 1024; ++i) { + sender->send_nonblock(data); + } + // close connection after sender finished. + }); + t.detach(); + } +}; + +class StoreAllDataSmallNonBlock : public MsgReceiver { +public: + explicit StoreAllDataSmallNonBlock(size_t &buffer, int &count) : buffer(buffer), count(count) {} + + void on_message(std::string_view data) override { + if (count % 100 == 0) { + std::cout << "received " << count << " messages " + << ",size = " << buffer << std::endl; + } + buffer += data.length(); + count += 1; + } + + size_t &buffer; + int &count; +}; + +class StoreAllDataSmallNonBlockNotifyOnCloseCallback : public ConnCallback { +public: + + void on_connect(const std::string &local_addr, const std::string &peer_addr, + std::shared_ptr conn, std::shared_ptr send) override { + auto rcv = std::make_unique(add_data, count); + // store sender so connection is not dropped. + this->sender = send; + conn->start(std::move(rcv), nullptr, MSG_BUF_SIZE); + } + + void on_connection_close(const std::string &local_addr, const std::string &peer_addr) override { + std::unique_lock lk(mutex); + has_closed.store(true); + std::cout << "on_connection_close" << std::endl; + cond.notify_all(); + } + + void on_listen_error(const std::string &addr, const std::string &err) override {} + + void on_connect_error(const std::string &addr, const std::string &err) override {} + + std::mutex mutex; + std::condition_variable cond; + std::atomic_bool has_closed{false}; + size_t add_data{0}; + int count{0}; + std::shared_ptr sender; +}; + +int test_transfer_data_small_nonblock(int argc, char **argv) { + const std::string addr = "127.0.0.1:40013"; + + auto send_cb = std::make_shared(); + auto store_cb = std::make_shared(); + SocketManager send(send_cb); + SocketManager store(store_cb); + + send.listen_on_addr(addr); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + store.connect_to_addr(addr); + + // Wait for the connection to close + while (true) { + if (store_cb->has_closed.load()) { + auto avg_size = store_cb->add_data / store_cb->count; + std::cout << "received " << store_cb->count << " messages ," + << "total size = " << store_cb->add_data << " bytes, " + << "average size = " << avg_size << " bytes" + << std::endl; + assert(store_cb->add_data == 1024 * 1024 * 1000); + return 0; + } + { + std::unique_lock u_lock(store_cb->mutex); + store_cb->cond.wait_for(u_lock, std::chrono::milliseconds(10)); + } + } +} diff --git a/tests/test_utils.h b/tests/test_utils.h index c7018aa..3d11366 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -30,7 +30,7 @@ enum EventType { /// class DoNothingReceiver : public MsgReceiver { - void on_message(const std::shared_ptr &data) override {} + void on_message(std::string_view data) override {} }; class MsgStoreReceiver : public MsgReceiver { @@ -41,9 +41,9 @@ class MsgStoreReceiver : public MsgReceiver { std::vector>> &buffer) : conn_id(std::move(conn_id)), mutex(mutex), cond(cond), buffer(buffer) {} - void on_message(const std::shared_ptr &data) override { + void on_message(std::string_view data) override { std::unique_lock lock(mutex); - buffer.emplace_back(conn_id, data); + buffer.emplace_back(conn_id, std::make_shared(data)); cond.notify_all(); } @@ -62,7 +62,7 @@ class MsgStoreReceiver : public MsgReceiver { class DoNothingConnCallback : public ConnCallback { public: void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { conn->close(); } @@ -80,7 +80,7 @@ class BitFlagCallback : public ConnCallback { : mutex(mutex), cond(cond), sig(sig), buffer(buffer) {} void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { set_sig(CONNECTED); auto conn_id = local_addr + "->" + peer_addr; auto msg_storer = std::make_unique(conn_id, mutex, cond, buffer); @@ -125,12 +125,12 @@ class StoreAllEventsConnCallback : public ConnCallback { : connected_count(0), clean_sender_on_close(clean_sender_on_close) {} void on_connect(const std::string &local_addr, const std::string &peer_addr, - const std::shared_ptr &conn) override { + std::shared_ptr conn, std::shared_ptr sender) override { std::unique_lock lock(mutex); auto conn_id = local_addr + "->" + peer_addr; events.emplace_back(CONNECTED, conn_id); auto msg_storer = std::make_unique(conn_id, mutex, cond, buffer); - auto sender = conn->start(std::move(msg_storer)); + conn->start(std::move(msg_storer)); senders.emplace(conn_id, sender); connected_count.fetch_add(1, std::memory_order_seq_cst); cond.notify_all(); @@ -163,7 +163,7 @@ class StoreAllEventsConnCallback : public ConnCallback { std::unique_lock lock(mutex); try { auto sender = senders.at(conn_id); - sender->send(data); + sender->send_block(data); sender->flush(); } catch (std::out_of_range &e) { std::cout << "connection " << conn_id << " not found during send" << std::endl;