Skip to content

Commit

Permalink
Cancelling on going request when shutting down the streamer (#18)
Browse files Browse the repository at this point in the history
* stop threadpool

* stop thread pool

* stop s3 layer

* lint

* CR

* fix stop of s3 client
  • Loading branch information
noa-neria authored Nov 24, 2024
1 parent 36027bc commit bf57cc6
Show file tree
Hide file tree
Showing 32 changed files with 790 additions and 63 deletions.
1 change: 1 addition & 0 deletions cpp/common/responder/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ runai_cc_test(
srcs = ["responder_test.cc"],
deps = [":responder",
"//utils/threadpool",
"//utils/thread",
"//utils/random",
],
)
27 changes: 25 additions & 2 deletions cpp/common/responder/responder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace runai::llm::streamer::common
Responder::Responder(unsigned running) :
_running(running),
_ready(0),
_stopped(false),
_total_bytesize(0),
_start_time(std::chrono::steady_clock::now())
{
Expand All @@ -25,16 +26,22 @@ Responder::~Responder()
// return -1 if there are no running requests
Response Responder::pop()
{
if (finished())
if (_stopped || finished())
{
LOG(DEBUG) << "responder does not expect any more responses";
LOG(DEBUG) << (_stopped ? "responder does not expect any more responses" : "responder stopped");
return ResponseCode::FinishedError;
}

_ready.wait();

const auto guard = std::unique_lock<std::mutex>(_mutex);

if (_stopped)
{
LOG(DEBUG) << "responder stopped";
return ResponseCode::FinishedError;
}

ASSERT(!_responses.empty()) << "responder is empty after notification. Current running " << _running;

auto response = _responses.front();
Expand All @@ -49,6 +56,12 @@ void Responder::push(Response && response)
{
const auto guard = std::unique_lock<std::mutex>(_mutex);

if (_stopped)
{
// ignore responses
return;
}

_successful = _successful && response.ret == common::ResponseCode::Success;

if (_running)
Expand Down Expand Up @@ -91,6 +104,16 @@ void Responder::cancel()
_canceled = true;
}

void Responder::stop()
{
{
const auto guard = std::unique_lock<std::mutex>(_mutex);
_stopped = true;
}
// wake up blocking waiting threads
_ready.post();
}

size_t Responder::bytes_per_second() const
{
const auto time_ = std::chrono::steady_clock::now();
Expand Down
2 changes: 2 additions & 0 deletions cpp/common/responder/responder.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct Responder
void push(Response && response, size_t bytesize);

void cancel();
void stop();

bool finished() const;

Expand All @@ -55,6 +56,7 @@ struct Responder
mutable std::mutex _mutex;

bool _canceled = false;
std::atomic<bool> _stopped;

std::atomic<size_t> _total_bytesize;
std::chrono::time_point<std::chrono::steady_clock> _start_time;
Expand Down
51 changes: 48 additions & 3 deletions cpp/common/responder/responder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "utils/random/random.h"
#include "utils/threadpool/threadpool.h"
#include "utils/thread/thread.h"
#include "utils/semaphore/semaphore.h"

namespace runai::llm::streamer::common
Expand Down Expand Up @@ -67,7 +68,7 @@ TEST(Pop, Wait)
auto responder = Responder(size);

// create threadpool to push
auto pool = utils::ThreadPool<unsigned>([&](unsigned i)
auto pool = utils::ThreadPool<unsigned>([&](unsigned i, std::atomic<bool> &)
{
responder.push(i);
}, size);
Expand Down Expand Up @@ -111,7 +112,7 @@ TEST(Pop, Error)
auto responder = Responder(size);

// create threadpool to push
auto pool = utils::ThreadPool<int>([&](int i)
auto pool = utils::ThreadPool<int>([&](int i, std::atomic<bool> &)
{
auto r = Response(rc);
responder.push(std::move(r));
Expand Down Expand Up @@ -150,7 +151,7 @@ TEST(Pop, Unexpected_Responses)

std::atomic<unsigned> completed = 0;
auto finished = utils::Semaphore(0);
auto pool = utils::ThreadPool<unsigned>([&](unsigned i)
auto pool = utils::ThreadPool<unsigned>([&](unsigned i, std::atomic<bool> &)
{
responder.push(i);
completed++;
Expand Down Expand Up @@ -199,4 +200,48 @@ TEST(Pop, Unexpected_Responses)
}
}

TEST(Stop, Sanity)
{
auto size = utils::random::number(1, 100);
auto responder = Responder(size);

// create a thread to wait
auto waiting = utils::Thread([&]()
{
for (unsigned i = 0; i < size; ++i)
{
auto r = responder.pop();
if (r == ResponseCode::FinishedError)
{
break;
}
EXPECT_EQ(r.ret, ResponseCode::Success);
}
});

// create threadpool to push
auto pool = utils::ThreadPool<unsigned>([&](unsigned i, std::atomic<bool> &)
{
responder.push(i);
}, size);

for (unsigned i = 0; i < size; ++i)
{
unsigned value = i;
usleep(utils::random::number(100));
pool.push(std::move(value));
}

// stop the responder
usleep(utils::random::number(100 * 1000));
responder.stop();

auto times = utils::random::number(1, 10);
for (unsigned i = 0; i < times; ++i)
{
auto r = responder.pop();
EXPECT_EQ(r.ret, ResponseCode::FinishedError);
}
}

}; // namespace runai::llm::streamer::common
13 changes: 13 additions & 0 deletions cpp/common/s3_wrapper/s3_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ void S3ClientWrapper::shutdown()
}
}

void S3ClientWrapper::stop()
{
try
{
std::shared_ptr<utils::Dylib> s3_dylib(open_s3());
static auto __stop_s3_clients = s3_dylib->dlsym<void(*)()>("runai_stop_s3_clients");
__stop_s3_clients();
}
catch(...)
{
}
}

S3ClientWrapper::S3ClientWrapper(const StorageUri & uri) :
_s3_dylib(open_s3()),
_s3_client(create_client(uri))
Expand Down
5 changes: 5 additions & 0 deletions cpp/common/s3_wrapper/s3_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ struct S3ClientWrapper
ResponseCode async_read(std::vector<Range>& ranges, size_t chunk_bytesize, char * buffer);
Response async_read_response();

// stop - stops the responder of each S3 client, in order to notify callers which sent a request and are waiting for a response
// required for stopping the threadpool workers, which are bloking on the client responder
static void stop();

// destroy S3 all clients
static void shutdown();

static constexpr size_t min_chunk_bytesize = 5 * 1024 * 1024;
Expand Down
17 changes: 14 additions & 3 deletions cpp/s3/client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace runai::llm::streamer::impl::s3
{

S3Client::S3Client(const common::s3::StorageUri & uri) :
_stop(false),
_bucket_name(uri.bucket.c_str(), uri.bucket.size()),
_path(uri.path.c_str(), uri.path.size())
{
Expand Down Expand Up @@ -84,13 +85,14 @@ common::ResponseCode S3Client::async_read(unsigned num_ranges, common::Range * r
}

_responder = std::make_shared<common::Responder>(num_ranges);

Aws::S3Crt::Model::GetObjectRequest request;
request.SetBucket(_bucket_name);
request.SetKey(_path);

char * buffer_ = buffer;
common::Range * ranges_ = ranges;
for (unsigned ir = 0; ir < num_ranges; ++ir)
for (unsigned ir = 0; ir < num_ranges && !_stop; ++ir)
{
const auto & range_ = *ranges_;

Expand All @@ -106,7 +108,7 @@ common::ResponseCode S3Client::async_read(unsigned num_ranges, common::Range * r

size_t total_ = range_.size;
size_t offset_ = range_.start;
for (unsigned i = 0; i < size; ++i)
for (unsigned i = 0; i < size && !_stop; ++i)
{
size_t bytesize_ = (i == size - 1 ? total_ : chunk_bytesize);

Expand Down Expand Up @@ -164,7 +166,7 @@ common::ResponseCode S3Client::async_read(unsigned num_ranges, common::Range * r
ranges_++;
}

return common::ResponseCode::Success;
return _stop ? common::ResponseCode::FinishedError : common::ResponseCode::Success;
}

std::string S3Client::bucket() const
Expand All @@ -177,4 +179,13 @@ void S3Client::path(const std::string & path)
_path = Aws::String(path.c_str(), path.size());
}

void S3Client::stop()
{
_stop = true;
if (_responder != nullptr)
{
_responder->stop();
}
}

}; // namespace runai::llm::streamer::impl::s3
7 changes: 7 additions & 0 deletions cpp/s3/client/client.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <atomic>
#include <memory>
#include <string>

Expand All @@ -23,11 +24,17 @@ struct S3Client

common::Response async_read_response();

// Stop sending requests to the object store
// Requests that were already sent cannot be cancelled, since the Aws S3CrtClient does not support aborting requests
// The S3CrtClient d'tor will wait for response of all teh sent requests, which can take a while
void stop();

std::string bucket() const;

void path(const std::string & path);

private:
std::atomic<bool> _stop;
ClientConfiguration _client_config;
std::unique_ptr<Aws::S3Crt::S3CrtClient> _client;
const Aws::String _bucket_name;
Expand Down
17 changes: 17 additions & 0 deletions cpp/s3/client_mgr/client_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ struct ClientMgr

static void clear();

// stop all clients
static void stop();

// for testing:
static unsigned size();
static unsigned unused();
Expand Down Expand Up @@ -172,6 +175,20 @@ void ClientMgr<T>::clear()
mgr._current_bucket.clear();
}

template <typename T>
void ClientMgr<T>::stop()
{
LOG(DEBUG) << "Stopping all S3 clients";
auto & mgr = get();

const auto guard = std::unique_lock<std::mutex>(mgr._mutex);

for (auto & pair : mgr._clients)
{
pair.second->stop();
}
}

using S3ClientMgr = ClientMgr<S3Client>;

}; //namespace runai::llm::streamer::impl::s3
12 changes: 12 additions & 0 deletions cpp/s3/s3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ void runai_release_s3_clients()
}
}

void runai_stop_s3_clients()
{
try
{
S3ClientMgr::stop();
}
catch(const std::exception & e)
{
LOG(ERROR) << "Failed to stop all S3 clients";
}
}

common::ResponseCode runai_async_read_s3_client(void * client, unsigned num_ranges, common::Range * ranges, size_t chunk_bytesize, char * buffer)
{
try
Expand Down
3 changes: 3 additions & 0 deletions cpp/s3/s3.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ extern "C" void runai_remove_s3_client(void * client);
extern "C" common::ResponseCode runai_async_read_s3_client(void * client, unsigned num_ranges, common::Range * ranges, size_t chunk_bytesize, char * buffer);
// wait for asynchronous read response
extern "C" common::ResponseCode runai_async_response_s3_client(void * client, unsigned * index /* output parameter */);
// stop clients
// Stops the responder of each client, in order to notify callers which sent a request and are waiting for a response
extern "C" void runai_stop_s3_clients();
// release clients
extern "C" void runai_release_s3_clients();

Expand Down
1 change: 1 addition & 0 deletions cpp/s3/s3.ldscript
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
runai_remove_s3_client;
runai_async_read_s3_client;
runai_async_response_s3_client;
runai_stop_s3_clients;
runai_release_s3_clients;
local: *;
};
Loading

0 comments on commit bf57cc6

Please sign in to comment.