From 33bfd251b4dd9eebdf8b254adb6288bbf8d31801 Mon Sep 17 00:00:00 2001 From: dmiller Date: Mon, 11 Nov 2024 21:07:00 +0000 Subject: [PATCH] Macro for common idiom in nse_libssh2 --- nse_libssh2.cc | 142 +++++++++++++++---------------------------------- 1 file changed, 44 insertions(+), 98 deletions(-) diff --git a/nse_libssh2.cc b/nse_libssh2.cc index 10e960968..0584b7f7e 100644 --- a/nse_libssh2.cc +++ b/nse_libssh2.cc @@ -277,6 +277,13 @@ static int filter (lua_State *L) { return finish_read(L, 0, 0); } +#define DO_OR_YIELD(_Stmt, _Sshu_index, _Func, _Ctx) \ + while ((_Stmt) == LIBSSH2_ERROR_EAGAIN) { \ + luaL_getmetafield(L, (_Sshu_index), "filter"); \ + lua_pushvalue(L, (_Sshu_index)); \ + lua_callk(L, 1, 0, (_Ctx), (_Func)); \ + } + static int do_session_handshake (lua_State *L, int status, lua_KContext ctx) { int rc; struct ssh_userdata *sshu = NULL; @@ -284,13 +291,8 @@ static int do_session_handshake (lua_State *L, int status, lua_KContext ctx) { assert(lua_gettop(L) == 4); sshu = (struct ssh_userdata *) nseU_checkudata(L, 3, SSH2_UDATA, "ssh2"); - while ((rc = libssh2_session_handshake(sshu->session, sshu->sp[0])) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 3, "filter"); - lua_pushvalue(L, 3); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, 0, do_session_handshake); - } + DO_OR_YIELD((rc = libssh2_session_handshake(sshu->session, sshu->sp[0])), + 3, do_session_handshake, ctx); if (rc) { libssh2_session_free(sshu->session); @@ -447,14 +449,9 @@ static int userauth_list (lua_State *L, int status, lua_KContext ctx) { state = (struct ssh_userdata *) nseU_checkudata(L, 1, SSH2_UDATA, "ssh2"); assert(state->session != NULL); - while ((auth_list = libssh2_userauth_list(state->session, username, lua_rawlen(L, 2))) == NULL - && libssh2_session_last_errno(state->session) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, 0, userauth_list); - } + DO_OR_YIELD(((auth_list = libssh2_userauth_list(state->session, username, lua_rawlen(L, 2))) == NULL ? + libssh2_session_last_errno(state->session) : LIBSSH2_ERROR_NONE), + 1, userauth_list, ctx); if (auth_list) { const char *auth = strtok(auth_list, ","); @@ -530,16 +527,11 @@ static void validate_publickey_params(lua_State *L, struct publickey_ctx *ctx) { static int userauth_publickey (lua_State *L, int status, lua_KContext ctx) { struct publickey_ctx *context = (struct publickey_ctx *)ctx; int rc; - while ((rc = libssh2_userauth_publickey_fromfile_ex( + DO_OR_YIELD((rc = libssh2_userauth_publickey_fromfile_ex( context->state->session, context->username, context->username_len, context->pubkey, context->privkey, context->passphrase - )) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, ctx, userauth_publickey); - } + )), + 1, userauth_publickey, ctx); lua_pushboolean(L, (rc == 0)); @@ -555,17 +547,12 @@ static int l_userauth_publickey (lua_State *L) { static int userauth_publickey_frommemory (lua_State *L, int status, lua_KContext ctx) { struct publickey_ctx *context = (struct publickey_ctx *)ctx; int rc; - while ((rc = libssh2_userauth_publickey_frommemory( + DO_OR_YIELD((rc = libssh2_userauth_publickey_frommemory( context->state->session, context->username, context->username_len, context->pubkey, context->pubkey_len, context->privkey, context->privkey_len, context->passphrase - )) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, ctx, userauth_publickey_frommemory); - } + )), + 1, userauth_publickey_frommemory, ctx); lua_pushboolean(L, (rc == 0)); @@ -630,14 +617,9 @@ static int publickey_canauth (lua_State *L, int status, lua_KContext ctx) { else return luaL_error(L, "Invalid public key"); - while ((rc = libssh2_userauth_publickey(state->session, - username, publickey_data, len, &publickey_canauth_cb, NULL)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, 0, publickey_canauth); - } + DO_OR_YIELD((rc = libssh2_userauth_publickey(state->session, + username, publickey_data, len, &publickey_canauth_cb, NULL)), + 1, publickey_canauth, ctx); libssh2_session_last_error(state->session, &errmsg, NULL, 0); @@ -675,14 +657,8 @@ static int userauth_password (lua_State *L, int status, lua_KContext ctx) { password = luaL_checkstring(L, 3); assert(state->session != NULL); - while ((rc = libssh2_userauth_password(state->session, - username, password)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, 0, userauth_password); - } + DO_OR_YIELD((rc = libssh2_userauth_password(state->session, username, password)), + 1, userauth_password, ctx); if (rc == 0) lua_pushboolean(L, 1); @@ -703,14 +679,8 @@ static int session_close (lua_State *L, int status, lua_KContext ctx) { state = (struct ssh_userdata *) nseU_checkudata(L, 1, SSH2_UDATA, "ssh2"); if (state->session != NULL) { - while ((rc = libssh2_session_disconnect( - state->session, "Normal Shutdown")) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - - assert(lua_status(L) == LUA_OK); - lua_callk(L, 1, 0, 0, session_close); - } + DO_OR_YIELD((rc = libssh2_session_disconnect(state->session, "Normal Shutdown")), + 1, session_close, ctx); if (rc < 0) return luaL_error(L, "unable to disconnect session"); @@ -735,11 +705,8 @@ static int channel_read (lua_State *L, int status, lua_KContext ctx) { LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **) lua_touserdata(L, 2); int stream_id = luaL_checkinteger(L, 3); - while ((rc = libssh2_channel_read_ex(*channel, stream_id, buf, buflen)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, channel_read); - } + DO_OR_YIELD((rc = libssh2_channel_read_ex(*channel, stream_id, buf, buflen)), + 1, channel_read, ctx); if (rc > 0) { lua_pushlstring(L, buf, rc); @@ -774,11 +741,8 @@ static int channel_write (lua_State *L, int status, lua_KContext ctx) { else return luaL_error(L, "Invalid buffer"); - while ((rc = libssh2_channel_write(*channel, buf, buflen)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, channel_write); - } + DO_OR_YIELD((rc = libssh2_channel_write(*channel, buf, buflen)), + 1, channel_write, ctx); if (rc < 0) return luaL_error(L, "Writing to channel"); @@ -797,11 +761,8 @@ static int channel_exec (lua_State *L, int status, lua_KContext ctx) { LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **) lua_touserdata(L, 2); const char *cmd = luaL_checkstring(L, 3); - while ((rc = libssh2_channel_exec(*channel, cmd)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, channel_exec); - } + DO_OR_YIELD((rc = libssh2_channel_exec(*channel, cmd)), + 1, channel_exec, ctx); if (rc != 0) return luaL_error(L, "Error executing command"); @@ -830,11 +791,8 @@ static int channel_send_eof(lua_State *L, int status, lua_KContext ctx) { // ssh_userdata *state = (ssh_userdata *)lua_touserdata(L, 1); LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **) lua_touserdata(L, 2); - while ((rc = libssh2_channel_send_eof(*channel)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, channel_send_eof); - } + DO_OR_YIELD((rc = libssh2_channel_send_eof(*channel)), + 1, channel_send_eof, ctx); if (rc != 0) return luaL_error(L, "Error sending EOF"); @@ -850,11 +808,8 @@ static int setup_channel(lua_State *L, int status, lua_KContext ctx) { // ssh_userdata *state = (ssh_userdata *)lua_touserdata(L, 1); LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **) lua_touserdata(L, 2); - while ((rc = libssh2_channel_request_pty(*channel, "vanilla")) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, setup_channel); - } + DO_OR_YIELD((rc = libssh2_channel_request_pty(*channel, "vanilla")), + 1, setup_channel, ctx); if (rc != 0) return luaL_error(L, "Requesting pty"); @@ -869,13 +824,10 @@ static int finish_open_channel (lua_State *L, int status, lua_KContext ctx) { ssh_userdata *state = (ssh_userdata *)lua_touserdata(L, 1); LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **) lua_touserdata(L, 2); - while ((*channel = libssh2_channel_open_session(state->session)) == NULL - && libssh2_session_last_errno(state->session) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, finish_open_channel); - } - if (channel == NULL) + DO_OR_YIELD(((*channel = libssh2_channel_open_session(state->session)) == NULL ? + libssh2_session_last_errno(state->session) : LIBSSH2_ERROR_NONE), + 1, finish_open_channel, ctx); + if (*channel == NULL) return luaL_error(L, "Opening channel"); return setup_channel(L, 0, 0); @@ -885,12 +837,9 @@ static int l_open_channel (lua_State *L) { ssh_userdata *state = (ssh_userdata *)lua_touserdata(L, 1); LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **)lua_newuserdatauv(L, sizeof(LIBSSH2_CHANNEL *), 0); - while ((*channel = libssh2_channel_open_session(state->session)) == NULL - && libssh2_session_last_errno(state->session) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, finish_open_channel); - } + DO_OR_YIELD(((*channel = libssh2_channel_open_session(state->session)) == NULL ? + libssh2_session_last_errno(state->session) : LIBSSH2_ERROR_NONE), + 1, finish_open_channel, 0); return l_setup_channel(L); } @@ -900,11 +849,8 @@ static int channel_close (lua_State *L, int status, lua_KContext ctx) { // ssh_userdata *state = (ssh_userdata *)lua_touserdata(L, 1); LIBSSH2_CHANNEL **channel = (LIBSSH2_CHANNEL **) lua_touserdata(L, 2); - while ((rc = libssh2_channel_close(*channel)) == LIBSSH2_ERROR_EAGAIN) { - luaL_getmetafield(L, 1, "filter"); - lua_pushvalue(L, 1); - lua_callk(L, 1, 0, 0, channel_close); - } + DO_OR_YIELD((rc = libssh2_channel_close(*channel)), + 1, channel_close, ctx); if (rc != 0) return luaL_error(L, "Error closing channel");;