Skip to content
This repository was archived by the owner on Jan 31, 2025. It is now read-only.

Commit d49a361

Browse files
committed
reporter: initial server-side capture
Rename to peekableConn and wrap it in another CaptureConn instance. Move setting of TCP Keep-Alive since a wrapped connection is not a net.TCPConn instance. To be done: save TLS version and success state.
1 parent f3199e5 commit d49a361

File tree

6 files changed

+237
-26
lines changed

6 files changed

+237
-26
lines changed

capture_conn.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ func (c *CaptureConn) captureFrame(b []byte, isRead bool) {
4545
*c.frames = append(*c.frames, frame)
4646
}
4747

48-
func (c *CaptureConn) StopCapture() {
48+
// StopCapture stops recording more frames. Returns true if a non-empty capture
49+
// was just closed and false otherwise.
50+
func (c *CaptureConn) StopCapture() bool {
51+
if c.frames == nil || len(*c.frames) == 0 {
52+
return false
53+
}
4954
c.frames = nil
55+
return true
5056
}

reporter/capture_conn.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../capture_conn.go

reporter/conn.go

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
// A net.Conn implementation which allows for buffering the initial read such
2-
// that it can be peeked into without consuming it in the actual read buffer.
1+
// Various net.Conn implementations to support peeking or capturing data.
32
package main
43

54
import (
65
"net"
76
"sync"
7+
"time"
88
)
99

10+
// A net.Conn implementation which allows for buffering the initial read such
11+
// that it can be peeked into without consuming it in the actual read buffer.
1012
type conn struct {
1113
net.Conn
1214
readBuffer []byte
@@ -51,3 +53,18 @@ func (c *conn) Read(b []byte) (int, error) {
5153
}
5254
return c.Conn.Read(b)
5355
}
56+
57+
type serverCaptureConn struct {
58+
*CaptureConn
59+
info *ServerCapture
60+
ServerCaptureReady func(*ServerCapture)
61+
}
62+
63+
func (c *serverCaptureConn) Close() error {
64+
err := c.CaptureConn.Close()
65+
if c.CaptureConn.StopCapture() {
66+
c.info.EndTime = time.Now().UTC()
67+
c.ServerCaptureReady(c.info)
68+
}
69+
return err
70+
}

reporter/db.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,93 @@ func (model *ClientCapture) Create(tx *sql.Tx) error {
213213
)
214214
return err
215215
}
216+
217+
// Create a new ServerCapture model. Required field: SubtestID, Frames,
218+
// ClientIP, ServerIP. Fields that are updated: ID, CreatedAt.
219+
func (model *ServerCapture) Create(tx *sql.Tx) error {
220+
if model.SubtestID == 0 {
221+
return errors.New("SubtestID must be initialized!")
222+
}
223+
if model.Frames == nil {
224+
return errors.New("Frames must be initialized")
225+
}
226+
if model.ClientIP == nil {
227+
return errors.New("client IP must be initialized")
228+
}
229+
if model.ServerIP == nil {
230+
return errors.New("server IP must be initialized")
231+
}
232+
clientIP := model.ClientIP.String()
233+
serverIP := model.ServerIP.String()
234+
frames, err := json.Marshal(model.Frames)
235+
if err != nil {
236+
return err
237+
}
238+
err = tx.QueryRow(`
239+
INSERT INTO server_captures (
240+
-- id,
241+
subtest_id,
242+
created_at,
243+
begin_time,
244+
end_time,
245+
actual_tls_version,
246+
frames,
247+
key_log,
248+
has_failed,
249+
client_ip,
250+
server_ip
251+
) VALUES (
252+
-- -- id,
253+
$1, -- subtest_id,
254+
now(), -- created_at,
255+
$2, -- begin_time,
256+
$3, -- end_time,
257+
$4, -- actual_tls_version,
258+
$5, -- frames,
259+
$6, -- key_log,
260+
$7, -- has_failed,
261+
$8, -- client_ip,
262+
$9 -- server_ip
263+
) RETURNING
264+
id,
265+
created_at
266+
`,
267+
//&model.ID,
268+
&model.SubtestID,
269+
//&model.CreatedAt,
270+
&model.BeginTime,
271+
&model.EndTime,
272+
&model.ActualTLSVersion,
273+
&frames,
274+
&model.KeyLog,
275+
&model.HasFailed,
276+
&clientIP,
277+
&serverIP,
278+
).Scan(
279+
&model.ID,
280+
&model.CreatedAt,
281+
)
282+
return err
283+
}
284+
285+
// QuerySubtest finds SubtestID that covers the given (testID, number) pair. No
286+
// result is returned if the test has already concluded (this is not an error).
287+
func QuerySubtest(db *sql.DB, testID string, number int, mutableTestPeriodSecs int) (int, error) {
288+
var subtestID int
289+
err := db.QueryRow(`
290+
SELECT
291+
subtests.id
292+
FROM subtests
293+
JOIN tests
294+
ON subtests.test_id = tests.id
295+
WHERE
296+
tests.test_id = $1 AND
297+
subtests.number = $2 AND
298+
is_pending AND
299+
now() - created_at < $3
300+
`, testID, number, mutableTestPeriodSecs).Scan(&subtestID)
301+
if err != nil && err != sql.ErrNoRows {
302+
return 0, err
303+
}
304+
return subtestID, nil
305+
}

