diff --git a/nse_init.cc b/nse_init.cc index 49cb863a3..e9ce04451 100644 --- a/nse_init.cc +++ b/nse_init.cc @@ -26,6 +26,8 @@ extern NmapOps o; extern int current_hosts; extern int errfunc; +#define REQUIRE_ERRORS "require_error" + /* int error_function (lua_State *L) * * Arguments: @@ -109,11 +111,26 @@ static int loadfile (lua_State *L) lua_setfenv(L, -2); // set it if (lua_pcall(L, 0, 0, 0) != 0) // Call the function (loads globals) { - error("%s: '%s' threw a run time error and could not be loaded.", - SCRIPT_ENGINE, filename); - SCRIPT_ENGINE_DEBUGGING( - error("%s", lua_tostring(L, -1)); - ) + // Check for dependency errors + lua_getfield(L, LUA_REGISTRYINDEX, REQUIRE_ERRORS); + lua_pushvalue(L, -2); // the error + lua_gettable(L, -2); + if (lua_toboolean(L, -1)) // The error was thrown by require + { + if (o.verbose > 3 && !o.debugging) + error("%s: '%s' could not be loaded due to missing dependency '%s'", + SCRIPT_ENGINE, filename, lua_tostring(L, -1)); + SCRIPT_ENGINE_DEBUGGING( + error("%s: '%s' threw a run time error and could not be loaded.\n%s", + SCRIPT_ENGINE, filename, lua_tostring(L, -3)); + ) + } else { + error("%s: '%s' threw a run time error and could not be loaded.", + SCRIPT_ENGINE, filename); + SCRIPT_ENGINE_DEBUGGING( + error("%s", lua_tostring(L, -3)); + ) + } return 0; } @@ -231,6 +248,30 @@ static int init_setpath (lua_State *L) return 0; } +/* int nse_require (lua_State *L) + * + * This hooks the standard require function to allow us to properly catch + * dependency errors. Basically an error message is saved in the error table + * (upvalue 1) that can be indexed later to check if it was unhandled by + * the script (see loadfile in particular). + */ +static int nse_require (lua_State *L) +{ + luaL_checkstring(L, 1); // ensure first argument is a string + lua_pushvalue(L, 1); + lua_insert(L, 1); // save a copy of the library name at stack bottom + lua_pushvalue(L, lua_upvalueindex(1)); // require function + lua_insert(L, 2); + if (lua_pcall(L, lua_gettop(L)-2, LUA_MULTRET, 0) != 0) + { + lua_pushvalue(L, -1); // the error message + lua_pushvalue(L, 1); // the library name that caused the error + lua_settable(L, lua_upvalueindex(2)); + return lua_error(L); + } + return lua_gettop(L)-1; // omit the saved first argument +} + /* int init_lua (lua_State *L) * * Initializes the Lua State. @@ -284,6 +325,17 @@ int init_lua (lua_State *L) lua_newtable(L); current_hosts = luaL_ref(L, LUA_REGISTRYINDEX); + lua_newtable(L); // nse_require error table + lua_createtable(L, 0, 1); // metatable + lua_pushliteral(L, "k"); + lua_setfield(L, -2, "__mode"); // weak keys + lua_setmetatable(L, -2); + lua_getglobal(L, "require"); + lua_pushvalue(L, -2); // nse_require error table + lua_pushcclosure(L, nse_require, 2); + lua_setglobal(L, "require"); + lua_setfield(L, LUA_REGISTRYINDEX, REQUIRE_ERRORS); // save nse_require table + return 0; }