Merge environments of nested functions

Previously an expression like 'x: y: ...' would create two
environments with one value. Now it creates one environment with two
values. This reduces the number of allocations and the distance in the
environment chain that variable lookups need to traverse.

On

  $ nix-instantiate --dry-run '<nixpkgs/nixos/release-combined.nix>' -A nixos.tests.simple.x86_64-linux

this gives a ~30% reduction in the number of Env allocations.
This commit is contained in:
Eelco Dolstra 2021-11-05 13:05:03 +01:00
parent a1c1b0e553
commit 904d0ec5c0
9 changed files with 276 additions and 136 deletions

View File

@ -127,6 +127,7 @@ void printValue(std::ostream & str, std::set<const Value *> & active, const Valu
break;
case tThunk:
case tApp:
case tPartialApp:
str << "<CODE>";
break;
case tLambda:
@ -1275,35 +1276,28 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value &
}
};
while (nrArgs > 0) {
auto callLambda = [&](Env * env, ExprLambda & lambda, Value * * args)
{
Env & env2(allocEnv(lambda.envSize));
env2.up = env;
if (vCur.isLambda()) {
Displacement displ = 0;
ExprLambda & lambda(*vCur.lambda.fun);
for (auto & arg : lambda.args) {
auto vArg = *args++;
auto size =
(lambda.arg.empty() ? 0 : 1) +
(lambda.hasFormals() ? lambda.formals->formals.size() : 0);
Env & env2(allocEnv(size));
env2.up = vCur.lambda.env;
if (arg.arg != sEpsilon)
env2.values[displ++] = vArg;
Displacement displ = 0;
if (!lambda.hasFormals())
env2.values[displ++] = args[0];
else {
forceAttrs(*args[0], pos);
if (!lambda.arg.empty())
env2.values[displ++] = args[0];
if (arg.formals) {
forceAttrs(*vArg, pos);
/* For each formal argument, get the actual argument. If
there is no matching actual argument but the formal
argument has a default, use the default. */
size_t attrsUsed = 0;
for (auto & i : lambda.formals->formals) {
auto j = args[0]->attrs->get(i.name);
for (auto & i : arg.formals->formals) {
auto j = vArg->attrs->get(i.name);
if (!j) {
if (!i.def) throwTypeError(pos, "%1% called without required argument '%2%'",
lambda, i.name);
@ -1316,35 +1310,96 @@ void EvalState::callFunction(Value & fun, size_t nrArgs, Value * * args, Value &
/* Check that each actual argument is listed as a formal
argument (unless the attribute match specifies a `...'). */
if (!lambda.formals->ellipsis && attrsUsed != args[0]->attrs->size()) {
if (!arg.formals->ellipsis && attrsUsed != vArg->attrs->size()) {
/* Nope, so show the first unexpected argument to the
user. */
for (auto & i : *args[0]->attrs)
if (lambda.formals->argNames.find(i.name) == lambda.formals->argNames.end())
for (auto & i : *vArg->attrs)
if (arg.formals->argNames.find(i.name) == arg.formals->argNames.end())
throwTypeError(pos, "%1% called with unexpected argument '%2%'", lambda, i.name);
abort(); // can't happen
}
}
}
nrFunctionCalls++;
if (countCalls) incrFunctionCall(&lambda);
assert(displ == lambda.envSize);
/* Evaluate the body. */
try {
lambda.body->eval(*this, env2, vCur);
} catch (Error & e) {
if (loggerSettings.showTrace.get()) {
addErrorTrace(e, lambda.pos, "while evaluating %s",
(lambda.name.set()
? "'" + (string) lambda.name + "'"
: "anonymous lambda"));
addErrorTrace(e, pos, "from call site%s", "");
}
throw;
nrFunctionCalls++;
if (countCalls) incrFunctionCall(&lambda);
/* Evaluate the body. */
try {
lambda.body->eval(*this, env2, vCur);
} catch (Error & e) {
if (loggerSettings.showTrace) {
addErrorTrace(e, lambda.pos, "while evaluating %s",
(lambda.name.set()
? "'" + (string) lambda.name + "'"
: "anonymous lambda"));
addErrorTrace(e, pos, "from call site%s", "");
}
throw;
}
};
nrArgs--;
args += 1;
while (nrArgs > 0) {
if (vCur.isLambda()) {
ExprLambda & lambda(*vCur.lambda.fun);
if (nrArgs < lambda.args.size()) {
vRes = vCur;
for (size_t i = 0; i < nrArgs; ++i) {
auto fun2 = allocValue();
*fun2 = vRes;
vRes.mkPartialApp(fun2, args[i]);
}
return;
} else {
callLambda(vCur.lambda.env, lambda, args);
nrArgs -= lambda.args.size();
args += lambda.args.size();
}
}
else if (vCur.isPartialApp()) {
/* Figure out the number of arguments still needed. */
size_t argsDone = 0;
Value * lambda = &vCur;
while (lambda->isPartialApp()) {
argsDone++;
lambda = lambda->app.left;
}
assert(lambda->isLambda());
auto arity = lambda->lambda.fun->args.size();
auto argsLeft = arity - argsDone;
if (nrArgs < argsLeft) {
/* We still don't have enough arguments, so extend the tPartialApp chain. */
vRes = vCur;
for (size_t i = 0; i < nrArgs; ++i) {
auto fun2 = allocValue();
*fun2 = vRes;
vRes.mkPartialApp(fun2, args[i]);
}
return;
} else {
/* We have all the arguments, so call the function
with the previous and new arguments. */
Value * vArgs[arity];
auto n = argsDone;
for (Value * arg = &vCur; arg->isPartialApp(); arg = arg->app.left)
vArgs[--n] = arg->app.right;
for (size_t i = 0; i < argsLeft; ++i)
vArgs[argsDone + i] = args[i];
nrArgs -= argsLeft;
args += argsLeft;
callLambda(lambda->lambda.env, *lambda->lambda.fun, vArgs);
}
}
else if (vCur.isPrimOp()) {
@ -1458,42 +1513,48 @@ void EvalState::autoCallFunction(Bindings & args, Value & fun, Value & res)
}
}
if (!fun.isLambda() || !fun.lambda.fun->hasFormals()) {
if (!fun.isLambda()) {
res = fun;
return;
}
Value * actualArgs = allocValue();
mkAttrs(*actualArgs, std::max(static_cast<uint32_t>(fun.lambda.fun->formals->formals.size()), args.size()));
Value * actualArgs[fun.lambda.fun->args.size()];
if (fun.lambda.fun->formals->ellipsis) {
// If the formals have an ellipsis (eg the function accepts extra args) pass
// all available automatic arguments (which includes arguments specified on
// the command line via --arg/--argstr)
for (auto& v : args) {
actualArgs->attrs->push_back(v);
for (const auto & [i, arg] : enumerate(fun.lambda.fun->args)) {
if (!arg.formals) {
res = fun;
return;
}
} else {
// Otherwise, only pass the arguments that the function accepts
for (auto & i : fun.lambda.fun->formals->formals) {
Bindings::iterator j = args.find(i.name);
if (j != args.end()) {
actualArgs->attrs->push_back(*j);
} else if (!i.def) {
throwMissingArgumentError(i.pos, R"(cannot evaluate a function that has an argument without a value ('%1%')
actualArgs[i] = allocValue();
mkAttrs(*actualArgs[i], std::max(arg.formals->formals.size(), static_cast<size_t>(args.size())));
if (arg.formals->ellipsis) {
/* If the formals have an ellipsis (i.e. the function
accepts extra args), pass all available automatic
arguments. */
for (auto & v : args)
actualArgs[i]->attrs->push_back(v);
} else {
/* Otherwise, only pass the arguments that the function
accepts. */
for (auto & j : arg.formals->formals) {
if (auto attr = args.get(j.name))
actualArgs[i]->attrs->push_back(*attr);
else if (!j.def)
throwMissingArgumentError(j.pos, R"(cannot evaluate a function that has an argument without a value ('%1%')
Nix attempted to evaluate a function as a top level expression; in
this case it must have its arguments supplied either by default
values, or passed explicitly with '--arg' or '--argstr'. See
https://nixos.org/manual/nix/stable/#ss-functions.)", i.name);
https://nixos.org/manual/nix/stable/#ss-functions.)", j.name);
}
}
actualArgs[i]->attrs->sort();
}
actualArgs->attrs->sort();
callFunction(fun, *actualArgs, res, noPos);
callFunction(fun, fun.lambda.fun->args.size(), actualArgs, res, noPos);
}

View File

@ -230,8 +230,13 @@ static Flake getFlake(
if (auto outputs = vInfo.attrs->get(sOutputs)) {
expectType(state, nFunction, *outputs->value, *outputs->pos);
if (outputs->value->isLambda() && outputs->value->lambda.fun->hasFormals()) {
for (auto & formal : outputs->value->lambda.fun->formals->formals) {
if (outputs->value->lambda.fun->args.size() != 1)
throw Error("the 'outputs' attribute of flake '%s' is not a unary function", lockedRef);
auto & arg = outputs->value->lambda.fun->args[0];
if (arg.formals) {
for (auto & formal : arg.formals->formals) {
if (formal.name != state.sSelf)
flake.inputs.emplace(formal.name, FlakeInput {
.ref = parseFlakeRef(formal.name)

View File

@ -124,23 +124,26 @@ void ExprList::show(std::ostream & str) const
void ExprLambda::show(std::ostream & str) const
{
str << "(";
if (hasFormals()) {
str << "{ ";
bool first = true;
for (auto & i : formals->formals) {
if (first) first = false; else str << ", ";
str << i.name;
if (i.def) str << " ? " << *i.def;
for (auto & arg : args) {
if (arg.formals) {
str << "{ ";
bool first = true;
for (auto & i : arg.formals->formals) {
if (first) first = false; else str << ", ";
str << i.name;
if (i.def) str << " ? " << *i.def;
}
if (arg.formals->ellipsis) {
if (!first) str << ", ";
str << "...";
}
str << " }";
if (!arg.arg.empty()) str << " @ ";
}
if (formals->ellipsis) {
if (!first) str << ", ";
str << "...";
}
str << " }";
if (!arg.empty()) str << " @ ";
if (!arg.arg.empty()) str << arg.arg;
str << ": ";
}
if (!arg.empty()) str << arg;
str << ": " << *body << ")";
str << *body << ")";
}
void ExprCall::show(std::ostream & str) const
@ -279,8 +282,7 @@ void ExprVar::bindVars(const StaticEnv & env)
if (curEnv->isWith) {
if (withLevel == -1) withLevel = level;
} else {
auto i = curEnv->find(name);
if (i != curEnv->vars.end()) {
if (auto i = curEnv->get(name)) {
fromWith = false;
this->level = level;
displ = i->second;
@ -354,25 +356,48 @@ void ExprList::bindVars(const StaticEnv & env)
void ExprLambda::bindVars(const StaticEnv & env)
{
StaticEnv newEnv(
false, &env,
(hasFormals() ? formals->formals.size() : 0) +
(arg.empty() ? 0 : 1));
/* The parser adds arguments in reverse order. Let's fix that
now. */
std::reverse(args.begin(), args.end());
envSize = 0;
for (auto & arg :args) {
if (!arg.arg.empty()) envSize++;
if (arg.formals) envSize += arg.formals->formals.size();
}
StaticEnv newEnv(false, &env, envSize);
Displacement displ = 0;
if (!arg.empty()) newEnv.vars.emplace_back(arg, displ++);
for (auto & arg : args) {
if (!arg.arg.empty()) {
if (auto i = const_cast<StaticEnv::Vars::value_type *>(newEnv.get(arg.arg)))
i->second = displ++;
else
newEnv.vars.emplace_back(arg.arg, displ++);
}
if (hasFormals()) {
for (auto & i : formals->formals)
newEnv.vars.emplace_back(i.name, displ++);
if (arg.formals) {
for (auto & i : arg.formals->formals) {
if (auto j = const_cast<StaticEnv::Vars::value_type *>(newEnv.get(i.name)))
j->second = displ++;
else
newEnv.vars.emplace_back(i.name, displ++);
}
newEnv.sort();
newEnv.sort();
for (auto & i : formals->formals)
if (i.def) i.def->bindVars(newEnv);
for (auto & i : arg.formals->formals)
if (i.def) i.def->bindVars(newEnv);
}
}
assert(displ == envSize);
newEnv.sort();
body->bindVars(newEnv);
}

View File

@ -233,21 +233,24 @@ struct ExprLambda : Expr
{
Pos pos;
Symbol name;
Symbol arg;
Formals * formals;
Expr * body;
ExprLambda(const Pos & pos, const Symbol & arg, Formals * formals, Expr * body)
: pos(pos), arg(arg), formals(formals), body(body)
struct Arg
{
if (!arg.empty() && formals && formals->argNames.find(arg) != formals->argNames.end())
throw ParseError({
.msg = hintfmt("duplicate formal function argument '%1%'", arg),
.errPos = pos
});
Symbol arg;
Formals * formals;
};
std::vector<Arg> args;
Expr * body;
Displacement envSize = 0; // initialized by bindVars()
ExprLambda(const Pos & pos, Expr * body)
: pos(pos), body(body)
{ };
void setName(Symbol & name);
string showNamePos() const;
inline bool hasFormals() const { return formals != nullptr; }
COMMON_METHODS
};
@ -368,12 +371,12 @@ struct StaticEnv
[](const Vars::value_type & a, const Vars::value_type & b) { return a.first < b.first; });
}
Vars::const_iterator find(const Symbol & name) const
const Vars::value_type * get(const Symbol & name) const
{
Vars::value_type key(name, 0);
auto i = std::lower_bound(vars.begin(), vars.end(), key);
if (i != vars.end() && i->first == name) return i;
return vars.end();
if (i != vars.end() && i->first == name) return &*i;
return {};
}
};

View File

@ -160,6 +160,24 @@ static void addFormal(const Pos & pos, Formals * formals, const Formal & formal)
}
static Expr * addArg(const Pos & pos, Expr * e, ExprLambda::Arg && arg)
{
if (!arg.arg.empty() && arg.formals && arg.formals->argNames.count(arg.arg))
throw ParseError({
.msg = hintfmt("duplicate formal function argument '%1%'", arg.arg),
.errPos = pos
});
auto e2 = dynamic_cast<ExprLambda *>(e); // FIXME: slow?
if (!e2)
e2 = new ExprLambda(pos, e);
else
e2->pos = pos;
e2->args.emplace_back(std::move(arg));
return e2;
}
static Expr * stripIndentation(const Pos & pos, SymbolTable & symbols, vector<Expr *> & es)
{
if (es.empty()) return new ExprString(symbols.create(""));
@ -332,13 +350,13 @@ expr: expr_function;
expr_function
: ID ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create($1), 0, $3); }
{ $$ = addArg(CUR_POS, $3, {data->symbols.create($1), nullptr}); }
| '{' formals '}' ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create(""), $2, $5); }
{ $$ = addArg(CUR_POS, $5, {data->state.sEpsilon, $2}); }
| '{' formals '}' '@' ID ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create($5), $2, $7); }
{ $$ = addArg(CUR_POS, $7, {data->symbols.create($5), $2}); }
| ID '@' '{' formals '}' ':' expr_function
{ $$ = new ExprLambda(CUR_POS, data->symbols.create($1), $4, $7); }
{ $$ = addArg(CUR_POS, $7, {data->symbols.create($1), $4}); }
| ASSERT expr ';' expr_function
{ $$ = new ExprAssert(CUR_POS, $2, $4); }
| WITH expr ';' expr_function
@ -456,7 +474,7 @@ expr_simple
string_parts
: STR
| string_parts_interpolated { $$ = new ExprConcatStrings(CUR_POS, true, $1); }
| { $$ = new ExprString(data->symbols.create("")); }
| { $$ = new ExprString(data->state.sEpsilon); }
;
string_parts_interpolated

View File

@ -2386,23 +2386,38 @@ static RegisterPrimOp primop_catAttrs({
static void prim_functionArgs(EvalState & state, const Pos & pos, Value * * args, Value & v)
{
state.forceValue(*args[0], pos);
if (args[0]->isPrimOpApp() || args[0]->isPrimOp()) {
state.mkAttrs(v, 0);
return;
}
if (!args[0]->isLambda())
if (!args[0]->isLambda() && !args[0]->isPartialApp())
throw TypeError({
.msg = hintfmt("'functionArgs' requires a function"),
.errPos = pos
});
if (!args[0]->lambda.fun->hasFormals()) {
size_t argsDone = 0;
auto lambda = args[0];
while (lambda->isPartialApp()) {
argsDone++;
lambda = lambda->app.left;
}
assert(lambda->isLambda());
assert(argsDone < lambda->lambda.fun->args.size());
// FIXME: handle partially applied functions
auto formals = lambda->lambda.fun->args[argsDone].formals;
if (!formals) {
state.mkAttrs(v, 0);
return;
}
state.mkAttrs(v, args[0]->lambda.fun->formals->formals.size());
for (auto & i : args[0]->lambda.fun->formals->formals) {
state.mkAttrs(v, formals->formals.size());
for (auto & i : formals->formals) {
// !!! should optimise booleans (allocate only once)
Value * value = state.allocValue();
v.attrs->push_back(Attr(i.name, value, ptr(&i.pos)));

View File

@ -126,24 +126,28 @@ static void printValueAsXML(EvalState & state, bool strict, bool location,
}
case nFunction: {
if (!v.isLambda()) {
// FIXME: Serialize primops and primopapps
// FIXME: Serialize primops and partial apps
doc.writeEmptyElement("unevaluated");
break;
}
XMLAttrs xmlAttrs;
if (location) posToXML(xmlAttrs, v.lambda.fun->pos);
XMLOpenElement _(doc, "function", xmlAttrs);
if (v.lambda.fun->hasFormals()) {
auto & arg = v.lambda.fun->args[0];
if (arg.formals) {
XMLAttrs attrs;
if (!v.lambda.fun->arg.empty()) attrs["name"] = v.lambda.fun->arg;
if (v.lambda.fun->formals->ellipsis) attrs["ellipsis"] = "1";
if (arg.arg != state.sEpsilon) attrs["name"] = arg.arg;
if (arg.formals->ellipsis) attrs["ellipsis"] = "1";
XMLOpenElement _(doc, "attrspat", attrs);
for (auto & i : v.lambda.fun->formals->formals)
for (auto & i : arg.formals->formals)
doc.writeEmptyElement("attr", singletonAttrs("name", i.name));
} else
doc.writeEmptyElement("varpat", singletonAttrs("name", v.lambda.fun->arg));
doc.writeEmptyElement("varpat", singletonAttrs("name", arg.arg));
break;
}

View File

@ -21,6 +21,7 @@ typedef enum {
tListN,
tThunk,
tApp,
tPartialApp,
tLambda,
tBlackhole,
tPrimOp,
@ -125,6 +126,7 @@ public:
// type() == nFunction
inline bool isLambda() const { return internalType == tLambda; };
inline bool isPartialApp() const { return internalType == tPartialApp; };
inline bool isPrimOp() const { return internalType == tPrimOp; };
inline bool isPrimOpApp() const { return internalType == tPrimOpApp; };
@ -196,7 +198,7 @@ public:
case tNull: return nNull;
case tAttrs: return nAttrs;
case tList1: case tList2: case tListN: return nList;
case tLambda: case tPrimOp: case tPrimOpApp: return nFunction;
case tLambda: case tPartialApp: case tPrimOp: case tPrimOpApp: return nFunction;
case tExternal: return nExternal;
case tFloat: return nFloat;
case tThunk: case tApp: case tBlackhole: return nThunk;
@ -307,6 +309,13 @@ public:
app.right = r;
}
inline void mkPartialApp(Value * l, Value * r)
{
internalType = tPartialApp;
app.left = l;
app.right = r;
}
inline void mkExternal(ExternalValueBase * e)
{
clearValue();

View File

@ -355,14 +355,12 @@ struct CmdFlakeCheck : FlakeCommand
try {
state->forceValue(v, pos);
if (!v.isLambda()
|| v.lambda.fun->hasFormals()
|| !argHasName(v.lambda.fun->arg, "final"))
throw Error("overlay does not take an argument named 'final'");
auto body = dynamic_cast<ExprLambda *>(v.lambda.fun->body);
if (!body
|| body->hasFormals()
|| !argHasName(body->arg, "prev"))
throw Error("overlay does not take an argument named 'prev'");
|| v.lambda.fun->args.size() != 2
|| v.lambda.fun->args[0].formals
|| !argHasName(v.lambda.fun->args[0].arg, "final")
|| v.lambda.fun->args[1].formals
|| !argHasName(v.lambda.fun->args[1].arg, "prev"))
throw Error("overlay is not a binary function with arguments 'final' and 'prev'");
// FIXME: if we have a 'nixpkgs' input, use it to
// evaluate the overlay.
} catch (Error & e) {
@ -375,7 +373,9 @@ struct CmdFlakeCheck : FlakeCommand
try {
state->forceValue(v, pos);
if (v.isLambda()) {
if (!v.lambda.fun->hasFormals() || !v.lambda.fun->formals->ellipsis)
if (v.lambda.fun->args.size() != 1
|| !v.lambda.fun->args[0].formals
|| !v.lambda.fun->args[0].formals->ellipsis)
throw Error("module must match an open attribute set ('{ config, ... }')");
} else if (v.type() == nAttrs) {
for (auto & attr : *v.attrs)
@ -473,12 +473,12 @@ struct CmdFlakeCheck : FlakeCommand
auto checkBundler = [&](const std::string & attrPath, Value & v, const Pos & pos) {
try {
state->forceValue(v, pos);
if (!v.isLambda())
throw Error("bundler must be a function");
if (!v.lambda.fun->formals ||
!v.lambda.fun->formals->argNames.count(state->symbols.create("program")) ||
!v.lambda.fun->formals->argNames.count(state->symbols.create("system")))
throw Error("bundler must take formal arguments 'program' and 'system'");
if (!v.isLambda()
|| v.lambda.fun->args.size() != 1
|| !v.lambda.fun->args[0].formals
|| !v.lambda.fun->args[0].formals->argNames.count(state->symbols.create("program"))
|| !v.lambda.fun->args[0].formals->argNames.count(state->symbols.create("system")))
throw Error("bundler must be a function that takes take arguments 'program' and 'system'");
} catch (Error & e) {
e.addTrace(pos, hintfmt("while checking the template '%s'", attrPath));
reportError(e);