From bb6c7d416925942158425d0b9579254af94f27d4 Mon Sep 17 00:00:00 2001 From: Tristan Ross Date: Thu, 13 Mar 2025 21:51:41 -0700 Subject: [PATCH] Fix S3 not interruptting --- src/libstore/s3-binary-cache-store.cc | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/libstore/s3-binary-cache-store.cc b/src/libstore/s3-binary-cache-store.cc index cfa713b00..0cf61d85d 100644 --- a/src/libstore/s3-binary-cache-store.cc +++ b/src/libstore/s3-binary-cache-store.cc @@ -17,8 +17,12 @@ #include #include #include +#include +#include +#include #include #include +#include #include #include #include @@ -56,6 +60,49 @@ R && checkAws(std::string_view s, Aws::Utils::Outcome && outcome) return outcome.GetResultWithOwnership(); } +class AwsHttpClient : public Aws::Http::HttpClient +{ +public: + AwsHttpClient(const Aws::Client::ClientConfiguration& clientConfig) : HttpClient() {} + + std::shared_ptr MakeRequest(const std::shared_ptr& request, + Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr, + Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter = nullptr) const override { + Aws::Http::URI uri = request->GetUri(); + Aws::String url = uri.GetURIString(); + + debug("Making request %s", url); + + std::shared_ptr response = std::make_shared(request); + + if (writeLimiter != nullptr) { + writeLimiter->ApplyAndPayForCost(request->GetSize()); + } + + FileTransferRequest ftr = FileTransferRequest(url); + getFileTransfer()->download(ftr); + return response; + } +}; + +class AwsHttpClientFactory : public Aws::Http::HttpClientFactory +{ +public: + std::shared_ptr CreateHttpClient(const Aws::Client::ClientConfiguration& clientConfiguration) const override { + return std::make_shared(clientConfiguration); + } + + std::shared_ptr CreateHttpRequest(const Aws::String& uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override { + return CreateHttpRequest(Aws::Http::URI(uri), method, streamFactory); + } + + std::shared_ptr CreateHttpRequest(const Aws::Http::URI& uri, Aws::Http::HttpMethod method, const Aws::IOStreamFactory& streamFactory) const override { + auto request = std::make_shared(uri, method); + request->SetResponseStreamFactory(streamFactory); + return request; + } +}; + class AwsLogger : public Aws::Utils::Logging::FormattedLogSystem { using Aws::Utils::Logging::FormattedLogSystem::FormattedLogSystem; @@ -80,6 +127,10 @@ static void initAWS() shared.cc), so don't let aws-sdk-cpp override it. */ options.cryptoOptions.initAndCleanupOpenSSL = false; + options.httpOptions.httpClientFactory_create_fn = []() { + return std::make_shared(); + }; + if (verbosity >= lvlDebug) { options.loggingOptions.logLevel = verbosity == lvlDebug