Skip to content

Commit 6f0aa32

Browse files
committed
move address resolution into lua
1 parent 93348e2 commit 6f0aa32

File tree

6 files changed

+173
-55
lines changed

6 files changed

+173
-55
lines changed

src/main.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <fcntl.h>
77
#include <getopt.h>
88
#include <math.h>
9-
#include <netdb.h>
109
#include <netinet/in.h>
1110
#include <netinet/tcp.h>
1211
#include <stdarg.h>

src/script.c

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,28 @@
44
#include <string.h>
55
#include "script.h"
66
#include "http_parser.h"
7+
#include "zmalloc.h"
78

89
typedef struct {
910
char *name;
1011
int type;
1112
void *value;
1213
} table_field;
1314

15+
static int script_addr_tostring(lua_State *);
16+
static int script_addr_gc(lua_State *);
1417
static int script_stats_len(lua_State *);
1518
static int script_stats_get(lua_State *);
19+
static int script_wrk_lookup(lua_State *);
20+
static int script_wrk_connect(lua_State *);
1621
static void set_fields(lua_State *, int index, const table_field *);
1722

23+
static const struct luaL_reg addrlib[] = {
24+
{ "__tostring", script_addr_tostring },
25+
{ "__gc" , script_addr_gc },
26+
{ NULL, NULL }
27+
};
28+
1829
static const struct luaL_reg statslib[] = {
1930
{ "__index", script_stats_get },
2031
{ "__len", script_stats_len },
@@ -26,16 +37,22 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) {
2637
luaL_openlibs(L);
2738
luaL_dostring(L, "wrk = require \"wrk\"");
2839

40+
luaL_newmetatable(L, "wrk.addr");
41+
luaL_register(L, NULL, addrlib);
42+
lua_pop(L, 1);
43+
2944
luaL_newmetatable(L, "wrk.stats");
3045
luaL_register(L, NULL, statslib);
3146
lua_pop(L, 1);
3247

3348
const table_field fields[] = {
34-
{ "scheme", LUA_TSTRING, scheme },
35-
{ "host", LUA_TSTRING, host },
36-
{ "port", LUA_TSTRING, port },
37-
{ "path", LUA_TSTRING, path },
38-
{ NULL, 0, NULL },
49+
{ "scheme", LUA_TSTRING, scheme },
50+
{ "host", LUA_TSTRING, host },
51+
{ "port", LUA_TSTRING, port },
52+
{ "path", LUA_TSTRING, path },
53+
{ "lookup", LUA_TFUNCTION, script_wrk_lookup },
54+
{ "connect", LUA_TFUNCTION, script_wrk_connect },
55+
{ NULL, 0, NULL },
3956
};
4057

4158
lua_getglobal(L, "wrk");
@@ -45,6 +62,36 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) {
4562
return L;
4663
}
4764

