diff --git a/nse_main.cc b/nse_main.cc index 4a2f43a3d..d10177931 100644 --- a/nse_main.cc +++ b/nse_main.cc @@ -1,11 +1,5 @@ #include "nse_main.h" -extern "C" { - #include "lua.h" - #include "lualib.h" - #include "lauxlib.h" -} - #include "nse_init.h" #include "nse_nsock.h" #include "nse_nmaplib.h" diff --git a/nse_main.h b/nse_main.h index 00390525a..8b0247c8d 100644 --- a/nse_main.h +++ b/nse_main.h @@ -1,6 +1,12 @@ #ifndef NMAP_LUA_H #define NMAP_LUA_H +extern "C" { + #include "lua.h" + #include "lualib.h" + #include "lauxlib.h" +} + #include #include #include @@ -20,5 +26,7 @@ int script_updatedb(); //parses the arguments provided to scripts via nmap's --script-args option int script_check_args(); + +int process_waiting2running(lua_State *, int); #endif diff --git a/nse_nmaplib.cc b/nse_nmaplib.cc index df71cc5f7..52dddb787 100644 --- a/nse_nmaplib.cc +++ b/nse_nmaplib.cc @@ -2,6 +2,7 @@ #include "nse_nsock.h" #include "nse_macros.h" #include "nse_debug.h" +#include "nse_main.h" #include "nmap.h" #include "nmap_error.h" @@ -45,6 +46,72 @@ static int l_get_timing_level(lua_State *L); int l_clock_ms(lua_State *L); +static int mutex_i; + +static int aux_mutex (lua_State *L) +{ + static const char * op[] = {"lock", "done", "trylock", "running", NULL}; + switch (luaL_checkoption(L, 1, NULL, op)) + { + case 0: // lock + if (lua_isnil(L, lua_upvalueindex(2))) // check running + { + lua_pushthread(L); + lua_replace(L, lua_upvalueindex(2)); // set running + return 0; + } + lua_pushthread(L); + lua_rawseti(L, lua_upvalueindex(1), lua_objlen(L, lua_upvalueindex(1))+1); + return lua_yield(L, 0); + case 1: // done + lua_pushthread(L); + if (!lua_equal(L, -1, lua_upvalueindex(2))) + luaL_error(L, "%s", "Do not have a lock on this mutex"); + lua_getfield(L, LUA_REGISTRYINDEX, "table.remove"); + lua_pushvalue(L, lua_upvalueindex(1)); + lua_pushinteger(L, 1); + lua_call(L, 2, 1); + lua_replace(L, lua_upvalueindex(2)); + if (!lua_isnil(L, lua_upvalueindex(2))) // waiting threads had a thread + process_waiting2running(lua_tothread(L, lua_upvalueindex(2)), 0); + return 0; + case 2: // trylock + if (lua_isnil(L, lua_upvalueindex(2))) + { + lua_pushthread(L); + lua_replace(L, lua_upvalueindex(2)); + lua_pushboolean(L, true); + } + else + lua_pushboolean(L, false); + return 1; + case 3: // running + lua_pushvalue(L, lua_upvalueindex(2)); + return 1; + } + return 0; +} + +static int l_mutex (lua_State *L) +{ + int t = lua_type(L, 1); + if (t == LUA_TNONE || t == LUA_TNIL || t == LUA_TBOOLEAN || t == LUA_TNUMBER) + luaL_argerror(L, 1, "Object expected"); + lua_rawgeti(L, LUA_REGISTRYINDEX, mutex_i); + lua_pushvalue(L, 1); + lua_gettable(L, -2); + if (lua_isnil(L, -1)) + { + lua_newtable(L); // waiting threads + lua_pushnil(L); // running thread + lua_pushcclosure(L, aux_mutex, 2); + lua_pushvalue(L, 1); // "mutex object" + lua_pushvalue(L, -2); // function + lua_settable(L, -5); // Add to mutex table + } + return 1; // aux_mutex closure +} + int luaopen_nmap (lua_State *L) { static luaL_reg nmaplib [] = { @@ -62,6 +129,7 @@ int luaopen_nmap (lua_State *L) {"have_ssl", l_get_have_ssl}, {"fetchfile", l_fetchfile}, {"timing_level", l_get_timing_level}, + {"mutex", l_mutex}, {NULL, NULL} }; @@ -70,6 +138,13 @@ int luaopen_nmap (lua_State *L) lua_newtable(L); lua_setfield(L, -2, "registry"); + lua_newtable(L); + lua_createtable(L, 0, 1); + lua_pushliteral(L, "v"); + lua_setfield(L, -2, "__mode"); + lua_setmetatable(L, -2); // Allow closures to be collected (see l_mutex) + mutex_i = luaL_ref(L, LUA_REGISTRYINDEX); + SCRIPT_ENGINE_TRY(l_nsock_open(L)); SCRIPT_ENGINE_TRY(l_dnet_open(L));