diff --git a/src/libstore/daemon.cc b/src/libstore/daemon.cc index 47d6d5541..2add9ae12 100644 --- a/src/libstore/daemon.cc +++ b/src/libstore/daemon.cc @@ -1,6 +1,7 @@ #include "daemon.hh" #include "signals.hh" #include "worker-protocol.hh" +#include "worker-protocol-connection.hh" #include "worker-protocol-impl.hh" #include "build-result.hh" #include "store-api.hh" @@ -1026,11 +1027,9 @@ void processConnection( #endif /* Exchange the greeting. */ - unsigned int magic = readInt(from); - if (magic != WORKER_MAGIC_1) throw Error("protocol mismatch"); - to << WORKER_MAGIC_2 << PROTOCOL_VERSION; - to.flush(); - WorkerProto::Version clientVersion = readInt(from); + WorkerProto::Version clientVersion = + WorkerProto::BasicServerConnection::handshake( + to, from, PROTOCOL_VERSION); if (clientVersion < 0x10a) throw Error("the Nix client version is too old"); @@ -1048,29 +1047,20 @@ void processConnection( printMsgUsing(prevLogger, lvlDebug, "%d operations", opCount); }); - if (GET_PROTOCOL_MINOR(clientVersion) >= 14 && readInt(from)) { - // Obsolete CPU affinity. - readInt(from); - } + WorkerProto::BasicServerConnection conn { + .to = to, + .from = from, + .clientVersion = clientVersion, + }; - if (GET_PROTOCOL_MINOR(clientVersion) >= 11) - readInt(from); // obsolete reserveSpace - - if (GET_PROTOCOL_MINOR(clientVersion) >= 33) - to << nixVersion; - - if (GET_PROTOCOL_MINOR(clientVersion) >= 35) { + conn.postHandshake(*store, { + .daemonNixVersion = nixVersion, // We and the underlying store both need to trust the client for // it to be trusted. - auto temp = trusted + .remoteTrustsUs = trusted ? store->isTrustedClient() - : std::optional { NotTrusted }; - WorkerProto::WriteConn wconn { - .to = to, - .version = clientVersion, - }; - WorkerProto::write(*store, wconn, temp); - } + : std::optional { NotTrusted }, + }); /* Send startup error messages to the client. */ tunnelLogger->startWork(); diff --git a/src/libstore/legacy-ssh-store.cc b/src/libstore/legacy-ssh-store.cc index c75d50ade..146049922 100644 --- a/src/libstore/legacy-ssh-store.cc +++ b/src/libstore/legacy-ssh-store.cc @@ -4,6 +4,7 @@ #include "pool.hh" #include "remote-store.hh" #include "serve-protocol.hh" +#include "serve-protocol-connection.hh" #include "serve-protocol-impl.hh" #include "build-result.hh" #include "store-api.hh" @@ -86,7 +87,7 @@ ref LegacySSHStore::openConnection() conn->sshConn->in.close(); { NullSink nullSink; - conn->from.drainInto(nullSink); + tee.drainInto(nullSink); } throw Error("'nix-store --serve' protocol mismatch from '%s', got '%s'", host, chomp(saved.s)); @@ -168,41 +169,38 @@ void LegacySSHStore::addToStore(const ValidPathInfo & info, Source & source, } conn->to.flush(); + if (readInt(conn->from) != 1) + throw Error("failed to add path '%s' to remote host '%s'", printStorePath(info.path), host); + } else { - conn->to - << ServeProto::Command::ImportPaths - << 1; - try { - copyNAR(source, conn->to); - } catch (...) { - conn->good = false; - throw; - } - conn->to - << exportMagic - << printStorePath(info.path); - ServeProto::write(*this, *conn, info.references); - conn->to - << (info.deriver ? printStorePath(*info.deriver) : "") - << 0 - << 0; - conn->to.flush(); + conn->importPaths(*this, [&](Sink & sink) { + try { + copyNAR(source, sink); + } catch (...) { + conn->good = false; + throw; + } + sink + << exportMagic + << printStorePath(info.path); + ServeProto::write(*this, *conn, info.references); + sink + << (info.deriver ? printStorePath(*info.deriver) : "") + << 0 + << 0; + }); } - - if (readInt(conn->from) != 1) - throw Error("failed to add path '%s' to remote host '%s'", printStorePath(info.path), host); } void LegacySSHStore::narFromPath(const StorePath & path, Sink & sink) { auto conn(connections->get()); - - conn->to << ServeProto::Command::DumpStorePath << printStorePath(path); - conn->to.flush(); - copyNAR(conn->from, sink); + conn->narFromPath(*this, path, [&](auto & source) { + copyNAR(source, sink); + }); } @@ -226,7 +224,7 @@ BuildResult LegacySSHStore::buildDerivation(const StorePath & drvPath, const Bas conn->putBuildDerivationRequest(*this, drvPath, drv, buildSettings()); - return ServeProto::Serialise::read(*this, *conn); + return conn->getBuildDerivationResponse(*this); } diff --git a/src/libstore/remote-store-connection.hh b/src/libstore/remote-store-connection.hh index 44328b06b..405120ee9 100644 --- a/src/libstore/remote-store-connection.hh +++ b/src/libstore/remote-store-connection.hh @@ -3,6 +3,7 @@ #include "remote-store.hh" #include "worker-protocol.hh" +#include "worker-protocol-connection.hh" #include "pool.hh" namespace nix { @@ -14,90 +15,13 @@ namespace nix { * Contains `Source` and `Sink` for actual communication, along with * other information learned when negotiating the connection. */ -struct RemoteStore::Connection +struct RemoteStore::Connection : WorkerProto::BasicClientConnection, + WorkerProto::ClientHandshakeInfo { - /** - * Send with this. - */ - FdSink to; - - /** - * Receive with this. - */ - FdSource from; - - /** - * Worker protocol version used for the connection. - * - * Despite its name, I think it is actually the maximum version both - * sides support. (If the maximum doesn't exist, we would fail to - * establish a connection and produce a value of this type.) - */ - WorkerProto::Version daemonVersion; - - /** - * Whether the remote side trusts us or not. - * - * 3 values: "yes", "no", or `std::nullopt` for "unknown". - * - * Note that the "remote side" might not be just the end daemon, but - * also an intermediary forwarder that can make its own trusting - * decisions. This would be the intersection of all their trust - * decisions, since it takes only one link in the chain to start - * denying operations. - */ - std::optional remoteTrustsUs; - - /** - * The version of the Nix daemon that is processing our requests. - * - * Do note, it may or may not communicating with another daemon, - * rather than being an "end" `LocalStore` or similar. - */ - std::optional daemonNixVersion; - /** * Time this connection was established. */ std::chrono::time_point startTime; - - /** - * Coercion to `WorkerProto::ReadConn`. This makes it easy to use the - * factored out worker protocol searlizers with a - * `RemoteStore::Connection`. - * - * The worker protocol connection types are unidirectional, unlike - * this type. - */ - operator WorkerProto::ReadConn () - { - return WorkerProto::ReadConn { - .from = from, - .version = daemonVersion, - }; - } - - /** - * Coercion to `WorkerProto::WriteConn`. This makes it easy to use the - * factored out worker protocol searlizers with a - * `RemoteStore::Connection`. - * - * The worker protocol connection types are unidirectional, unlike - * this type. - */ - operator WorkerProto::WriteConn () - { - return WorkerProto::WriteConn { - .to = to, - .version = daemonVersion, - }; - } - - virtual ~Connection(); - - virtual void closeWrite() = 0; - - std::exception_ptr processStderr(Sink * sink = 0, Source * source = 0, bool flush = true); }; /** diff --git a/src/libstore/remote-store.cc b/src/libstore/remote-store.cc index 09196481b..d6efc14f9 100644 --- a/src/libstore/remote-store.cc +++ b/src/libstore/remote-store.cc @@ -69,50 +69,26 @@ void RemoteStore::initConnection(Connection & conn) /* Send the magic greeting, check for the reply. */ try { conn.from.endOfFileError = "Nix daemon disconnected unexpectedly (maybe it crashed?)"; - conn.to << WORKER_MAGIC_1; - conn.to.flush(); + StringSink saved; + TeeSource tee(conn.from, saved); try { - TeeSource tee(conn.from, saved); - unsigned int magic = readInt(tee); - if (magic != WORKER_MAGIC_2) - throw Error("protocol mismatch"); + conn.daemonVersion = WorkerProto::BasicClientConnection::handshake( + conn.to, tee, PROTOCOL_VERSION); } catch (SerialisationError & e) { /* In case the other side is waiting for our input, close it. */ conn.closeWrite(); - auto msg = conn.from.drain(); - throw Error("protocol mismatch, got '%s'", chomp(saved.s + msg)); + { + NullSink nullSink; + tee.drainInto(nullSink); + } + throw Error("protocol mismatch, got '%s'", chomp(saved.s)); } - conn.from >> conn.daemonVersion; - if (GET_PROTOCOL_MAJOR(conn.daemonVersion) != GET_PROTOCOL_MAJOR(PROTOCOL_VERSION)) - throw Error("Nix daemon protocol version not supported"); - if (GET_PROTOCOL_MINOR(conn.daemonVersion) < 10) - throw Error("the Nix daemon version is too old"); - conn.to << PROTOCOL_VERSION; + static_cast(conn) = conn.postHandshake(*this); - if (GET_PROTOCOL_MINOR(conn.daemonVersion) >= 14) { - // Obsolete CPU affinity. - conn.to << 0; - } - - if (GET_PROTOCOL_MINOR(conn.daemonVersion) >= 11) - conn.to << false; // obsolete reserveSpace - - if (GET_PROTOCOL_MINOR(conn.daemonVersion) >= 33) { - conn.to.flush(); - conn.daemonNixVersion = readString(conn.from); - } - - if (GET_PROTOCOL_MINOR(conn.daemonVersion) >= 35) { - conn.remoteTrustsUs = WorkerProto::Serialise>::read(*this, conn); - } else { - // We don't know the answer; protocol to old. - conn.remoteTrustsUs = std::nullopt; - } - - auto ex = conn.processStderr(); + auto ex = conn.processStderrReturn(); if (ex) std::rethrow_exception(ex); } catch (Error & e) { @@ -158,7 +134,7 @@ void RemoteStore::setOptions(Connection & conn) conn.to << i.first << i.second.value; } - auto ex = conn.processStderr(); + auto ex = conn.processStderrReturn(); if (ex) std::rethrow_exception(ex); } @@ -173,28 +149,7 @@ RemoteStore::ConnectionHandle::~ConnectionHandle() void RemoteStore::ConnectionHandle::processStderr(Sink * sink, Source * source, bool flush) { - auto ex = handle->processStderr(sink, source, flush); - if (ex) { - daemonException = true; - try { - std::rethrow_exception(ex); - } catch (const Error & e) { - // Nix versions before #4628 did not have an adequate behavior for reporting that the derivation format was upgraded. - // To avoid having to add compatibility logic in many places, we expect to catch almost all occurrences of the - // old incomprehensible error here, so that we can explain to users what's going on when their daemon is - // older than #4628 (2023). - if (experimentalFeatureSettings.isEnabled(Xp::DynamicDerivations) && - GET_PROTOCOL_MINOR(handle->daemonVersion) <= 35) - { - auto m = e.msg(); - if (m.find("parsing derivation") != std::string::npos && - m.find("expected string") != std::string::npos && - m.find("Derive([") != std::string::npos) - throw Error("%s, this might be because the daemon is too old to understand dependencies on dynamic derivations. Check to see if the raw derivation is in the form '%s'", std::move(m), "DrvWithVersion(..)"); - } - throw; - } - } + handle->processStderr(&daemonException, sink, source, flush); } @@ -226,13 +181,7 @@ StorePathSet RemoteStore::queryValidPaths(const StorePathSet & paths, Substitute if (isValidPath(i)) res.insert(i); return res; } else { - conn->to << WorkerProto::Op::QueryValidPaths; - WorkerProto::write(*this, *conn, paths); - if (GET_PROTOCOL_MINOR(conn->daemonVersion) >= 27) { - conn->to << maybeSubstitute; - } - conn.processStderr(); - return WorkerProto::Serialise::read(*this, *conn); + return conn->queryValidPaths(*this, &conn.daemonException, paths, maybeSubstitute); } } @@ -322,22 +271,10 @@ void RemoteStore::queryPathInfoUncached(const StorePath & path, std::shared_ptr info; { auto conn(getConnection()); - conn->to << WorkerProto::Op::QueryPathInfo << printStorePath(path); - try { - conn.processStderr(); - } catch (Error & e) { - // Ugly backwards compatibility hack. - if (e.msg().find("is not valid") != std::string::npos) - throw InvalidPath(std::move(e.info())); - throw; - } - if (GET_PROTOCOL_MINOR(conn->daemonVersion) >= 17) { - bool valid; conn->from >> valid; - if (!valid) throw InvalidPath("path '%s' is not valid", printStorePath(path)); - } info = std::make_shared( StorePath{path}, - WorkerProto::Serialise::read(*this, *conn)); + conn->queryPathInfo(*this, &conn.daemonException, path)); + } callback(std::move(info)); } catch (...) { callback.rethrow(); } @@ -542,8 +479,6 @@ void RemoteStore::addToStore(const ValidPathInfo & info, Source & source, auto conn(getConnection()); if (GET_PROTOCOL_MINOR(conn->daemonVersion) < 18) { - conn->to << WorkerProto::Op::ImportPaths; - auto source2 = sinkToSource([&](Sink & sink) { sink << 1 // == path follows ; @@ -558,11 +493,7 @@ void RemoteStore::addToStore(const ValidPathInfo & info, Source & source, << 0 // == no path follows ; }); - - conn.processStderr(0, source2.get()); - - auto importedPaths = WorkerProto::Serialise::read(*this, *conn); - assert(importedPaths.size() <= 1); + conn->importPaths(*this, &conn.daemonException, *source2); } else { @@ -807,9 +738,7 @@ BuildResult RemoteStore::buildDerivation(const StorePath & drvPath, const BasicD BuildMode buildMode) { auto conn(getConnection()); - conn->to << WorkerProto::Op::BuildDerivation << printStorePath(drvPath); - writeDerivation(conn->to, *this, drv); - conn->to << buildMode; + conn->putBuildDerivationRequest(*this, &conn.daemonException, drvPath, drv, buildMode); conn.processStderr(); return WorkerProto::Serialise::read(*this, *conn); } @@ -827,9 +756,7 @@ void RemoteStore::ensurePath(const StorePath & path) void RemoteStore::addTempRoot(const StorePath & path) { auto conn(getConnection()); - conn->to << WorkerProto::Op::AddTempRoot << printStorePath(path); - conn.processStderr(); - readInt(conn->from); + conn->addTempRoot(*this, &conn.daemonException, path); } @@ -969,22 +896,12 @@ void RemoteStore::flushBadConnections() connections->flushBad(); } - -RemoteStore::Connection::~Connection() -{ - try { - to.flush(); - } catch (...) { - ignoreException(); - } -} - void RemoteStore::narFromPath(const StorePath & path, Sink & sink) { - auto conn(connections->get()); - conn->to << WorkerProto::Op::NarFromPath << printStorePath(path); - conn->processStderr(); - copyNAR(conn->from, sink); + auto conn(getConnection()); + conn->narFromPath(*this, &conn.daemonException, path, [&](Source & source) { + copyNAR(conn->from, sink); + }); } ref RemoteStore::getFSAccessor(bool requireValidPath) @@ -992,91 +909,6 @@ ref RemoteStore::getFSAccessor(bool requireValidPath) return make_ref(ref(shared_from_this())); } -static Logger::Fields readFields(Source & from) -{ - Logger::Fields fields; - size_t size = readInt(from); - for (size_t n = 0; n < size; n++) { - auto type = (decltype(Logger::Field::type)) readInt(from); - if (type == Logger::Field::tInt) - fields.push_back(readNum(from)); - else if (type == Logger::Field::tString) - fields.push_back(readString(from)); - else - throw Error("got unsupported field type %x from Nix daemon", (int) type); - } - return fields; -} - - -std::exception_ptr RemoteStore::Connection::processStderr(Sink * sink, Source * source, bool flush) -{ - if (flush) - to.flush(); - - while (true) { - - auto msg = readNum(from); - - if (msg == STDERR_WRITE) { - auto s = readString(from); - if (!sink) throw Error("no sink"); - (*sink)(s); - } - - else if (msg == STDERR_READ) { - if (!source) throw Error("no source"); - size_t len = readNum(from); - auto buf = std::make_unique(len); - writeString({(const char *) buf.get(), source->read(buf.get(), len)}, to); - to.flush(); - } - - else if (msg == STDERR_ERROR) { - if (GET_PROTOCOL_MINOR(daemonVersion) >= 26) { - return std::make_exception_ptr(readError(from)); - } else { - auto error = readString(from); - unsigned int status = readInt(from); - return std::make_exception_ptr(Error(status, error)); - } - } - - else if (msg == STDERR_NEXT) - printError(chomp(readString(from))); - - else if (msg == STDERR_START_ACTIVITY) { - auto act = readNum(from); - auto lvl = (Verbosity) readInt(from); - auto type = (ActivityType) readInt(from); - auto s = readString(from); - auto fields = readFields(from); - auto parent = readNum(from); - logger->startActivity(act, lvl, type, s, fields, parent); - } - - else if (msg == STDERR_STOP_ACTIVITY) { - auto act = readNum(from); - logger->stopActivity(act); - } - - else if (msg == STDERR_RESULT) { - auto act = readNum(from); - auto type = (ResultType) readInt(from); - auto fields = readFields(from); - logger->result(act, type, fields); - } - - else if (msg == STDERR_LAST) - break; - - else - throw Error("got unknown message type %x from Nix daemon", msg); - } - - return nullptr; -} - void RemoteStore::ConnectionHandle::withFramedSink(std::function fun) { (*this)->to.flush(); diff --git a/src/libstore/serve-protocol-impl.cc b/src/libstore/serve-protocol-connection.cc similarity index 51% rename from src/libstore/serve-protocol-impl.cc rename to src/libstore/serve-protocol-connection.cc index 5e1d3f1e8..07379999b 100644 --- a/src/libstore/serve-protocol-impl.cc +++ b/src/libstore/serve-protocol-connection.cc @@ -1,3 +1,4 @@ +#include "serve-protocol-connection.hh" #include "serve-protocol-impl.hh" #include "build-result.hh" #include "derivations.hh" @@ -5,10 +6,7 @@ namespace nix { ServeProto::Version ServeProto::BasicClientConnection::handshake( - BufferedSink & to, - Source & from, - ServeProto::Version localVersion, - std::string_view host) + BufferedSink & to, Source & from, ServeProto::Version localVersion, std::string_view host) { to << SERVE_MAGIC_1 << localVersion; to.flush(); @@ -22,39 +20,30 @@ ServeProto::Version ServeProto::BasicClientConnection::handshake( return std::min(remoteVersion, localVersion); } -ServeProto::Version ServeProto::BasicServerConnection::handshake( - BufferedSink & to, - Source & from, - ServeProto::Version localVersion) +ServeProto::Version +ServeProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, ServeProto::Version localVersion) { unsigned int magic = readInt(from); - if (magic != SERVE_MAGIC_1) throw Error("protocol mismatch"); + if (magic != SERVE_MAGIC_1) + throw Error("protocol mismatch"); to << SERVE_MAGIC_2 << localVersion; to.flush(); auto remoteVersion = readInt(from); return std::min(remoteVersion, localVersion); } - StorePathSet ServeProto::BasicClientConnection::queryValidPaths( - const Store & store, - bool lock, const StorePathSet & paths, - SubstituteFlag maybeSubstitute) + const StoreDirConfig & store, bool lock, const StorePathSet & paths, SubstituteFlag maybeSubstitute) { - to - << ServeProto::Command::QueryValidPaths - << lock - << maybeSubstitute; + to << ServeProto::Command::QueryValidPaths << lock << maybeSubstitute; write(store, *this, paths); to.flush(); return Serialise::read(store, *this); } - -std::map ServeProto::BasicClientConnection::queryPathInfos( - const Store & store, - const StorePathSet & paths) +std::map +ServeProto::BasicClientConnection::queryPathInfos(const StoreDirConfig & store, const StorePathSet & paths) { std::map infos; @@ -64,7 +53,8 @@ std::map ServeProto::BasicClientConnection::que while (true) { auto storePathS = readString(from); - if (storePathS == "") break; + if (storePathS == "") + break; auto storePath = store.parseStorePath(storePathS); assert(paths.count(storePath) == 1); @@ -75,15 +65,13 @@ std::map ServeProto::BasicClientConnection::que return infos; } - void ServeProto::BasicClientConnection::putBuildDerivationRequest( - const Store & store, - const StorePath & drvPath, const BasicDerivation & drv, + const StoreDirConfig & store, + const StorePath & drvPath, + const BasicDerivation & drv, const ServeProto::BuildOptions & options) { - to - << ServeProto::Command::BuildDerivation - << store.printStorePath(drvPath); + to << ServeProto::Command::BuildDerivation << store.printStorePath(drvPath); writeDerivation(to, store, drv); ServeProto::write(store, *this, options); @@ -91,4 +79,28 @@ void ServeProto::BasicClientConnection::putBuildDerivationRequest( to.flush(); } +BuildResult ServeProto::BasicClientConnection::getBuildDerivationResponse(const StoreDirConfig & store) +{ + return ServeProto::Serialise::read(store, *this); +} + +void ServeProto::BasicClientConnection::narFromPath( + const StoreDirConfig & store, const StorePath & path, std::function fun) +{ + to << ServeProto::Command::DumpStorePath << store.printStorePath(path); + to.flush(); + + fun(from); +} + +void ServeProto::BasicClientConnection::importPaths(const StoreDirConfig & store, std::function fun) +{ + to << ServeProto::Command::ImportPaths; + fun(to); + to.flush(); + + if (readInt(from) != 1) + throw Error("remote machine failed to import closure"); +} + } diff --git a/src/libstore/serve-protocol-connection.hh b/src/libstore/serve-protocol-connection.hh new file mode 100644 index 000000000..73bf71443 --- /dev/null +++ b/src/libstore/serve-protocol-connection.hh @@ -0,0 +1,108 @@ +#pragma once +///@file + +#include "serve-protocol.hh" +#include "store-api.hh" + +namespace nix { + +struct ServeProto::BasicClientConnection +{ + FdSink to; + FdSource from; + ServeProto::Version remoteVersion; + + /** + * Establishes connection, negotiating version. + * + * @return the version provided by the other side of the + * connection. + * + * @param to Taken by reference to allow for various error handling + * mechanisms. + * + * @param from Taken by reference to allow for various error + * handling mechanisms. + * + * @param localVersion Our version which is sent over + * + * @param host Just used to add context to thrown exceptions. + */ + static ServeProto::Version + handshake(BufferedSink & to, Source & from, ServeProto::Version localVersion, std::string_view host); + + /** + * Coercion to `ServeProto::ReadConn`. This makes it easy to use the + * factored out serve protocol serializers with a + * `LegacySSHStore::Connection`. + * + * The serve protocol connection types are unidirectional, unlike + * this type. + */ + operator ServeProto::ReadConn() + { + return ServeProto::ReadConn{ + .from = from, + .version = remoteVersion, + }; + } + + /** + * Coercion to `ServeProto::WriteConn`. This makes it easy to use the + * factored out serve protocol serializers with a + * `LegacySSHStore::Connection`. + * + * The serve protocol connection types are unidirectional, unlike + * this type. + */ + operator ServeProto::WriteConn() + { + return ServeProto::WriteConn{ + .to = to, + .version = remoteVersion, + }; + } + + StorePathSet queryValidPaths( + const StoreDirConfig & remoteStore, bool lock, const StorePathSet & paths, SubstituteFlag maybeSubstitute); + + std::map queryPathInfos(const StoreDirConfig & store, const StorePathSet & paths); + ; + + void putBuildDerivationRequest( + const StoreDirConfig & store, + const StorePath & drvPath, + const BasicDerivation & drv, + const ServeProto::BuildOptions & options); + + /** + * Get the response, must be paired with + * `putBuildDerivationRequest`. + */ + BuildResult getBuildDerivationResponse(const StoreDirConfig & store); + + void narFromPath(const StoreDirConfig & store, const StorePath & path, std::function fun); + + void importPaths(const StoreDirConfig & store, std::function fun); +}; + +struct ServeProto::BasicServerConnection +{ + /** + * Establishes connection, negotiating version. + * + * @return the version provided by the other side of the + * connection. + * + * @param to Taken by reference to allow for various error handling + * mechanisms. + * + * @param from Taken by reference to allow for various error + * handling mechanisms. + * + * @param localVersion Our version which is sent over + */ + static ServeProto::Version handshake(BufferedSink & to, Source & from, ServeProto::Version localVersion); +}; + +} diff --git a/src/libstore/serve-protocol-impl.hh b/src/libstore/serve-protocol-impl.hh index 1d9c79ef0..67bc5dc6e 100644 --- a/src/libstore/serve-protocol-impl.hh +++ b/src/libstore/serve-protocol-impl.hh @@ -57,105 +57,4 @@ struct ServeProto::Serialise /* protocol-specific templates */ -struct ServeProto::BasicClientConnection -{ - FdSink to; - FdSource from; - ServeProto::Version remoteVersion; - - /** - * Establishes connection, negotiating version. - * - * @return the version provided by the other side of the - * connection. - * - * @param to Taken by reference to allow for various error handling - * mechanisms. - * - * @param from Taken by reference to allow for various error - * handling mechanisms. - * - * @param localVersion Our version which is sent over - * - * @param host Just used to add context to thrown exceptions. - */ - static ServeProto::Version handshake( - BufferedSink & to, - Source & from, - ServeProto::Version localVersion, - std::string_view host); - - /** - * Coercion to `ServeProto::ReadConn`. This makes it easy to use the - * factored out serve protocol serializers with a - * `LegacySSHStore::Connection`. - * - * The serve protocol connection types are unidirectional, unlike - * this type. - */ - operator ServeProto::ReadConn () - { - return ServeProto::ReadConn { - .from = from, - .version = remoteVersion, - }; - } - - /** - * Coercion to `ServeProto::WriteConn`. This makes it easy to use the - * factored out serve protocol serializers with a - * `LegacySSHStore::Connection`. - * - * The serve protocol connection types are unidirectional, unlike - * this type. - */ - operator ServeProto::WriteConn () - { - return ServeProto::WriteConn { - .to = to, - .version = remoteVersion, - }; - } - - StorePathSet queryValidPaths( - const Store & remoteStore, - bool lock, const StorePathSet & paths, - SubstituteFlag maybeSubstitute); - - std::map queryPathInfos( - const Store & store, - const StorePathSet & paths); - - /** - * Just the request half, because Hydra may do other things between - * issuing the request and reading the `BuildResult` response. - */ - void putBuildDerivationRequest( - const Store & store, - const StorePath & drvPath, const BasicDerivation & drv, - const ServeProto::BuildOptions & options); -}; - -struct ServeProto::BasicServerConnection -{ - /** - * Establishes connection, negotiating version. - * - * @return the version provided by the other side of the - * connection. - * - * @param to Taken by reference to allow for various error handling - * mechanisms. - * - * @param from Taken by reference to allow for various error - * handling mechanisms. - * - * @param localVersion Our version which is sent over - */ - static ServeProto::Version handshake( - BufferedSink & to, - Source & from, - ServeProto::Version localVersion); -}; - } diff --git a/src/libstore/worker-protocol-connection.cc b/src/libstore/worker-protocol-connection.cc new file mode 100644 index 000000000..072bae8da --- /dev/null +++ b/src/libstore/worker-protocol-connection.cc @@ -0,0 +1,280 @@ +#include "worker-protocol-connection.hh" +#include "worker-protocol-impl.hh" +#include "build-result.hh" +#include "derivations.hh" + +namespace nix { + +WorkerProto::BasicClientConnection::~BasicClientConnection() +{ + try { + to.flush(); + } catch (...) { + ignoreException(); + } +} + +static Logger::Fields readFields(Source & from) +{ + Logger::Fields fields; + size_t size = readInt(from); + for (size_t n = 0; n < size; n++) { + auto type = (decltype(Logger::Field::type)) readInt(from); + if (type == Logger::Field::tInt) + fields.push_back(readNum(from)); + else if (type == Logger::Field::tString) + fields.push_back(readString(from)); + else + throw Error("got unsupported field type %x from Nix daemon", (int) type); + } + return fields; +} + +std::exception_ptr WorkerProto::BasicClientConnection::processStderrReturn(Sink * sink, Source * source, bool flush) +{ + if (flush) + to.flush(); + + std::exception_ptr ex; + + while (true) { + + auto msg = readNum(from); + + if (msg == STDERR_WRITE) { + auto s = readString(from); + if (!sink) + throw Error("no sink"); + (*sink)(s); + } + + else if (msg == STDERR_READ) { + if (!source) + throw Error("no source"); + size_t len = readNum(from); + auto buf = std::make_unique(len); + writeString({(const char *) buf.get(), source->read(buf.get(), len)}, to); + to.flush(); + } + + else if (msg == STDERR_ERROR) { + if (GET_PROTOCOL_MINOR(daemonVersion) >= 26) { + ex = std::make_exception_ptr(readError(from)); + } else { + auto error = readString(from); + unsigned int status = readInt(from); + ex = std::make_exception_ptr(Error(status, error)); + } + break; + } + + else if (msg == STDERR_NEXT) + printError(chomp(readString(from))); + + else if (msg == STDERR_START_ACTIVITY) { + auto act = readNum(from); + auto lvl = (Verbosity) readInt(from); + auto type = (ActivityType) readInt(from); + auto s = readString(from); + auto fields = readFields(from); + auto parent = readNum(from); + logger->startActivity(act, lvl, type, s, fields, parent); + } + + else if (msg == STDERR_STOP_ACTIVITY) { + auto act = readNum(from); + logger->stopActivity(act); + } + + else if (msg == STDERR_RESULT) { + auto act = readNum(from); + auto type = (ResultType) readInt(from); + auto fields = readFields(from); + logger->result(act, type, fields); + } + + else if (msg == STDERR_LAST) + break; + + else + throw Error("got unknown message type %x from Nix daemon", msg); + } + + if (!ex) { + return ex; + } else { + try { + std::rethrow_exception(ex); + } catch (const Error & e) { + // Nix versions before #4628 did not have an adequate + // behavior for reporting that the derivation format was + // upgraded. To avoid having to add compatibility logic in + // many places, we expect to catch almost all occurrences of + // the old incomprehensible error here, so that we can + // explain to users what's going on when their daemon is + // older than #4628 (2023). + if (experimentalFeatureSettings.isEnabled(Xp::DynamicDerivations) + && GET_PROTOCOL_MINOR(daemonVersion) <= 35) { + auto m = e.msg(); + if (m.find("parsing derivation") != std::string::npos && m.find("expected string") != std::string::npos + && m.find("Derive([") != std::string::npos) + return std::make_exception_ptr(Error( + "%s, this might be because the daemon is too old to understand dependencies on dynamic derivations. Check to see if the raw derivation is in the form '%s'", + std::move(m), + "Drv WithVersion(..)")); + } + return std::current_exception(); + } + } +} + +void WorkerProto::BasicClientConnection::processStderr(bool * daemonException, Sink * sink, Source * source, bool flush) +{ + auto ex = processStderrReturn(sink, source, flush); + if (ex) { + *daemonException = true; + std::rethrow_exception(ex); + } +} + +WorkerProto::Version +WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion) +{ + to << WORKER_MAGIC_1 << localVersion; + to.flush(); + + unsigned int magic = readInt(from); + if (magic != WORKER_MAGIC_2) + throw Error("nix-daemon protocol mismatch from"); + auto daemonVersion = readInt(from); + + if (GET_PROTOCOL_MAJOR(daemonVersion) != GET_PROTOCOL_MAJOR(PROTOCOL_VERSION)) + throw Error("Nix daemon protocol version not supported"); + if (GET_PROTOCOL_MINOR(daemonVersion) < 10) + throw Error("the Nix daemon version is too old"); + to << localVersion; + + return std::min(daemonVersion, localVersion); +} + +WorkerProto::Version +WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion) +{ + unsigned int magic = readInt(from); + if (magic != WORKER_MAGIC_1) + throw Error("protocol mismatch"); + to << WORKER_MAGIC_2 << localVersion; + to.flush(); + auto clientVersion = readInt(from); + return std::min(clientVersion, localVersion); +} + +WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store) +{ + WorkerProto::ClientHandshakeInfo res; + + if (GET_PROTOCOL_MINOR(daemonVersion) >= 14) { + // Obsolete CPU affinity. + to << 0; + } + + if (GET_PROTOCOL_MINOR(daemonVersion) >= 11) + to << false; // obsolete reserveSpace + + if (GET_PROTOCOL_MINOR(daemonVersion) >= 33) + to.flush(); + + return WorkerProto::Serialise::read(store, *this); +} + +void WorkerProto::BasicServerConnection::postHandshake(const StoreDirConfig & store, const ClientHandshakeInfo & info) +{ + if (GET_PROTOCOL_MINOR(clientVersion) >= 14 && readInt(from)) { + // Obsolete CPU affinity. + readInt(from); + } + + if (GET_PROTOCOL_MINOR(clientVersion) >= 11) + readInt(from); // obsolete reserveSpace + + WorkerProto::write(store, *this, info); +} + +UnkeyedValidPathInfo WorkerProto::BasicClientConnection::queryPathInfo( + const StoreDirConfig & store, bool * daemonException, const StorePath & path) +{ + to << WorkerProto::Op::QueryPathInfo << store.printStorePath(path); + try { + processStderr(daemonException); + } catch (Error & e) { + // Ugly backwards compatibility hack. + if (e.msg().find("is not valid") != std::string::npos) + throw InvalidPath(std::move(e.info())); + throw; + } + if (GET_PROTOCOL_MINOR(daemonVersion) >= 17) { + bool valid; + from >> valid; + if (!valid) + throw InvalidPath("path '%s' is not valid", store.printStorePath(path)); + } + return WorkerProto::Serialise::read(store, *this); +} + +StorePathSet WorkerProto::BasicClientConnection::queryValidPaths( + const StoreDirConfig & store, bool * daemonException, const StorePathSet & paths, SubstituteFlag maybeSubstitute) +{ + assert(GET_PROTOCOL_MINOR(daemonVersion) >= 12); + to << WorkerProto::Op::QueryValidPaths; + WorkerProto::write(store, *this, paths); + if (GET_PROTOCOL_MINOR(daemonVersion) >= 27) { + to << maybeSubstitute; + } + processStderr(daemonException); + return WorkerProto::Serialise::read(store, *this); +} + +void WorkerProto::BasicClientConnection::addTempRoot( + const StoreDirConfig & store, bool * daemonException, const StorePath & path) +{ + to << WorkerProto::Op::AddTempRoot << store.printStorePath(path); + processStderr(daemonException); + readInt(from); +} + +void WorkerProto::BasicClientConnection::putBuildDerivationRequest( + const StoreDirConfig & store, + bool * daemonException, + const StorePath & drvPath, + const BasicDerivation & drv, + BuildMode buildMode) +{ + to << WorkerProto::Op::BuildDerivation << store.printStorePath(drvPath); + writeDerivation(to, store, drv); + to << buildMode; +} + +BuildResult +WorkerProto::BasicClientConnection::getBuildDerivationResponse(const StoreDirConfig & store, bool * daemonException) +{ + return WorkerProto::Serialise::read(store, *this); +} + +void WorkerProto::BasicClientConnection::narFromPath( + const StoreDirConfig & store, bool * daemonException, const StorePath & path, std::function fun) +{ + to << WorkerProto::Op::NarFromPath << store.printStorePath(path); + processStderr(daemonException); + + fun(from); +} + +void WorkerProto::BasicClientConnection::importPaths( + const StoreDirConfig & store, bool * daemonException, Source & source) +{ + to << WorkerProto::Op::ImportPaths; + processStderr(daemonException, 0, &source); + auto importedPaths = WorkerProto::Serialise::read(store, *this); + assert(importedPaths.size() <= importedPaths.size()); +} +} diff --git a/src/libstore/worker-protocol-connection.hh b/src/libstore/worker-protocol-connection.hh new file mode 100644 index 000000000..9dd723fd0 --- /dev/null +++ b/src/libstore/worker-protocol-connection.hh @@ -0,0 +1,187 @@ +#pragma once +///@file + +#include "worker-protocol.hh" +#include "store-api.hh" + +namespace nix { + +struct WorkerProto::BasicClientConnection +{ + /** + * Send with this. + */ + FdSink to; + + /** + * Receive with this. + */ + FdSource from; + + /** + * Worker protocol version used for the connection. + * + * Despite its name, it is actually the maximum version both + * sides support. (If the maximum doesn't exist, we would fail to + * establish a connection and produce a value of this type.) + */ + WorkerProto::Version daemonVersion; + + /** + * Flush to direction + */ + virtual ~BasicClientConnection(); + + virtual void closeWrite() = 0; + + std::exception_ptr processStderrReturn(Sink * sink = 0, Source * source = 0, bool flush = true); + + void processStderr(bool * daemonException, Sink * sink = 0, Source * source = 0, bool flush = true); + + /** + * Establishes connection, negotiating version. + * + * @return the version provided by the other side of the + * connection. + * + * @param to Taken by reference to allow for various error handling + * mechanisms. + * + * @param from Taken by reference to allow for various error + * handling mechanisms. + * + * @param localVersion Our version which is sent over + */ + static Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion); + + /** + * After calling handshake, must call this to exchange some basic + * information abou the connection. + */ + ClientHandshakeInfo postHandshake(const StoreDirConfig & store); + + /** + * Coercion to `WorkerProto::ReadConn`. This makes it easy to use the + * factored out serve protocol serializers with a + * `LegacySSHStore::Connection`. + * + * The serve protocol connection types are unidirectional, unlike + * this type. + */ + operator WorkerProto::ReadConn() + { + return WorkerProto::ReadConn{ + .from = from, + .version = daemonVersion, + }; + } + + /** + * Coercion to `WorkerProto::WriteConn`. This makes it easy to use the + * factored out serve protocol serializers with a + * `LegacySSHStore::Connection`. + * + * The serve protocol connection types are unidirectional, unlike + * this type. + */ + operator WorkerProto::WriteConn() + { + return WorkerProto::WriteConn{ + .to = to, + .version = daemonVersion, + }; + } + + void addTempRoot(const StoreDirConfig & remoteStore, bool * daemonException, const StorePath & path); + + StorePathSet queryValidPaths( + const StoreDirConfig & remoteStore, + bool * daemonException, + const StorePathSet & paths, + SubstituteFlag maybeSubstitute); + + UnkeyedValidPathInfo queryPathInfo(const StoreDirConfig & store, bool * daemonException, const StorePath & path); + + void putBuildDerivationRequest( + const StoreDirConfig & store, + bool * daemonException, + const StorePath & drvPath, + const BasicDerivation & drv, + BuildMode buildMode); + + /** + * Get the response, must be paired with + * `putBuildDerivationRequest`. + */ + BuildResult getBuildDerivationResponse(const StoreDirConfig & store, bool * daemonException); + + void narFromPath( + const StoreDirConfig & store, + bool * daemonException, + const StorePath & path, + std::function fun); + + void importPaths(const StoreDirConfig & store, bool * daemonException, Source & source); +}; + +struct WorkerProto::BasicServerConnection +{ + /** + * Send with this. + */ + FdSink & to; + + /** + * Receive with this. + */ + FdSource & from; + + /** + * Worker protocol version used for the connection. + * + * Despite its name, it is actually the maximum version both + * sides support. (If the maximum doesn't exist, we would fail to + * establish a connection and produce a value of this type.) + */ + WorkerProto::Version clientVersion; + + operator WorkerProto::ReadConn() + { + return WorkerProto::ReadConn{ + .from = from, + .version = clientVersion, + }; + } + + operator WorkerProto::WriteConn() + { + return WorkerProto::WriteConn{ + .to = to, + .version = clientVersion, + }; + } + + /** + * Establishes connection, negotiating version. + * + * @return the version provided by the other side of the + * connection. + * + * @param to Taken by reference to allow for various error handling + * mechanisms. + * + * @param from Taken by reference to allow for various error + * handling mechanisms. + * + * @param localVersion Our version which is sent over + */ + static WorkerProto::Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion); + + /** + * After calling handshake, must call this to exchange some basic + * information abou the connection. + */ + void postHandshake(const StoreDirConfig & store, const ClientHandshakeInfo & info); +}; + +} diff --git a/src/libstore/worker-protocol.cc b/src/libstore/worker-protocol.cc index a50259d24..c61540bd6 100644 --- a/src/libstore/worker-protocol.cc +++ b/src/libstore/worker-protocol.cc @@ -222,4 +222,35 @@ void WorkerProto::Serialise::write(const StoreDirConfig & } } + +WorkerProto::ClientHandshakeInfo WorkerProto::Serialise::read(const StoreDirConfig & store, ReadConn conn) +{ + WorkerProto::ClientHandshakeInfo res; + + if (GET_PROTOCOL_MINOR(conn.version) >= 33) { + res.daemonNixVersion = readString(conn.from); + } + + if (GET_PROTOCOL_MINOR(conn.version) >= 35) { + res.remoteTrustsUs = WorkerProto::Serialise>::read(store, conn); + } else { + // We don't know the answer; protocol to old. + res.remoteTrustsUs = std::nullopt; + } + + return res; +} + +void WorkerProto::Serialise::write(const StoreDirConfig & store, WriteConn conn, const WorkerProto::ClientHandshakeInfo & info) +{ + if (GET_PROTOCOL_MINOR(conn.version) >= 33) { + assert(info.daemonNixVersion); + conn.to << *info.daemonNixVersion; + } + + if (GET_PROTOCOL_MINOR(conn.version) >= 35) { + WorkerProto::write(store, conn, info.remoteTrustsUs); + } +} + } diff --git a/src/libstore/worker-protocol.hh b/src/libstore/worker-protocol.hh index 91d277b77..db61574aa 100644 --- a/src/libstore/worker-protocol.hh +++ b/src/libstore/worker-protocol.hh @@ -76,6 +76,19 @@ struct WorkerProto Version version; }; + /** + * Stripped down serialization logic suitable for sharing with Hydra. + * + * @todo remove once Hydra uses Store abstraction consistently. + */ + struct BasicClientConnection; + struct BasicServerConnection; + + /** + * Extra information provided as part of protocol negotation. + */ + struct ClientHandshakeInfo; + /** * Data type for canonical pairs of serialisers for the worker protocol. * @@ -166,6 +179,33 @@ enum struct WorkerProto::Op : uint64_t AddPermRoot = 47, }; +struct WorkerProto::ClientHandshakeInfo +{ + /** + * The version of the Nix daemon that is processing our requests +. + * + * Do note, it may or may not communicating with another daemon, + * rather than being an "end" `LocalStore` or similar. + */ + std::optional daemonNixVersion; + + /** + * Whether the remote side trusts us or not. + * + * 3 values: "yes", "no", or `std::nullopt` for "unknown". + * + * Note that the "remote side" might not be just the end daemon, but + * also an intermediary forwarder that can make its own trusting + * decisions. This would be the intersection of all their trust + * decisions, since it takes only one link in the chain to start + * denying operations. + */ + std::optional remoteTrustsUs; + + bool operator == (const ClientHandshakeInfo &) const = default; +}; + /** * Convenience for sending operation codes. * @@ -218,6 +258,8 @@ template<> DECLARE_WORKER_SERIALISER(std::optional); template<> DECLARE_WORKER_SERIALISER(std::optional); +template<> +DECLARE_WORKER_SERIALISER(WorkerProto::ClientHandshakeInfo); template DECLARE_WORKER_SERIALISER(std::vector); diff --git a/src/nix-store/nix-store.cc b/src/nix-store/nix-store.cc index b23d99ad6..6d028e0a7 100644 --- a/src/nix-store/nix-store.cc +++ b/src/nix-store/nix-store.cc @@ -7,6 +7,7 @@ #include "local-fs-store.hh" #include "log-store.hh" #include "serve-protocol.hh" +#include "serve-protocol-connection.hh" #include "serve-protocol-impl.hh" #include "shared.hh" #include "graphml.hh" diff --git a/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_30.bin b/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_30.bin new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_33.bin b/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_33.bin new file mode 100644 index 000000000..96c6efafc Binary files /dev/null and b/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_33.bin differ diff --git a/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_35.bin b/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_35.bin new file mode 100644 index 000000000..e877159aa Binary files /dev/null and b/tests/unit/libstore/data/worker-protocol/client-handshake-info_1_35.bin differ diff --git a/tests/unit/libstore/data/worker-protocol/handshake-to-client.bin b/tests/unit/libstore/data/worker-protocol/handshake-to-client.bin new file mode 100644 index 000000000..bee94fbe5 Binary files /dev/null and b/tests/unit/libstore/data/worker-protocol/handshake-to-client.bin differ diff --git a/tests/unit/libstore/serve-protocol.cc b/tests/unit/libstore/serve-protocol.cc index 61dc15528..ebf0c52b0 100644 --- a/tests/unit/libstore/serve-protocol.cc +++ b/tests/unit/libstore/serve-protocol.cc @@ -6,6 +6,7 @@ #include "serve-protocol.hh" #include "serve-protocol-impl.hh" +#include "serve-protocol-connection.hh" #include "build-result.hh" #include "file-descriptor.hh" #include "tests/protocol.hh" diff --git a/tests/unit/libstore/worker-protocol.cc b/tests/unit/libstore/worker-protocol.cc index 2b2e559a9..853d56b6a 100644 --- a/tests/unit/libstore/worker-protocol.cc +++ b/tests/unit/libstore/worker-protocol.cc @@ -4,6 +4,7 @@ #include #include "worker-protocol.hh" +#include "worker-protocol-connection.hh" #include "worker-protocol-impl.hh" #include "derived-path.hh" #include "build-result.hh" @@ -18,9 +19,9 @@ struct WorkerProtoTest : VersionedProtoTest { /** * For serializers that don't care about the minimum version, we - * used the oldest one: 1.0. + * used the oldest one: 1.10. */ - WorkerProto::Version defaultVersion = 1 << 8 | 0; + WorkerProto::Version defaultVersion = 1 << 8 | 10; }; @@ -591,4 +592,152 @@ VERSIONED_CHARACTERIZATION_TEST( }, })) +VERSIONED_CHARACTERIZATION_TEST( + WorkerProtoTest, + clientHandshakeInfo_1_30, + "client-handshake-info_1_30", + 1 << 8 | 30, + (std::tuple { + {}, + })) + +VERSIONED_CHARACTERIZATION_TEST( + WorkerProtoTest, + clientHandshakeInfo_1_33, + "client-handshake-info_1_33", + 1 << 8 | 33, + (std::tuple { + { + .daemonNixVersion = std::optional { "foo" }, + }, + { + .daemonNixVersion = std::optional { "bar" }, + }, + })) + +VERSIONED_CHARACTERIZATION_TEST( + WorkerProtoTest, + clientHandshakeInfo_1_35, + "client-handshake-info_1_35", + 1 << 8 | 35, + (std::tuple { + { + .daemonNixVersion = std::optional { "foo" }, + .remoteTrustsUs = std::optional { NotTrusted }, + }, + { + .daemonNixVersion = std::optional { "bar" }, + .remoteTrustsUs = std::optional { Trusted }, + }, + })) + +TEST_F(WorkerProtoTest, handshake_log) +{ + CharacterizationTest::writeTest("handshake-to-client", [&]() -> std::string { + StringSink toClientLog; + + Pipe toClient, toServer; + toClient.create(); + toServer.create(); + + WorkerProto::Version clientResult; + + auto thread = std::thread([&]() { + FdSink out { toServer.writeSide.get() }; + FdSource in0 { toClient.readSide.get() }; + TeeSource in { in0, toClientLog }; + clientResult = WorkerProto::BasicClientConnection::handshake( + out, in, defaultVersion); + }); + + { + FdSink out { toClient.writeSide.get() }; + FdSource in { toServer.readSide.get() }; + WorkerProto::BasicServerConnection::handshake( + out, in, defaultVersion); + }; + + thread.join(); + + return std::move(toClientLog.s); + }); +} + +/// Has to be a `BufferedSink` for handshake. +struct NullBufferedSink : BufferedSink { + void writeUnbuffered(std::string_view data) override { } +}; + +TEST_F(WorkerProtoTest, handshake_client_replay) +{ + CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) { + NullBufferedSink nullSink; + + StringSource in { toClientLog }; + auto clientResult = WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion); + + EXPECT_EQ(clientResult, defaultVersion); + }); +} + +TEST_F(WorkerProtoTest, handshake_client_truncated_replay_throws) +{ + CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) { + for (size_t len = 0; len < toClientLog.size(); ++len) { + NullBufferedSink nullSink; + StringSource in { + // truncate + toClientLog.substr(0, len) + }; + if (len < 8) { + EXPECT_THROW( + WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion), + EndOfFile); + } else { + // Not sure why cannot keep on checking for `EndOfFile`. + EXPECT_THROW( + WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion), + Error); + } + } + }); +} + +TEST_F(WorkerProtoTest, handshake_client_corrupted_throws) +{ + CharacterizationTest::readTest("handshake-to-client", [&](const std::string toClientLog) { + for (size_t idx = 0; idx < toClientLog.size(); ++idx) { + // corrupt a copy + std::string toClientLogCorrupt = toClientLog; + toClientLogCorrupt[idx] *= 4; + ++toClientLogCorrupt[idx]; + + NullBufferedSink nullSink; + StringSource in { toClientLogCorrupt }; + + if (idx < 4 || idx == 9) { + // magic bytes don't match + EXPECT_THROW( + WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion), + Error); + } else if (idx < 8 || idx >= 12) { + // Number out of bounds + EXPECT_THROW( + WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion), + SerialisationError); + } else { + auto ver = WorkerProto::BasicClientConnection::handshake( + nullSink, in, defaultVersion); + // `std::min` of this and the other version saves us + EXPECT_EQ(ver, defaultVersion); + } + } + }); +} + }