diff --git a/src/libutil/compression.cc b/src/libutil/compression.cc index d00303dce..d17401f27 100644 --- a/src/libutil/compression.cc +++ b/src/libutil/compression.cc @@ -38,8 +38,10 @@ struct ArchiveDecompressionSource : Source { std::unique_ptr archive = 0; Source & src; - ArchiveDecompressionSource(Source & src) + std::optional compressionMethod; + ArchiveDecompressionSource(Source & src, std::optional compressionMethod = std::nullopt) : src(src) + , compressionMethod(std::move(compressionMethod)) { } ~ArchiveDecompressionSource() override {} @@ -47,7 +49,7 @@ struct ArchiveDecompressionSource : Source { struct archive_entry * ae; if (!archive) { - archive = std::make_unique(src, true); + archive = std::make_unique(src, /*raw*/ true, compressionMethod); this->archive->check(archive_read_next_header(this->archive->archive, &ae), "failed to read header (%s)"); if (archive_filter_count(this->archive->archive) < 2) { throw CompressionError("input compression not recognized"); @@ -218,8 +220,8 @@ std::unique_ptr makeDecompressionSink(const std::string & method, Si else if (method == "br") return std::make_unique(nextSink); else - return sourceToSink([&](Source & source) { - auto decompressionSource = std::make_unique(source); + return sourceToSink([method, &nextSink](Source & source) { + auto decompressionSource = std::make_unique(source, method); decompressionSource->drainInto(nextSink); }); } diff --git a/src/libutil/tarfile.cc b/src/libutil/tarfile.cc index 4a11195c4..6bb2bd2f3 100644 --- a/src/libutil/tarfile.cc +++ b/src/libutil/tarfile.cc @@ -1,18 +1,21 @@ #include #include +#include "finally.hh" #include "serialise.hh" #include "tarfile.hh" #include "file-system.hh" namespace nix { -static int callback_open(struct archive *, void * self) +namespace { + +int callback_open(struct archive *, void * self) { return ARCHIVE_OK; } -static ssize_t callback_read(struct archive * archive, void * _self, const void ** buffer) +ssize_t callback_read(struct archive * archive, void * _self, const void ** buffer) { auto self = (TarArchive *) _self; *buffer = self->buffer.data(); @@ -27,33 +30,61 @@ static ssize_t callback_read(struct archive * archive, void * _self, const void } } -static int callback_close(struct archive *, void * self) +int callback_close(struct archive *, void * self) { return ARCHIVE_OK; } -void TarArchive::check(int err, const std::string & reason) +void checkLibArchive(archive * archive, int err, const std::string & reason) { if (err == ARCHIVE_EOF) throw EndOfFile("reached end of archive"); else if (err != ARCHIVE_OK) - throw Error(reason, archive_error_string(this->archive)); + throw Error(reason, archive_error_string(archive)); } -TarArchive::TarArchive(Source & source, bool raw) - : buffer(65536) +constexpr auto defaultBufferSize = std::size_t{65536}; +} + +void TarArchive::check(int err, const std::string & reason) { - this->archive = archive_read_new(); - this->source = &source; + checkLibArchive(archive, err, reason); +} + +/// @brief Get filter_code from its name. +/// +/// libarchive does not provide a convenience function like archive_write_add_filter_by_name but for reading. +/// Instead it's necessary to use this kludge to convert method -> code and +/// then use archive_read_support_filter_by_code. Arguably this is better than +/// hand-rolling the equivalent function that is better implemented in libarchive. +int getArchiveFilterCodeByName(const std::string & method) +{ + auto * ar = archive_write_new(); + auto cleanup = Finally{[&ar]() { checkLibArchive(ar, archive_write_close(ar), "failed to close archive: %s"); }}; + auto err = archive_write_add_filter_by_name(ar, method.c_str()); + checkLibArchive(ar, err, "failed to get libarchive filter by name: %s"); + auto code = archive_filter_code(ar, 0); + return code; +} + +TarArchive::TarArchive(Source & source, bool raw, std::optional compression_method) + : archive{archive_read_new()} + , source{&source} + , buffer(defaultBufferSize) +{ + if (!compression_method) { + archive_read_support_filter_all(archive); + } else { + archive_read_support_filter_by_code(archive, getArchiveFilterCodeByName(*compression_method)); + } if (!raw) { - archive_read_support_filter_all(archive); archive_read_support_format_all(archive); } else { - archive_read_support_filter_all(archive); archive_read_support_format_raw(archive); archive_read_support_format_empty(archive); } + archive_read_set_option(archive, NULL, "mac-ext", NULL); check( archive_read_open(archive, (void *) this, callback_open, callback_read, callback_close), @@ -61,9 +92,9 @@ TarArchive::TarArchive(Source & source, bool raw) } TarArchive::TarArchive(const Path & path) + : archive{archive_read_new()} + , buffer(defaultBufferSize) { - this->archive = archive_read_new(); - archive_read_support_filter_all(archive); archive_read_support_format_all(archive); archive_read_set_option(archive, NULL, "mac-ext", NULL); diff --git a/src/libutil/tarfile.hh b/src/libutil/tarfile.hh index 200bb8f8f..705d211e4 100644 --- a/src/libutil/tarfile.hh +++ b/src/libutil/tarfile.hh @@ -15,18 +15,28 @@ struct TarArchive void check(int err, const std::string & reason = "failed to extract archive (%s)"); - TarArchive(Source & source, bool raw = false); + explicit TarArchive(const Path & path); - TarArchive(const Path & path); + /// @brief Create a generic archive from source. + /// @param source - Input byte stream. + /// @param raw - Whether to enable raw file support. For more info look in docs: + /// https://manpages.debian.org/stretch/libarchive-dev/archive_read_format.3.en.html + /// @param compression_method - Primary compression method to use. std::nullopt means 'all'. + TarArchive(Source & source, bool raw = false, std::optional compression_method = std::nullopt); - /// disable copy constructor + /// Disable copy constructor. Explicitly default move assignment/constructor. TarArchive(const TarArchive &) = delete; + TarArchive & operator=(const TarArchive &) = delete; + TarArchive(TarArchive &&) = default; + TarArchive & operator=(TarArchive &&) = default; void close(); ~TarArchive(); }; +int getArchiveFilterCodeByName(const std::string & method); + void unpackTarfile(Source & source, const Path & destDir); void unpackTarfile(const Path & tarFile, const Path & destDir);