reporter/listener.go

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ import (
1414
const maxHttpsQueueSize = 1024
1515

1616
// RequestClaimer is given a hostname and should return whether the listener
17-
// should claim the request and if so, whether it should record the data.
18-
type RequestClaimer func(host string) (claimed bool, record bool)
17+
// should claim the request and if a ServerCapture template if it should record.
18+
type RequestClaimer func(host string) (claimed bool, subtestID int)
1919

2020
type listener struct {
2121
net.Listener
@@ -25,19 +25,23 @@ type listener struct {
2525

2626
ClaimRequest RequestClaimer
2727

28+
// invoked when a server capture is ready.
29+
ServerCaptureReady func(*ServerCapture)
30+
2831
// queue for new connections, intended for reporter or test services.
2932
newc chan net.Conn
3033
// used to synchronize closing newc.
3134
connectionsWg sync.WaitGroup
3235
}
3336

34-
func newListener(ln net.Listener, initialReadTimeout time.Duration, originAddress string, claimer RequestClaimer) *listener {
37+
func newListener(ln net.Listener, initialReadTimeout time.Duration, originAddress string, claimer RequestClaimer, serverCaptureReady func(*ServerCapture)) *listener {
3538
newc := make(chan net.Conn, maxHttpsQueueSize)
3639
return &listener{
3740
Listener: ln,
3841
initialReadTimeout: initialReadTimeout,
3942
originAddress: originAddress,
4043
ClaimRequest: claimer,
44+
ServerCaptureReady: serverCaptureReady,
4145
newc: newc,
4246
}
4347
}
@@ -63,9 +67,16 @@ func (ln *listener) handleConnection(c net.Conn) {
6367
startTime := time.Now()
6468
c.SetReadDeadline(startTime.Add(ln.initialReadTimeout))
6569

70+
// let dead connections eventually go away (mimic
71+
// http.ListenAndServe behavior).
72+
if tc, ok := c.(*net.TCPConn); ok {
73+
tc.SetKeepAlive(true)
74+
tc.SetKeepAlivePeriod(3 * time.Minute)
75+
}
76+
6677
remoteAddr := c.RemoteAddr().String()
67-
wrappedConn := wrapConn(c)
68-
buffer, err := wrappedConn.peek(4096)
78+
peekableConn := NewPeekableConn(c)
79+
buffer, err := peekableConn.peek(4096)
6980
if len(buffer) == 0 {
7081
log.Printf("%s - failed to read a record: %v\n", remoteAddr, err)
7182
return
@@ -77,17 +88,37 @@ func (ln *listener) handleConnection(c net.Conn) {
7788
// server configuration.
7889
c.SetReadDeadline(time.Time{})
7990

80-
servedByUs, _ = ln.ClaimRequest(sni)
81-
// TODO handle TCP logging
91+
servedByUs, subtestID := ln.ClaimRequest(sni)
8292
switch {
8393
case servedByUs:
84-
ln.newc <- wrappedConn
85-
ln.connectionsWg.Done()
94+
defer ln.connectionsWg.Done()
95+
if subtestID != 0 {
96+
// TODO refactor this to have the logic in one place,
97+
// instead of scattered through conn.go and server.go
98+
serverCapture := &ServerCapture{
99+
Capture: Capture{
100+
SubtestID: subtestID,
101+
BeginTime: startTime.UTC(),
102+
Frames: []Frame{},
103+
HasFailed: true,
104+
},
105+
ClientIP: net.ParseIP(parseHost(remoteAddr)),
106+
ServerIP: net.ParseIP(parseHost(c.LocalAddr().String())),
107+
}
108+
capturedConn := &serverCaptureConn{
109+
CaptureConn: NewCaptureConn(peekableConn, &serverCapture.Frames),
110+
info: serverCapture,
111+
ServerCaptureReady: ln.ServerCaptureReady,
112+
}
113+
ln.newc <- capturedConn
114+
} else {
115+
ln.newc <- peekableConn
116+
}
86117
case ln.originAddress == "":
87118
log.Printf("%s - no upstream configured", remoteAddr)
88119
c.Write(tlsRecordUnrecognizedName)
89120
default:
90-
if err := proxyConnection(wrappedConn, ln.originAddress); err != nil {
121+
if err := proxyConnection(peekableConn, ln.originAddress); err != nil {
91122
log.Printf("%s - error proxying connection: %v\n", remoteAddr, err)
92123
}
93124
}
@@ -130,12 +161,5 @@ func (ln *listener) Accept() (net.Conn, error) {
130161
return nil, http.ErrServerClosed
131162
}
132163

133-
// let dead connections eventually go away (mimic
134-
// http.ListenAndServe behavior).
135-
if tc, ok := c.(*net.TCPConn); ok {
136-
tc.SetKeepAlive(true)
137-
tc.SetKeepAlivePeriod(3 * time.Minute)
138-
}
139-
140164
return c, nil
141165
}

reporter/server.go

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net"
1111
"net/http"
1212
"os"
13+
"strconv"
1314
"strings"
1415
"time"
1516

@@ -45,19 +46,91 @@ func isTestHost(host string, config *Config) bool {
4546
return strings.HasSuffix(host, config.HostSuffixIPv4) || strings.HasSuffix(host, config.HostSuffixIPv6)
4647
}
4748

48-
func makeIsOurHost(config *Config) RequestClaimer {
49-
return func(host string) (bool, bool) {
49+
func makeIsOurHost(db *sql.DB, config *Config) RequestClaimer {
50+
return func(host string) (bool, int) {
5051
host = strings.ToLower(host)
5152
if host == config.HostReporter {
5253
// pass to HTTP handler, handle API requests.
53-
return true, false
54+
return true, 0
5455
}
5556
if isTestHost(host, config) {
5657
// pass to HTTP handler, handling a basic response.
5758
// Logging is tentatively enabled.
58-
return true, true
59+
return true, prepareServerCapture(db, config, host)
5960
}
60-
return false, false
61+
return false, 0
62+
}
63+
}
64+
65+
func prepareServerCapture(db *sql.DB, config *Config, host string) int {
66+
testID, number := parseTestHost(config, host)
67+
if testID == "" {
68+
log.Printf("Host \"%s\" is not a valid test domain, ignoring", host)
69+
return 0
70+
}
71+
subtestID, err := QuerySubtest(db, testID, number, config.MutableTestPeriodSecs)
72+
if err != nil {
73+
log.Printf("Failed to query subtest for \"%s\": %s", host, err)
74+
return 0
75+
}
76+
if subtestID == 0 {
77+
log.Printf("Not accepting server capture for \"%s\"", host)
78+
return 0
79+
}
80+
return subtestID
81+
}
82+
83+
// parses a host name of the form "<testID>-<number><suffix>", returning the
84+
// TestID and subtest number. On error, the testID is empty.
85+
func parseTestHost(config *Config, host string) (string, int) {
86+
var prefix string
87+
switch {
88+
case strings.HasSuffix(host, config.HostSuffixIPv4):
89+
prefix = host[:len(host)-len(config.HostSuffixIPv4)]
90+
case strings.HasSuffix(host, config.HostSuffixIPv6):
91+
prefix = host[:len(host)-len(config.HostSuffixIPv6)]
92+
default:
93+
return "", 0
94+
}
95+
96+
// testID UUID is always 36 chars followed by "-" and number.
97+
if len(prefix) < 36+2 || prefix[36] != '-' {
98+
return "", 0
99+
}
100+
testID, numberStr := prefix[0:36], prefix[37:]
101+
if !ValidateUUID(testID) {
102+
return "", 0
103+
}
104+
number, err := strconv.Atoi(numberStr)
105+
if err != nil || number <= 0 {
106+
return "", 0
107+
}
108+
109+
return testID, number
110+
}
111+
112+
func newServerCaptureReady(db *sql.DB) func(*ServerCapture) {
113+
return func(serverCapture *ServerCapture) {
114+
tx, err := db.Begin()
115+
if err != nil {
116+
log.Printf("Failed to begin transaction: %s", err)
117+
return
118+
}
119+
defer func() {
120+
if tx != nil {
121+
tx.Rollback()
122+
}
123+
}()
124+
125+
err = serverCapture.Create(tx)
126+
if err != nil {
127+
log.Printf("Failed to create server capture: %s", err)
128+
return
129+
}
130+
131+
log.Printf("Stored server capture: %d", serverCapture.ID)
132+
tx.Commit()
133+
tx = nil
61134
}
62135
}
63136

@@ -142,7 +215,7 @@ func main() {
142215
panic(err)
143216
}
144217
initialReadTimeout := time.Duration(config.InitialReadTimeoutSecs) * time.Second
145-
wl := newListener(l, initialReadTimeout, config.OriginAddress, makeIsOurHost(config))
218+
wl := newListener(l, initialReadTimeout, config.OriginAddress, makeIsOurHost(db, config), newServerCaptureReady(db))
146219
go wl.Serve()
147220

148221
hostRouter := &hostHandler{

0 commit comments

Comments
 (0)