65+
void script_prepare_setup(lua_State *L, char *script) {
66+
if (script && luaL_dofile(L, script)) {
67+
const char *cause = lua_tostring(L, -1);
68+
fprintf(stderr, "%s: %s\n", script, cause);
69+
}
70+
}
71+
72+
bool script_resolve(lua_State *L, char *host, char *service) {
73+
lua_getglobal(L, "wrk");
74+
75+
lua_getfield(L, -1, "resolve");
76+
lua_pushstring(L, host);
77+
lua_pushstring(L, service);
78+
lua_call(L, 2, 0);
79+
80+
lua_getfield(L, -1, "addrs");
81+
size_t count = lua_objlen(L, -1);
82+
lua_pop(L, 2);
83+
return count > 0;
84+
}
85+
86+
struct addrinfo *script_peek_addr(lua_State *L) {
87+
lua_getglobal(L, "wrk");
88+
lua_getfield(L, -1, "addrs");
89+
lua_rawgeti(L, -1, 1);
90+
struct addrinfo *addr = lua_touserdata(L, -1);
91+
lua_pop(L, 3);
92+
return addr;
93+
}
94+
4895
void script_headers(lua_State *L, char **headers) {
4996
lua_getglobal(L, "wrk");
5097
lua_getfield(L, 1, "headers");
@@ -225,6 +272,34 @@ size_t script_verify_request(lua_State *L) {
225272
return count;
226273
}
227274

275+
static struct addrinfo *checkaddr(lua_State *L) {
276+
struct addrinfo *addr = luaL_checkudata(L, -1, "wrk.addr");
277+
luaL_argcheck(L, addr != NULL, 1, "`addr' expected");
278+
return addr;
279+
}
280+
281+
static int script_addr_tostring(lua_State *L) {
282+
struct addrinfo *addr = checkaddr(L);
283+
char host[NI_MAXHOST];
284+
char service[NI_MAXSERV];
285+
286+
int flags = NI_NUMERICHOST | NI_NUMERICSERV;
287+
int rc = getnameinfo(addr->ai_addr, addr->ai_addrlen, host, NI_MAXHOST, service, NI_MAXSERV, flags);
288+
if (rc != 0) {
289+
const char *msg = gai_strerror(rc);
290+
return luaL_error(L, "addr tostring failed %s", msg);
291+
}
292+
293+
lua_pushfstring(L, "%s:%s", host, service);
294+
return 1;
295+
}
296+
297+
static int script_addr_gc(lua_State *L) {
298+
struct addrinfo *addr = checkaddr(L);
299+
zfree(addr->ai_addr);
300+
return 0;
301+
}
302+
228303
static stats *checkstats(lua_State *L) {
229304
stats **s = luaL_checkudata(L, 1, "wrk.stats");
230305
luaL_argcheck(L, s != NULL, 1, "`stats' expected");
@@ -262,10 +337,57 @@ static int script_stats_len(lua_State *L) {
262337
return 1;
263338
}
264339

340+
static int script_wrk_lookup(lua_State *L) {
341+
struct addrinfo *addrs;
342+
struct addrinfo hints = {
343+
.ai_family = AF_UNSPEC,
344+
.ai_socktype = SOCK_STREAM
345+
};
346+
int rc, index = 1;
347+
348+
const char *host = lua_tostring(L, -2);
349+
const char *service = lua_tostring(L, -1);
350+
351+
if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) {
352+
const char *msg = gai_strerror(rc);
353+
fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg);
354+
exit(1);
355+
}
356+
357+
lua_newtable(L);
358+
for (struct addrinfo *addr = addrs; addr != NULL; addr = addr->ai_next) {
359+
struct addrinfo *udata = lua_newuserdata(L, sizeof(*udata));
360+
luaL_getmetatable(L, "wrk.addr");
361+
lua_setmetatable(L, -2);
362+
363+
*udata = *addr;
364+
udata->ai_addr = zmalloc(addr->ai_addrlen);
365+
memcpy(udata->ai_addr, addr->ai_addr, addr->ai_addrlen);
366+
lua_rawseti(L, -2, index++);
367+
}
368+
369+
freeaddrinfo(addrs);
370+
return 1;
371+
}
372+
373+
static int script_wrk_connect(lua_State *L) {
374+
struct addrinfo *addr = checkaddr(L);
375+
int fd, connected = 0;
376+
if ((fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol)) != -1) {
377+
connected = connect(fd, addr->ai_addr, addr->ai_addrlen) == 0;
378+
close(fd);
379+
}
380+
lua_pushboolean(L, connected);
381+
return 1;
382+
}
383+
265384
static void set_fields(lua_State *L, int index, const table_field *fields) {
266385
for (int i = 0; fields[i].name; i++) {
267386
table_field f = fields[i];
268387
switch (f.value == NULL ? LUA_TNIL : f.type) {
388+
case LUA_TFUNCTION:
389+
lua_pushcfunction(L, (lua_CFunction) f.value);
390+
break;
269391
case LUA_TNUMBER:
270392
lua_pushinteger(L, *((lua_Integer *) f.value));
271393
break;

src/script.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#include <lua.h>
66
#include <lualib.h>
77
#include <lauxlib.h>
8+
#include <sys/types.h>
9+
#include <netdb.h>
10+
#include <unistd.h>
811
#include "stats.h"
912

1013
typedef struct {
@@ -14,6 +17,9 @@ typedef struct {
1417
} buffer;
1518

1619
lua_State *script_create(char *, char *, char *, char *);
20+
void script_prepare_setup(lua_State *, char *);
21+
bool script_resolve(lua_State *, char *, char *);
22+
struct addrinfo *script_peek_addr(lua_State *);
1723
void script_headers(lua_State *, char **);
1824
size_t script_verify_request(lua_State *L);
1925

src/wrk.c

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "main.h"
55

66
static struct config {
7-
struct addrinfo addr;
87
uint64_t threads;
98
uint64_t connections;
109
uint64_t duration;
@@ -58,10 +57,8 @@ static void usage() {
5857
}
5958

6059
int main(int argc, char **argv) {
61-
struct addrinfo *addrs, *addr;
6260
struct http_parser_url parser_url;
6361
char *url, **headers;
64-
int rc;
6562

6663
headers = zmalloc((argc / 2) * sizeof(char *));
6764

@@ -85,26 +82,9 @@ int main(int argc, char **argv) {
8582
path = &url[parser_url.field_data[UF_PATH].off];
8683
}
8784

88-
struct addrinfo hints = {
89-
.ai_family = AF_UNSPEC,
90-
.ai_socktype = SOCK_STREAM
91-
};
92-
93-
if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) {
94-
const char *msg = gai_strerror(rc);
95-
fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg);
96-
exit(1);
97-
}
98-
99-
for (addr = addrs; addr != NULL; addr = addr->ai_next) {
100-
int fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
101-
if (fd == -1) continue;
102-
rc = connect(fd, addr->ai_addr, addr->ai_addrlen);
103-
close(fd);
104-
if (rc == 0) break;
105-
}
106-
107-
if (addr == NULL) {
85+
lua_State *L = script_create(schema, host, port, path);
86+
script_prepare_setup(L, cfg.script);
87+
if (!script_resolve(L, host, service)) {
10888
char *msg = strerror(errno);
10989
fprintf(stderr, "unable to connect to %s:%s %s\n", host, service, msg);
11090
exit(1);
@@ -125,7 +105,6 @@ int main(int argc, char **argv) {
125105

126106
signal(SIGPIPE, SIG_IGN);
127107
signal(SIGINT, SIG_IGN);
128-
cfg.addr = *addr;
129108

130109
pthread_mutex_init(&statistics.mutex, NULL);
131110
statistics.latency = stats_alloc(SAMPLES);
@@ -138,6 +117,7 @@ int main(int argc, char **argv) {
138117
for (uint64_t i = 0; i < cfg.threads; i++) {
139118
thread *t = &threads[i];
140119
t->loop = aeCreateEventLoop(10 + cfg.connections * 3);
120+
t->addr = script_peek_addr(L);
141121
t->connections = connections;
142122
t->stop_at = stop_at;
143123

@@ -217,7 +197,6 @@ int main(int argc, char **argv) {
217197
printf("Requests/sec: %9.2Lf\n", req_per_s);
218198
printf("Transfer/sec: %10sB\n", format_binary(bytes_per_s));
219199

220-
lua_State *L = threads[0].L;
221200
if (script_has_done(L)) {
222201
script_summary(L, runtime_us, complete, bytes);
223202
script_errors(L, &errors);
@@ -274,16 +253,16 @@ void *thread_main(void *arg) {
274253
}
275254

276255
static int connect_socket(thread *thread, connection *c) {
277-
struct addrinfo addr = cfg.addr;
256+
struct addrinfo *addr = thread->addr;
278257
struct aeEventLoop *loop = thread->loop;
279258
int fd, flags;
280259

281-
fd = socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
260+
fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
282261

283262
flags = fcntl(fd, F_GETFL, 0);
284263
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
285264

286-
if (connect(fd, addr.ai_addr, addr.ai_addrlen) == -1) {
265+
if (connect(fd, addr->ai_addr, addr->ai_addrlen) == -1) {
287266
if (errno != EINPROGRESS) goto error;
288267
}
289268

src/wrk.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <pthread.h>
66
#include <inttypes.h>
77
#include <sys/types.h>
8+
#include <netdb.h>
89

910
#include <openssl/ssl.h>
1011
#include <openssl/err.h>
@@ -25,6 +26,7 @@
2526
typedef struct {
2627
pthread_t thread;
2728
aeEventLoop *loop;
29+
struct addrinfo *addr;
2830
uint64_t connections;
2931
int interval;
3032
uint64_t stop_at;

src/wrk.lua

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,14 @@ local wrk = {
88
body = nil
99
}
1010

11-
function wrk.format(method, path, headers, body)
12-
local method = method or wrk.method
13-
local path = path or wrk.path
14-
local headers = headers or wrk.headers
15-
local body = body or wrk.body
16-
local s = {}
17-
18-
if not headers["Host"] then
19-
headers["Host"] = wrk.headers["Host"]
11+
function wrk.resolve(host, service)
12+
local addrs = wrk.lookup(host, service)
13+
for i = #addrs, 1, -1 do
14+
if not wrk.connect(addrs[i]) then
15+
table.remove(addrs, i)
16+
end
2017
end
21-
22-
headers["Content-Length"] = body and string.len(body)
23-
24-
s[1] = string.format("%s %s HTTP/1.1", method, path)
25-
for name, value in pairs(headers) do
26-
s[#s+1] = string.format("%s: %s", name, value)
27-
end
28-
29-
s[#s+1] = ""
30-
s[#s+1] = body or ""
31-
32-
return table.concat(s, "\r\n")
18+
wrk.addrs = addrs
3319
end
3420

3521
function wrk.init(args)
@@ -53,4 +39,28 @@ function wrk.init(args)
5339
end
5440
end
5541

42+
function wrk.format(method, path, headers, body)
43+
local method = method or wrk.method
44+
local path = path or wrk.path
45+
local headers = headers or wrk.headers
46+
local body = body or wrk.body
47+
local s = {}
48+
49+
if not headers["Host"] then
50+
headers["Host"] = wrk.headers["Host"]
51+
end
52+
53+
headers["Content-Length"] = body and string.len(body)
54+
55+
s[1] = string.format("%s %s HTTP/1.1", method, path)
56+
for name, value in pairs(headers) do
57+
s[#s+1] = string.format("%s: %s", name, value)
58+
end
59+
60+
s[#s+1] = ""
61+
s[#s+1] = body or ""
62+
63+
return table.concat(s, "\r\n")
64+
end
65+
5666
return wrk

0 commit comments

Comments
 (0)