From 3dcaf9fbae1db427f4cfea573d898f2202197067 Mon Sep 17 00:00:00 2001 From: Artur Mukhamadiev Date: Tue, 21 Apr 2026 15:29:33 +0300 Subject: [PATCH] [fix] resolved several issues found by AI --- include/cloud_point_rpc/rpc_server.hpp | 12 ++-- include/cloud_point_rpc/tcp_server.hpp | 76 ++++++++++++++++---------- src/rpc_coder.cpp | 24 ++++---- src/rpc_server.cpp | 64 ++++++++++++++++++---- src/server_api.cpp | 71 ++++++++++++++++-------- src/test_api.cpp | 20 ++++++- tests/test_c_api.cpp | 59 +++++++++++--------- 7 files changed, 221 insertions(+), 105 deletions(-) diff --git a/include/cloud_point_rpc/rpc_server.hpp b/include/cloud_point_rpc/rpc_server.hpp index d5dc23b..542897c 100644 --- a/include/cloud_point_rpc/rpc_server.hpp +++ b/include/cloud_point_rpc/rpc_server.hpp @@ -1,15 +1,16 @@ #pragma once +#include "export.h" #include #include #include #include #include -#include "export.h" +#include extern "C" { struct rpc_string { - rpc_string(const char* data, uint64_t size) : s(data,size) {} + rpc_string(const char *data, uint64_t size) : s(data, size) {} rpc_string() = default; std::string s; @@ -20,8 +21,9 @@ namespace score { class CRPC_EXPORT RpcServer { public: - using Handler = std::function; - using callback_t = rpc_string*(*)(rpc_string*); + using Handler = std::function( + const nlohmann::json &)>; + using callback_t = rpc_string *(*)(rpc_string *); void register_method(const std::string &name, Handler handler); void register_method(const std::string &name, callback_t handler); @@ -32,4 +34,4 @@ class CRPC_EXPORT RpcServer { std::map handlers_; }; -} // namespace cloud_point_rpc +} // namespace score diff --git a/include/cloud_point_rpc/tcp_server.hpp b/include/cloud_point_rpc/tcp_server.hpp index 0cf73fd..1a227fb 100644 --- a/include/cloud_point_rpc/tcp_server.hpp +++ b/include/cloud_point_rpc/tcp_server.hpp @@ -1,15 +1,15 @@ #pragma once +#include "export.h" #include #include #include #include #include -#include -#include -#include "export.h" #include #include +#include +#include namespace score { class CRPC_EXPORT TcpServer { @@ -22,7 +22,8 @@ class CRPC_EXPORT TcpServer { ~TcpServer() { stop(); - for (auto &thread : client_threads | std::views::keys) { + std::lock_guard lock(cliThrMtx_); + for (auto &thread : client_threads_ | std::views::keys) { thread.join(); } } @@ -43,27 +44,38 @@ class CRPC_EXPORT TcpServer { accept_thread_ = std::jthread([this]() { LOG(INFO) << "Accept thread started"; while (running_) { - std::ranges::remove_if(client_threads.begin(), client_threads.end(), [](auto& client_info) { - bool result = false; - if (client_info.second.wait_for(0ms) == std::future_status::ready) { - client_info.first.join(); - result = true; - } - return result; - }); + { + std::lock_guard lock(cliThrMtx_); + client_threads_.remove_if([](auto &client_info) { + bool result = false; + if (client_info.second.wait_for(0ms) == + std::future_status::ready) { + client_info.first.join(); + result = true; + } + return result; + }); + } try { auto socket = std::make_shared( io_context_); - acceptor_.accept(*socket); - + { + std::lock_guard lock(acceptorMtx_); + acceptor_.accept(*socket); + } LOG(INFO) << "New connection from " << socket->remote_endpoint().address().to_string(); auto done = std::make_shared>(); - client_threads.push_back(std::make_pair(std::jthread([this, socket, done]() { - handle_client(socket); - done->set_value(true); - }),done->get_future())); + { + std::lock_guard lock(cliThrMtx_); + client_threads_.push_back(std::make_pair( + std::jthread([this, socket, done]() { + handle_client(socket); + done->set_value(true); + }), + done->get_future())); + } } catch (const std::system_error &e) { LOG(INFO) << "Accept exception: " << e.what(); if (running_) { @@ -84,15 +96,9 @@ class CRPC_EXPORT TcpServer { return; LOG(INFO) << "Stopping server..."; running_ = false; - // Closing acceptor unblocks accept() call usually, but sometimes we - // need to prod it - asio::error_code ec; - std::ignore = acceptor_.close(ec); - if (ec.value()) { - LOG(ERROR) << std::format( - "acceptor closed with a value returned = {}", ec.value()); - } - // Ensure accept unblocks by connecting a dummy socket + // Ensure accept unblocks by connecting a dummy socket FIRST, + // while the acceptor is still open. This avoids a race where close() + // removes the listen endpoint before the connect completes. try { asio::ip::tcp::endpoint endpoint(asio::ip::make_address(ip_), port_); @@ -102,6 +108,16 @@ class CRPC_EXPORT TcpServer { } catch (...) { // Ignore } + // Now close the acceptor to unblock any pending accept() + asio::error_code ec; + { + std::lock_guard lock(acceptorMtx_); + std::ignore = acceptor_.close(ec); + } + if (ec.value()) { + LOG(ERROR) << std::format( + "acceptor closed with a value returned = {}", ec.value()); + } LOG(INFO) << "Acceptor closed"; } @@ -140,8 +156,10 @@ class CRPC_EXPORT TcpServer { asio::ip::tcp::acceptor acceptor_; std::atomic running_; - std::list>> client_threads; + std::list>> client_threads_; + std::mutex cliThrMtx_; + std::mutex acceptorMtx_; std::jthread accept_thread_; }; -} // namespace cloud_point_rpc +} // namespace score diff --git a/src/rpc_coder.cpp b/src/rpc_coder.cpp index c1f8ece..8e132d2 100644 --- a/src/rpc_coder.cpp +++ b/src/rpc_coder.cpp @@ -7,7 +7,7 @@ #include namespace score { -Base64RPCCoder::Base64RPCCoder() = default; +Base64RPCCoder::Base64RPCCoder() = default; Base64RPCCoder::~Base64RPCCoder() = default; /** @@ -15,13 +15,16 @@ Base64RPCCoder::~Base64RPCCoder() = default; * @param encoded ASCII complained base64 encoded string * @return vector of raw bytes << allocated on encoded.size() / 4 * 3 + 1 size */ -std::vector Base64RPCCoder::decode(const std::string& encoded) { +std::vector Base64RPCCoder::decode(const std::string &encoded) { + if (encoded.length() > (std::numeric_limits::max() / 3) * 4) + throw std::length_error("Base64 input too large"); DLOG(INFO) << "Base64RPCCoder::decode"; std::vector result((encoded.length() >> 2) * 3 + 1); size_t result_len = 0; - base64_decode(encoded.data(), encoded.size(), - result.data(), &result_len, 0); + base64_decode(encoded.data(), encoded.size(), result.data(), &result_len, + 0); DLOG(INFO) << "result_len: " << result_len; + result.resize(result_len); return result; } /** @@ -29,14 +32,15 @@ std::vector Base64RPCCoder::decode(const std::string& encoded) { * @param data raw byte stream * @return encoded base64 string */ -std::string Base64RPCCoder::encode(const std::vector& data) { +std::string Base64RPCCoder::encode(const std::vector &data) { + if (data.size() > (std::numeric_limits::max() / 4) * 3) + throw std::length_error("raw input is too large"); DLOG(INFO) << "Base64RPCCoder::encode"; size_t result_len = 0; - std::string result(data.size() / 3 * 4 + 1, 0); - base64_encode(data.data(), data.size(), - result.data(), &result_len, 0 - ); + std::string result((data.size() + 2) / 3 * 4, 0); + base64_encode(data.data(), data.size(), result.data(), &result_len, 0); DLOG(INFO) << "result_len: " << result_len; + result.resize(result_len); return result; } -} +} // namespace score diff --git a/src/rpc_server.cpp b/src/rpc_server.cpp index 08d3012..4aa9b39 100644 --- a/src/rpc_server.cpp +++ b/src/rpc_server.cpp @@ -1,5 +1,7 @@ #include "cloud_point_rpc/rpc_server.hpp" +#include "server_api.h" #include +#include using json = nlohmann::json; namespace score { @@ -12,21 +14,58 @@ json create_error(int code, const std::string &message, {"id", id}}; } -json create_success(const json &result, const json &id) { - return {{"jsonrpc", "2.0"}, {"result", result}, {"id", id}}; -} +struct CreateSuccess { + json obj; + json id; + + void operator()(const json &result) { + obj = {{"jsonrpc", "2.0"}, {"result", result}, {"id", id}}; + } + void operator()(const std::string &result) { + obj = {{"jsonrpc", "2.0"}, {"result", result}, {"id", id}}; + } +}; + } // namespace +template struct Deleter { + void operator()(T *element) { + (void)element; + LOG(ERROR) << "Called default deleter"; + } +}; + +template <> struct Deleter { + void operator()(rpc_string *element) { + if (element) { + crpc_str_destroy(element); + } + } +}; +using rpcStringPtr = std::unique_ptr>; + void RpcServer::register_method(const std::string &name, Handler handler) { handlers_[name] = std::move(handler); } -void RpcServer::register_method(const std::string& name, callback_t handler) { - handlers_[name] = [handler](const nlohmann::json& j) -> nlohmann::json { +void RpcServer::register_method(const std::string &name, callback_t handler) { + handlers_[name] = [handler](const nlohmann::json &j) + -> std::variant { rpc_string tmp; - tmp.s = j.dump(); - rpc_string* res = handler(&tmp); - return {res->s}; + tmp.s = j.dump(); + auto res = rpcStringPtr(handler(&tmp)); + if (!res) { + LOG(ERROR) << "Method is invalid"; + return {}; + } + std::variant ret; + try { + ret = json::parse(res->s); + } catch (std::exception &e) { + DLOG(INFO) << "return value is not a json"; + ret = res->s; + } + return ret; }; } @@ -59,11 +98,14 @@ std::string RpcServer::process(const std::string &request_str) { } try { - json result = it->second(params); - return create_success(result, id).dump(); + auto result = it->second(params); + CreateSuccess visitor; + visitor.id = id; + std::visit(visitor, result); + return visitor.obj.dump(); } catch (const std::exception &e) { return create_error(-32000, e.what(), id).dump(); // Server error } } -} // namespace cloud_point_rpc +} // namespace score diff --git a/src/server_api.cpp b/src/server_api.cpp index 84a7d56..ff43d52 100644 --- a/src/server_api.cpp +++ b/src/server_api.cpp @@ -1,53 +1,75 @@ +#include "server_api.h" #include "cloud_point_rpc/config.hpp" #include "cloud_point_rpc/rpc_server.hpp" #include "cloud_point_rpc/tcp_server.hpp" -#include -#include "server_api.h" #include +#include +#include #include #include -#include -static std::list> gc; +static std::list> gc; +std::mutex gc_mtx; +std::mutex server_mtx; score::RpcServer rpc_server; std::unique_ptr server = nullptr; extern "C" { -const char* crpc_str_get_data(const rpc_string* that) { +const char *crpc_str_get_data(const rpc_string *that) { + if (!that) { + LOG(ERROR) << "Tried to get data on nullptr"; + return nullptr; + } return that->s.c_str(); } -uint64_t crpc_str_get_size(const rpc_string* that){ +uint64_t crpc_str_get_size(const rpc_string *that) { + if (!that) { + LOG(ERROR) << "Tried to get size on nullptr"; + return 0; + } return that->s.size(); } -rpc_string* crpc_str_create(const char* data, uint64_t size){ +rpc_string *crpc_str_create(const char *data, uint64_t size) { + if (!data) { + LOG(ERROR) << "Tried to create with nullptr data"; + return nullptr; + } + std::lock_guard lock(gc_mtx); gc.push_back(std::make_unique(data, size)); return gc.back().get(); } -void crpc_str_destroy(rpc_string* that){ +void crpc_str_destroy(rpc_string *that) { + if (!that) { + LOG(ERROR) << "Tried to destroy on nullptr"; + return; + } + std::lock_guard lock(gc_mtx); auto it = std::ranges::find(gc, that, &std::unique_ptr::get); - if(it != gc.end()) + if (it != gc.end()) gc.erase(it); } - -void crpc_init(const char* config_path) { - google::InitGoogleLogging("CloudPointRPC"); - if(config_path == nullptr) { - LOG(INFO) << "config_path was not provided"; +void crpc_init(const char *config_path) { + if (!google::IsGoogleLoggingInitialized()) + google::InitGoogleLogging("CloudPointRPC"); + if (config_path == nullptr) { + LOG(ERROR) << "config_path was not provided"; + return; } try { auto config = score::ConfigLoader::load(config_path); LOG(INFO) << "Loaded config from " << config_path; - server = std::make_unique(config.server.ip, config.server.port, - [&](const std::string &request) { - return rpc_server.process( - request); - }); + server = std::make_unique( + config.server.ip, config.server.port, + [&](const std::string &request) { + std::lock_guard lock(server_mtx); + return rpc_server.process(request); + }); server->start(); } catch (const std::exception &e) { LOG(ERROR) << "Fatal error: " << e.what(); @@ -55,14 +77,17 @@ void crpc_init(const char* config_path) { } void crpc_deinit() { - if(server) - server->join(); server.reset(); + std::lock_guard lock(gc_mtx); gc.clear(); } -void crpc_add_method(callback_t cb, rpc_string* name) { +void crpc_add_method(callback_t cb, rpc_string *name) { + if (!name || !cb) { + LOG(ERROR) << "Invalid arguments (nullptr)"; + return; + } + std::lock_guard lock(server_mtx); rpc_server.register_method(name->s, cb); } } - diff --git a/src/test_api.cpp b/src/test_api.cpp index dc8cd98..08163d0 100644 --- a/src/test_api.cpp +++ b/src/test_api.cpp @@ -65,6 +65,10 @@ class TestThread { } } void add_method(const callback_t cb, rpc_string *name) { + if(!name || !name->s.size()) { + LOG(ERROR) << "Tried to add method with invalid name"; + return; + } LOG(INFO) << "Trying to add method: " << name->s; std::lock_guard lock(mtx); if (methods.contains(name->s)) { @@ -76,6 +80,10 @@ class TestThread { } int remove_method(const rpc_string *name) { + if(!name || !name->s.size()) { + LOG(ERROR) << "Tried to remove method with invalid name"; + return -1; + } LOG(INFO) << "Trying to remove method: " << name->s; std::lock_guard lock(mtx); int result = 0; @@ -90,6 +98,10 @@ class TestThread { } void call(const rpc_string *name) { + if (!name) { + LOG(ERROR) << "Called with nullptr name"; + return; + } std::lock_guard lock(mtx); LOG(INFO) << server.process(name->s); } @@ -166,7 +178,13 @@ int crpc_test_remove_method(rpc_string *name) { return test.remove_method(name); } -void crpc_test_schedule_call(rpc_string *name) { test.add_queue_call(name->s); } +void crpc_test_schedule_call(rpc_string *name) { + if (!name) { + LOG(ERROR) << "Called with name nullptr"; + return; + } + test.add_queue_call(name->s); +} void crpc_test_auto_call(uint32_t state) { test.auto_call(static_cast(state)); diff --git a/tests/test_c_api.cpp b/tests/test_c_api.cpp index 106d59a..1b97d98 100644 --- a/tests/test_c_api.cpp +++ b/tests/test_c_api.cpp @@ -35,7 +35,7 @@ TEST_F(TestCApi, Base) { task.set_value(installed); } DLOG(INFO) << "Go out"; - return crpc_str_create("res", sizeof("res")); + return crpc_str_create("res", sizeof("res") - 1); }, &name); @@ -53,24 +53,26 @@ TEST_F(TestCApi, AddedMultiple) { std::array, rpc_string>, N> called; // The Bridge: A static pointer local to this test function - static std::array, N>* bridge; + static std::array, N> *bridge; bridge = &tasks; for (int i = 0; i < N; i++) { - called[i].first = tasks[i].get_future(); - std::string n = "test" + std::to_string(i); + called[i].first = tasks[i].get_future(); + std::string n = "test" + std::to_string(i); called[i].second = rpc_string{n.c_str(), n.size()}; } auto register_idx = [&]() { - crpc_test_add_method(+[](rpc_string*) -> rpc_string* { - static bool installed = false; - if (!installed) { - installed = true; - (*bridge)[I].set_value(true); - } - return crpc_str_create("res", sizeof("res")); - }, &called[I].second); + crpc_test_add_method( + +[](rpc_string *) -> rpc_string * { + static bool installed = false; + if (!installed) { + installed = true; + (*bridge)[I].set_value(true); + } + return crpc_str_create("res", sizeof("res")); + }, + &called[I].second); }; register_idx.template operator()<0>(); @@ -111,24 +113,27 @@ TEST_F(TestCApi, ScheduleCall) { std::array, rpc_string>, N> called; // The Bridge: A static pointer local to this test function - static std::array, N>* bridge; + static std::array, N> *bridge; bridge = &tasks; LOG(INFO) << "Started Schedule Call"; for (int i = 0; i < N; i++) { - called[i].first = tasks[i].get_future(); - std::string n = "test" + std::to_string(i); + called[i].first = tasks[i].get_future(); + std::string n = "test" + std::to_string(i); called[i].second = rpc_string{n.c_str(), n.size()}; } auto register_idx = [&]() { - crpc_test_add_method(+[](rpc_string*) -> rpc_string* { - static bool installed = false; - if (!installed) { - installed = true; - (*bridge)[I].set_value(true); - } - return crpc_str_create("res", sizeof("res")); - }, &called[I].second); + crpc_test_add_method( + +[](rpc_string *) -> rpc_string * { + static bool installed = false; + if (!installed) { + installed = true; + (*bridge)[I].set_value(true); + } + std::string_view res = "res"; + return crpc_str_create(res.data(), res.size()); + }, + &called[I].second); }; auto test_idx = [&]() { using namespace std::chrono_literals; @@ -163,7 +168,9 @@ TEST_F(TestCApi, String) { name.s = "test"; EXPECT_EQ(name.s.c_str(), crpc_str_get_data(&name)); EXPECT_EQ(name.s.size(), crpc_str_get_size(&name)); - - auto creation = crpc_str_create("test 2222", sizeof("test 2222")); + std::string_view testString = "test 2222"; + auto creation = + crpc_str_create(testString.data(), testString.size()); + EXPECT_EQ(std::string_view(crpc_str_get_data(creation)), testString); EXPECT_NO_THROW(crpc_str_destroy(creation)); -} \ No newline at end of file +}