diff --git a/.gitignore b/.gitignore index 30f3f28..1597450 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ # vendor/ **/tls-key-log.txt +/results diff --git a/client/client.go b/client/client.go index 7a8d533..163b83d 100644 --- a/client/client.go +++ b/client/client.go @@ -30,6 +30,7 @@ type Options struct { DNSSEC bool ValidateOnly bool StrictValidation bool + KeepAlive bool // New flag for long-lived connections } // New creates a DNS client based on the upstream string @@ -214,8 +215,8 @@ func getDefaultPath(scheme string) string { } func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) { - logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v", - scheme, host, port, path, opts.DNSSEC) + logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v, KeepAlive=%v", + scheme, host, port, path, opts.DNSSEC, opts.KeepAlive) switch scheme { case "udp", "tcp", "do53", "": @@ -228,40 +229,44 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err case "http", "doh": config := doh.Config{ - Host: host, - Port: port, - Path: path, - DNSSEC: opts.DNSSEC, - HTTP3: false, + Host: host, + Port: port, + Path: path, + DNSSEC: opts.DNSSEC, + HTTP3: false, + KeepAlive: opts.KeepAlive, } logger.Debug("Creating DoH client with config: %+v", config) return doh.New(config) case "https", "doh3": config := doh.Config{ - Host: host, - Port: port, - Path: path, - DNSSEC: opts.DNSSEC, - HTTP3: true, + Host: host, + Port: port, + Path: path, + DNSSEC: opts.DNSSEC, + HTTP3: true, + KeepAlive: opts.KeepAlive, } logger.Debug("Creating DoH3 client with config: %+v", config) return doh.New(config) case "tls", "dot": config := dot.Config{ - Host: host, - Port: port, - DNSSEC: opts.DNSSEC, + Host: host, + Port: port, + DNSSEC: opts.DNSSEC, + KeepAlive: opts.KeepAlive, } logger.Debug("Creating DoT client with config: %+v", config) return dot.New(config) case "doq": // DNS over QUIC config := doq.Config{ - Host: host, - Port: port, - DNSSEC: opts.DNSSEC, + Host: host, + Port: port, + DNSSEC: opts.DNSSEC, + KeepAlive: opts.KeepAlive, } logger.Debug("Creating DoQ client with config: %+v", config) return doq.New(config) diff --git a/cmd/qol/qol.go b/cmd/qol/qol.go index b424781..9ff3779 100644 --- a/cmd/qol/qol.go +++ b/cmd/qol/qol.go @@ -17,6 +17,7 @@ type RunCmd struct { QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"` Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"` DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"` + KeepAlive bool `short:"k" long:"keep-alive" help:"Use persistent connections"` Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"` Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"` } @@ -27,6 +28,7 @@ func (r *RunCmd) Run() error { OutputDir: r.OutputDir, QueryType: r.QueryType, DNSSEC: r.DNSSEC, + KeepAlive: r.KeepAlive, Interface: r.Interface, Servers: r.Servers, } diff --git a/cmd/sdns-proxy/sdns-proxy.go b/cmd/sdns-proxy/sdns-proxy.go index 254004b..3f94cc5 100644 --- a/cmd/sdns-proxy/sdns-proxy.go +++ b/cmd/sdns-proxy/sdns-proxy.go @@ -26,6 +26,7 @@ type QueryCmd struct { DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"` ValidateOnly bool `help:"Only return DNSSEC validated responses." short:"V"` StrictValidation bool `help:"Fail on any DNSSEC validation error." short:"S"` + KeepAlive bool `help:"Use persistent connections." short:"k"` Timeout time.Duration `help:"Timeout for the query operation." default:"10s"` KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"` } @@ -36,18 +37,20 @@ type ListenCmd struct { Fallback string `help:"Fallback DNS server (e.g., https://1.1.1.1/dns-query, tls://8.8.8.8)." short:"f"` Bootstrap string `help:"Bootstrap DNS server (must be an IP address, e.g., 8.8.8.8, 1.1.1.1)." short:"b"` DNSSEC bool `help:"Enable DNSSEC for upstream queries." short:"d"` + KeepAlive bool `help:"Use persistent connections to upstream servers." short:"k"` Timeout time.Duration `help:"Timeout for upstream queries." default:"5s"` Verbose bool `help:"Enable verbose logging." short:"v"` } func (q *QueryCmd) Run() error { - logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)", - q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.Timeout) + logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, KeepAlive: %v, Timeout: %v)", + q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.KeepAlive, q.Timeout) opts := client.Options{ DNSSEC: q.DNSSEC, ValidateOnly: q.ValidateOnly, StrictValidation: q.StrictValidation, + KeepAlive: q.KeepAlive, } logger.Debug("Creating DNS client with options: %+v", opts) @@ -90,6 +93,7 @@ func (l *ListenCmd) Run() error { Fallback: l.Fallback, Bootstrap: l.Bootstrap, DNSSEC: l.DNSSEC, + KeepAlive: l.KeepAlive, Timeout: l.Timeout, Verbose: l.Verbose, } @@ -105,10 +109,12 @@ func (l *ListenCmd) Run() error { logger.Info("Upstream server: %v", l.Upstream) logger.Info("Fallback server: %v", l.Fallback) logger.Info("Bootstrap server: %v", l.Bootstrap) + logger.Info("KeepAlive: %v", l.KeepAlive) return srv.Start() } + func printResponse(domain, qtype string, msg *dns.Msg) { fmt.Println(";; QUESTION SECTION:") diff --git a/common/protocols/doh/doh.go b/common/protocols/doh/doh.go index b42e873..c1d7a57 100644 --- a/common/protocols/doh/doh.go +++ b/common/protocols/doh/doh.go @@ -22,12 +22,13 @@ import ( const dnsMessageContentType = "application/dns-message" type Config struct { - Host string - Port string - Path string - DNSSEC bool - HTTP3 bool - HTTP2 bool + Host string + Port string + Path string + DNSSEC bool + HTTP3 bool + HTTP2 bool + KeepAlive bool } type Client struct { @@ -37,7 +38,7 @@ type Client struct { } func New(config Config) (*Client, error) { - logger.Debug("Creating DoH client: %s:%s%s", config.Host, config.Port, config.Path) + logger.Debug("Creating DoH client: %s:%s%s (KeepAlive: %v)", config.Host, config.Port, config.Path, config.KeepAlive) if config.Host == "" || config.Port == "" || config.Path == "" { logger.Error("DoH client creation failed: missing required fields") @@ -65,31 +66,58 @@ func New(config Config) (*Client, error) { DisablePathMTUDiscovery: true, } - transport := http.DefaultTransport.(*http.Transport) - transport.TLSClientConfig = tlsConfig + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + } + + // Configure connection pooling based on KeepAlive setting + if config.KeepAlive { + transport.MaxIdleConnsPerHost = 10 + transport.MaxConnsPerHost = 10 + } else { + transport.MaxIdleConns = 0 + transport.MaxIdleConnsPerHost = 0 + transport.DisableKeepAlives = true + } + httpClient := &http.Client{ Transport: transport, + Timeout: 30 * time.Second, } var transportType string if config.HTTP2 { - httpClient.Transport = &http2.Transport{ + http2Transport := &http2.Transport{ TLSClientConfig: tlsConfig, AllowHTTP: true, } + + if !config.KeepAlive { + http2Transport.DisableCompression = true + } + + httpClient.Transport = http2Transport transportType = "HTTP/2" } else if config.HTTP3 { quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig) - httpClient.Transport = &http3.Transport{ + http3Transport := &http3.Transport{ TLSClientConfig: quicTlsConfig, QUICConfig: quicConfig, } + + if !config.KeepAlive { + http3Transport.DisableCompression = true + } + + httpClient.Transport = http3Transport transportType = "HTTP/3" } else { transportType = "HTTP/1.1" } - logger.Debug("DoH client created: %s (%s, DNSSEC: %v)", rawURL, transportType, config.DNSSEC) + logger.Debug("DoH client created: %s (%s, DNSSEC: %v, KeepAlive: %v)", rawURL, transportType, config.DNSSEC, config.KeepAlive) return &Client{ httpClient: httpClient, @@ -102,6 +130,8 @@ func (c *Client) Close() { logger.Debug("Closing DoH client") if t, ok := c.httpClient.Transport.(*http.Transport); ok { t.CloseIdleConnections() + } else if t2, ok := c.httpClient.Transport.(*http2.Transport); ok { + t2.CloseIdleConnections() } else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok { t3.CloseIdleConnections() } @@ -132,6 +162,13 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { httpReq.Header.Set("User-Agent", "sdns-proxy") httpReq.Header.Set("Content-Type", dnsMessageContentType) httpReq.Header.Set("Accept", dnsMessageContentType) + + // Set Connection header based on KeepAlive setting + if c.config.KeepAlive { + httpReq.Header.Set("Connection", "keep-alive") + } else { + httpReq.Header.Set("Connection", "close") + } httpResp, err := c.httpClient.Do(httpReq) if err != nil { diff --git a/common/protocols/doq/doq.go b/common/protocols/doq/doq.go index 4f272ca..5ce2efb 100644 --- a/common/protocols/doq/doq.go +++ b/common/protocols/doq/doq.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/afonsofrancof/sdns-proxy/common/logger" @@ -16,10 +17,11 @@ import ( ) type Config struct { - Host string - Port string - Debug bool - DNSSEC bool + Host string + Port string + Debug bool + DNSSEC bool + KeepAlive bool } type Client struct { @@ -30,10 +32,11 @@ type Client struct { quicTransport *quic.Transport quicConfig *quic.Config config Config + connMutex sync.Mutex } func New(config Config) (*Client, error) { - logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port) + logger.Debug("Creating DoQ client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive) tlsConfig := &tls.Config{ ServerName: config.Host, @@ -62,9 +65,7 @@ func New(config Config) (*Client, error) { MaxIdleTimeout: 30 * time.Second, } - logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC) - - return &Client{ + client := &Client{ targetAddr: targetAddr, tlsConfig: tlsConfig, udpConn: udpConn, @@ -72,18 +73,52 @@ func New(config Config) (*Client, error) { quicTransport: &quicTransport, quicConfig: &quicConfig, config: config, - }, nil + } + + // If keep-alive is enabled, establish connection now + if config.KeepAlive { + if err := client.ensureConnection(); err != nil { + logger.Error("DoQ failed to establish initial connection: %v", err) + client.Close() + return nil, fmt.Errorf("failed to establish initial connection: %w", err) + } + } + + logger.Debug("DoQ client created: %s:%s (DNSSEC: %v, KeepAlive: %v)", config.Host, config.Port, config.DNSSEC, config.KeepAlive) + return client, nil } func (c *Client) Close() { logger.Debug("Closing DoQ client") + c.connMutex.Lock() + defer c.connMutex.Unlock() + + if c.quicConn != nil { + c.quicConn.CloseWithError(0, "client shutdown") + c.quicConn = nil + } if c.udpConn != nil { c.udpConn.Close() } } -func (c *Client) OpenConnection() error { - logger.Debug("Opening DoQ connection to %s", c.targetAddr) +func (c *Client) ensureConnection() error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + + // Check if existing connection is still valid + if c.quicConn != nil { + select { + case <-c.quicConn.Context().Done(): + logger.Debug("DoQ connection closed, reconnecting") + c.quicConn = nil + default: + // Connection is still valid + return nil + } + } + + logger.Debug("Establishing DoQ connection to %s", c.targetAddr) quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) if err != nil { logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err) @@ -101,10 +136,19 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr) } - if c.quicConn == nil { - err := c.OpenConnection() - if err != nil { - return nil, err + // Ensure we have a connection (either persistent or new) + if c.config.KeepAlive { + if err := c.ensureConnection(); err != nil { + return nil, fmt.Errorf("doq: failed to ensure connection: %w", err) + } + } else { + // For non-keepalive mode, create a fresh connection for each query + c.connMutex.Lock() + c.quicConn = nil // Force new connection + c.connMutex.Unlock() + + if err := c.ensureConnection(); err != nil { + return nil, fmt.Errorf("doq: failed to create connection: %w", err) } } @@ -119,18 +163,33 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { return nil, fmt.Errorf("doq: failed to pack message: %w", err) } - var quicStream quic.Stream - quicStream, err = c.quicConn.OpenStream() + // Open a stream for this query + c.connMutex.Lock() + quicConn := c.quicConn + c.connMutex.Unlock() + + quicStream, err := quicConn.OpenStream() if err != nil { - logger.Debug("DoQ stream failed, reconnecting: %v", err) - err = c.OpenConnection() - if err != nil { - return nil, err - } - quicStream, err = c.quicConn.OpenStream() - if err != nil { - logger.Error("DoQ failed to open stream after reconnect: %v", err) - return nil, err + logger.Error("DoQ failed to open stream: %v", err) + + // If keep-alive is enabled, try to reconnect once + if c.config.KeepAlive { + logger.Debug("DoQ stream failed with keep-alive, attempting reconnect") + if reconnectErr := c.ensureConnection(); reconnectErr != nil { + return nil, fmt.Errorf("doq: failed to reconnect: %w", reconnectErr) + } + + c.connMutex.Lock() + quicConn = c.quicConn + c.connMutex.Unlock() + + quicStream, err = quicConn.OpenStream() + if err != nil { + logger.Error("DoQ failed to open stream after reconnect: %v", err) + return nil, fmt.Errorf("doq: failed to open stream after reconnect: %w", err) + } + } else { + return nil, fmt.Errorf("doq: failed to open stream: %w", err) } } @@ -184,5 +243,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer)) } + // Close the connection if not using keep-alive + if !c.config.KeepAlive { + c.connMutex.Lock() + if c.quicConn != nil { + c.quicConn.CloseWithError(0, "query complete") + c.quicConn = nil + } + c.connMutex.Unlock() + } + return recvMsg, nil } diff --git a/common/protocols/dot/dot.go b/common/protocols/dot/dot.go index b236ab2..6b50fff 100644 --- a/common/protocols/dot/dot.go +++ b/common/protocols/dot/dot.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/afonsofrancof/sdns-proxy/common/logger" @@ -16,6 +17,7 @@ type Config struct { Host string Port string DNSSEC bool + KeepAlive bool WriteTimeout time.Duration ReadTimeout time.Duration Debug bool @@ -25,10 +27,12 @@ type Client struct { hostAndPort string tlsConfig *tls.Config config Config + conn *tls.Conn + connMutex sync.Mutex } func New(config Config) (*Client, error) { - logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port) + logger.Debug("Creating DoT client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive) if config.Host == "" { logger.Error("DoT client creation failed: empty host") @@ -47,33 +51,76 @@ func New(config Config) (*Client, error) { ServerName: config.Host, } - logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC) - - return &Client{ + client := &Client{ hostAndPort: hostAndPort, tlsConfig: tlsConfig, config: config, - }, nil + } + + // If keep-alive is enabled, establish connection now + if config.KeepAlive { + if err := client.ensureConnection(); err != nil { + logger.Error("DoT failed to establish initial connection: %v", err) + return nil, fmt.Errorf("failed to establish initial connection: %w", err) + } + } + + logger.Debug("DoT client created: %s (DNSSEC: %v, KeepAlive: %v)", hostAndPort, config.DNSSEC, config.KeepAlive) + return client, nil } func (c *Client) Close() { logger.Debug("Closing DoT client") + c.connMutex.Lock() + defer c.connMutex.Unlock() + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } } -func (c *Client) createConnection() (*tls.Conn, error) { +func (c *Client) ensureConnection() error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + + // Check if existing connection is still valid + if c.conn != nil { + // Test the connection with a very short deadline + if err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err == nil { + // Try to read one byte to test connection + var testBuf [1]byte + _, err := c.conn.Read(testBuf[:]) + + // Reset deadline + c.conn.SetReadDeadline(time.Time{}) + + // If we get a timeout error, connection is still good + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil + } + + // Any other error means connection is dead + logger.Debug("DoT connection test failed, reconnecting: %v", err) + c.conn.Close() + c.conn = nil + } + } + + logger.Debug("Establishing DoT connection to %s", c.hostAndPort) dialer := &net.Dialer{ Timeout: c.config.WriteTimeout, } - logger.Debug("Establishing DoT connection to %s", c.hostAndPort) conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig) if err != nil { logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err) - return nil, err + return err } + c.conn = conn logger.Debug("DoT connection established to %s", c.hostAndPort) - return conn, nil + return nil } func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { @@ -82,12 +129,24 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort) } - // Create connection for this query - conn, err := c.createConnection() - if err != nil { - return nil, fmt.Errorf("dot: failed to create connection: %w", err) + // Ensure we have a connection (either persistent or new) + if c.config.KeepAlive { + if err := c.ensureConnection(); err != nil { + return nil, fmt.Errorf("dot: failed to ensure connection: %w", err) + } + } else { + // For non-keepalive mode, create a fresh connection for each query + c.connMutex.Lock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.connMutex.Unlock() + + if err := c.ensureConnection(); err != nil { + return nil, fmt.Errorf("dot: failed to create connection: %w", err) + } } - defer conn.Close() // Prepare DNS message if c.config.DNSSEC { @@ -104,6 +163,10 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { binary.BigEndian.PutUint16(length, uint16(len(packed))) data := append(length, packed...) + c.connMutex.Lock() + conn := c.conn + c.connMutex.Unlock() + // Write query if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { logger.Error("DoT failed to set write deadline: %v", err) @@ -112,7 +175,28 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { if _, err := conn.Write(data); err != nil { logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err) - return nil, fmt.Errorf("dot: failed to write message: %w", err) + + // If keep-alive is enabled and write failed, try to reconnect once + if c.config.KeepAlive { + logger.Debug("DoT write failed with keep-alive, attempting reconnect") + if reconnectErr := c.ensureConnection(); reconnectErr != nil { + return nil, fmt.Errorf("dot: failed to reconnect: %w", reconnectErr) + } + + c.connMutex.Lock() + conn = c.conn + c.connMutex.Unlock() + + if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { + return nil, fmt.Errorf("dot: failed to set write deadline after reconnect: %w", err) + } + + if _, err := conn.Write(data); err != nil { + return nil, fmt.Errorf("dot: failed to write message after reconnect: %w", err) + } + } else { + return nil, fmt.Errorf("dot: failed to write message: %w", err) + } } // Read response @@ -152,5 +236,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer)) } + // Close the connection if not using keep-alive + if !c.config.KeepAlive { + c.connMutex.Lock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.connMutex.Unlock() + } + return response, nil } diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..429743f --- /dev/null +++ b/flake.lock @@ -0,0 +1,27 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1758916627, + "narHash": "sha256-fB2ISCc+xn+9hZ6gOsABxSBcsCgLCjbJ5bC6U9bPzQ4=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "53614373268559d054c080d070cfc732dbe68ac4", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix index a0740e5..f6677be 100644 --- a/flake.nix +++ b/flake.nix @@ -26,6 +26,11 @@ go gotools golangci-lint + (pkgs.python3.withPackages (python-pkgs: with python-pkgs; [ + pandas + matplotlib + seaborn + ])) ]; }; }); diff --git a/internal/qol/measurement.go b/internal/qol/measurement.go index 505c810..bf16a01 100644 --- a/internal/qol/measurement.go +++ b/internal/qol/measurement.go @@ -20,6 +20,7 @@ type MeasurementConfig struct { OutputDir string QueryType string DNSSEC bool + KeepAlive bool Interface string Servers []string } @@ -74,9 +75,14 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT defer dnsClient.Close() // Setup output files - jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC) + jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC, r.config.KeepAlive) - fmt.Printf(">>> Measuring %s (dnssec=%v) → %s\n", upstream, r.config.DNSSEC, + keepAliveStr := "" + if r.config.KeepAlive { + keepAliveStr = " (keep-alive)" + } + + fmt.Printf(">>> Measuring %s (dnssec=%v%s) → %s\n", upstream, r.config.DNSSEC, keepAliveStr, strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/")) // Setup packet capture @@ -98,7 +104,10 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT } func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) { - opts := client.Options{DNSSEC: r.config.DNSSEC} + opts := client.Options{ + DNSSEC: r.config.DNSSEC, + KeepAlive: r.config.KeepAlive, + } return client.New(upstream, opts) } @@ -136,7 +145,14 @@ func (r *MeasurementRunner) runQueries(dnsClient client.DNSClient, upstream stri } r.printQueryResult(metric) - time.Sleep(10 * time.Millisecond) + + // For keep-alive connections, add smaller delays between queries + // to better utilize the persistent connection + if r.config.KeepAlive { + time.Sleep(5 * time.Millisecond) + } else { + time.Sleep(10 * time.Millisecond) + } } time.Sleep(100 * time.Millisecond) @@ -154,6 +170,7 @@ func (r *MeasurementRunner) performQuery(dnsClient client.DNSClient, domain, ups QueryType: r.config.QueryType, Protocol: proto, DNSSEC: r.config.DNSSEC, + KeepAlive: r.config.KeepAlive, DNSServer: upstream, Timestamp: time.Now(), } @@ -199,8 +216,14 @@ func (r *MeasurementRunner) printQueryResult(metric results.DNSMetric) { if metric.ResponseCode == "ERROR" { statusIcon = "✗" } - fmt.Printf("%s %s [%s] %s %.2fms\n", - statusIcon, metric.Domain, metric.Protocol, metric.ResponseCode, metric.DurationMs) + + keepAliveIndicator := "" + if metric.KeepAlive { + keepAliveIndicator = "⟷" + } + + fmt.Printf("%s %s%s [%s] %s %.2fms\n", + statusIcon, metric.Domain, keepAliveIndicator, metric.Protocol, metric.ResponseCode, metric.DurationMs) } func (r *MeasurementRunner) readDomainsFile() ([]string, error) { diff --git a/internal/qol/results/writer.go b/internal/qol/results/writer.go index 29e50b9..8733dd1 100644 --- a/internal/qol/results/writer.go +++ b/internal/qol/results/writer.go @@ -12,6 +12,7 @@ type DNSMetric struct { QueryType string `json:"query_type"` Protocol string `json:"protocol"` DNSSEC bool `json:"dnssec"` + KeepAlive bool `json:"keep_alive"` DNSServer string `json:"dns_server"` Timestamp time.Time `json:"timestamp"` Duration int64 `json:"duration_ns"` @@ -22,6 +23,7 @@ type DNSMetric struct { Error string `json:"error,omitempty"` } +// Rest stays exactly the same type MetricsWriter struct { encoder *json.Encoder file *os.File diff --git a/internal/qol/utils.go b/internal/qol/utils.go index 6f76b98..0d4fee7 100644 --- a/internal/qol/utils.go +++ b/internal/qol/utils.go @@ -8,17 +8,18 @@ import ( "time" ) -func GenerateOutputPaths(outputDir, upstream string, dnssec bool) (jsonPath, pcapPath string) { +func GenerateOutputPaths(outputDir, upstream string, dnssec, keepAlive bool) (jsonPath, pcapPath string) { proto := DetectProtocol(upstream) serverName := ExtractServerName(upstream) ts := time.Now().Format("20060102_1504") dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec] + keepAliveStr := map[bool]string{true: "on", false: "off"}[keepAlive] - base := fmt.Sprintf("%s_%s_dnssec_%s_%s", - proto, sanitize(serverName), dnssecStr, ts) + base := fmt.Sprintf("%s_%s_dnssec_%s_keepalive_%s_%s", + proto, sanitize(serverName), dnssecStr, keepAliveStr, ts) return filepath.Join(outputDir, base+".jsonl"), - filepath.Join(outputDir, base+".pcap") + filepath.Join(outputDir, base+".pcap") } func sanitize(s string) string { diff --git a/server/server.go b/server/server.go index 5bb9581..e80e1f4 100644 --- a/server/server.go +++ b/server/server.go @@ -1,153 +1,15 @@ -package server - -import ( - "context" - "fmt" - "net" - "net/url" - "os" - "os/signal" - "strings" - "sync" - "syscall" - "time" - - "github.com/afonsofrancof/sdns-proxy/client" - "github.com/afonsofrancof/sdns-proxy/common/logger" - "github.com/miekg/dns" -) - type Config struct { Address string Upstream string Fallback string Bootstrap string DNSSEC bool + KeepAlive bool Timeout time.Duration Verbose bool } -type cacheKey struct { - domain string - qtype uint16 -} - -type cacheEntry struct { - records []dns.RR - expiresAt time.Time -} - -type Server struct { - config Config - upstreamClient client.DNSClient - fallbackClient client.DNSClient - bootstrapClient client.DNSClient - resolvedHosts map[string]string - queryCache map[cacheKey]*cacheEntry - hostsMutex sync.RWMutex - cacheMutex sync.RWMutex - dnsServer *dns.Server -} - -func New(config Config) (*Server, error) { - logger.Debug("Creating new server with config: %+v", config) - - if config.Upstream == "" { - logger.Error("Upstream server is required") - return nil, fmt.Errorf("upstream server is required") - } - - // Check if we need bootstrap server - needsBootstrap := containsHostname(config.Upstream) - if config.Fallback != "" { - needsBootstrap = needsBootstrap || containsHostname(config.Fallback) - } - - logger.Debug("Bootstrap needed: %v (upstream has hostname: %v, fallback has hostname: %v)", - needsBootstrap, containsHostname(config.Upstream), - config.Fallback != "" && containsHostname(config.Fallback)) - - if needsBootstrap && config.Bootstrap == "" { - logger.Error("Bootstrap server is required when upstream or fallback contains hostnames") - return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames") - } - - if config.Bootstrap != "" && containsHostname(config.Bootstrap) { - logger.Error("Bootstrap server cannot contain hostnames: %s", config.Bootstrap) - return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap) - } - - s := &Server{ - config: config, - resolvedHosts: make(map[string]string), - queryCache: make(map[cacheKey]*cacheEntry), - } - - // Create bootstrap client if needed - if config.Bootstrap != "" { - logger.Debug("Creating bootstrap client for %s", config.Bootstrap) - bootstrapClient, err := client.New(config.Bootstrap, client.Options{ - DNSSEC: false, - }) - if err != nil { - logger.Error("Failed to create bootstrap client: %v", err) - return nil, fmt.Errorf("failed to create bootstrap client: %w", err) - } - s.bootstrapClient = bootstrapClient - logger.Debug("Bootstrap client created successfully") - } - - // Initialize upstream and fallback clients - if err := s.initClients(); err != nil { - logger.Error("Failed to initialize clients: %v", err) - return nil, fmt.Errorf("failed to initialize clients: %w", err) - } - - // Setup DNS server - mux := dns.NewServeMux() - mux.HandleFunc(".", s.handleDNSRequest) - - s.dnsServer = &dns.Server{ - Addr: config.Address, - Net: "udp", - Handler: mux, - } - - logger.Debug("Server created successfully, listening on %s", config.Address) - return s, nil -} - -func containsHostname(serverAddr string) bool { - logger.Debug("Checking if %s contains hostname", serverAddr) - - // Use the same parsing logic as the client package - parsedURL, err := url.Parse(serverAddr) - if err != nil { - logger.Debug("URL parsing failed for %s, treating as plain address", serverAddr) - // If URL parsing fails, assume it's a plain address - host, _, err := net.SplitHostPort(serverAddr) - if err != nil { - // Assume it's just a host - isHostname := net.ParseIP(serverAddr) == nil - logger.Debug("Address %s is hostname: %v", serverAddr, isHostname) - return isHostname - } - isHostname := net.ParseIP(host) == nil - logger.Debug("Host %s from %s is hostname: %v", host, serverAddr, isHostname) - return isHostname - } - - host := parsedURL.Hostname() - if host == "" { - logger.Debug("No hostname found in URL %s", serverAddr) - return false - } - - isHostname := net.ParseIP(host) == nil - logger.Debug("Host %s from URL %s is hostname: %v", host, serverAddr, isHostname) - return isHostname -} - +// Update the initClients method: func (s *Server) initClients() error { logger.Debug("Initializing DNS clients") @@ -160,7 +22,8 @@ func (s *Server) initClients() error { logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream) upstreamClient, err := client.New(resolvedUpstream, client.Options{ - DNSSEC: s.config.DNSSEC, + DNSSEC: s.config.DNSSEC, + KeepAlive: s.config.KeepAlive, }) if err != nil { logger.Error("Failed to create upstream client: %v", err) @@ -169,7 +32,7 @@ func (s *Server) initClients() error { s.upstreamClient = upstreamClient if s.config.Verbose { - logger.Info("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream) + logger.Info("Initialized upstream client: %s -> %s (KeepAlive: %v)", s.config.Upstream, resolvedUpstream, s.config.KeepAlive) } // Initialize fallback client if specified @@ -182,7 +45,8 @@ func (s *Server) initClients() error { logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback) fallbackClient, err := client.New(resolvedFallback, client.Options{ - DNSSEC: s.config.DNSSEC, + DNSSEC: s.config.DNSSEC, + KeepAlive: s.config.KeepAlive, }) if err != nil { logger.Error("Failed to create fallback client: %v", err) @@ -191,402 +55,10 @@ func (s *Server) initClients() error { s.fallbackClient = fallbackClient if s.config.Verbose { - logger.Info("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback) + logger.Info("Initialized fallback client: %s -> %s (KeepAlive: %v)", s.config.Fallback, resolvedFallback, s.config.KeepAlive) } } logger.Debug("All DNS clients initialized successfully") return nil } - -func (s *Server) resolveServerAddress(serverAddr string) (string, error) { - logger.Debug("Resolving server address: %s", serverAddr) - - // If it doesn't contain hostnames, return as-is - if !containsHostname(serverAddr) { - logger.Debug("Address %s contains no hostnames, returning as-is", serverAddr) - return serverAddr, nil - } - - // If no bootstrap client, we can't resolve hostnames - if s.bootstrapClient == nil { - logger.Error("Cannot resolve hostname in %s: no bootstrap server configured", serverAddr) - return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr) - } - - // Use the same parsing logic as the client package - parsedURL, err := url.Parse(serverAddr) - if err != nil { - logger.Debug("Parsing %s as plain host:port format", serverAddr) - // Handle plain host:port format - host, port, err := net.SplitHostPort(serverAddr) - if err != nil { - // Assume it's just a hostname - resolvedIP, err := s.resolveHostname(serverAddr) - if err != nil { - return "", err - } - logger.Debug("Resolved %s to %s", serverAddr, resolvedIP) - return resolvedIP, nil - } - - resolvedIP, err := s.resolveHostname(host) - if err != nil { - return "", err - } - resolved := net.JoinHostPort(resolvedIP, port) - logger.Debug("Resolved %s to %s", serverAddr, resolved) - return resolved, nil - } - - // Handle URL format - hostname := parsedURL.Hostname() - if hostname == "" { - logger.Error("No hostname in URL: %s", serverAddr) - return "", fmt.Errorf("no hostname in URL: %s", serverAddr) - } - - resolvedIP, err := s.resolveHostname(hostname) - if err != nil { - return "", err - } - - // Replace hostname with IP in the URL - port := parsedURL.Port() - if port == "" { - parsedURL.Host = resolvedIP - } else { - parsedURL.Host = net.JoinHostPort(resolvedIP, port) - } - - resolved := parsedURL.String() - logger.Debug("Resolved URL %s to %s", serverAddr, resolved) - return resolved, nil -} - -func (s *Server) resolveHostname(hostname string) (string, error) { - logger.Debug("Resolving hostname: %s", hostname) - - // Check cache first - s.hostsMutex.RLock() - if ip, exists := s.resolvedHosts[hostname]; exists { - s.hostsMutex.RUnlock() - logger.Debug("Found cached resolution for %s: %s", hostname, ip) - return ip, nil - } - s.hostsMutex.RUnlock() - - // Resolve using bootstrap - if s.config.Verbose { - logger.Info("Resolving hostname %s using bootstrap server", hostname) - } - - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(hostname), dns.TypeA) - msg.Id = dns.Id() - msg.RecursionDesired = true - - logger.Debug("Sending bootstrap query for %s (ID: %d)", hostname, msg.Id) - msg, err := s.bootstrapClient.Query(msg) - if err != nil { - logger.Error("Bootstrap query failed for %s: %v", hostname, err) - return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err) - } - - logger.Debug("Bootstrap response for %s: %d answers", hostname, len(msg.Answer)) - if len(msg.Answer) == 0 { - logger.Error("No A records found for %s", hostname) - return "", fmt.Errorf("no A records found for %s", hostname) - } - - // Find first A record - for _, rr := range msg.Answer { - if a, ok := rr.(*dns.A); ok { - ip := a.A.String() - - // Cache the result - s.hostsMutex.Lock() - s.resolvedHosts[hostname] = ip - s.hostsMutex.Unlock() - - if s.config.Verbose { - logger.Info("Resolved %s to %s", hostname, ip) - } - logger.Debug("Cached resolution: %s -> %s", hostname, ip) - - return ip, nil - } - } - - logger.Error("No valid A record found for %s", hostname) - return "", fmt.Errorf("no valid A record found for %s", hostname) -} - -func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { - if len(r.Question) == 0 { - logger.Debug("Received request with no questions from %s", w.RemoteAddr()) - dns.HandleFailed(w, r) - return - } - - question := r.Question[0] - domain := strings.ToLower(question.Name) - qtype := question.Qtype - - logger.Debug("Handling DNS request: %s %s from %s (ID: %d)", - question.Name, dns.TypeToString[qtype], w.RemoteAddr(), r.Id) - - if s.config.Verbose { - logger.Info("Query: %s %s from %s", - question.Name, - dns.TypeToString[qtype], - w.RemoteAddr()) - } - - // Check cache first - if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil { - response := s.buildResponse(r, cachedRecords) - if s.config.Verbose { - logger.Info("Cache hit: %s %s -> %d records", - question.Name, - dns.TypeToString[qtype], - len(cachedRecords)) - } - logger.Debug("Serving cached response for %s %s (%d records)", - question.Name, dns.TypeToString[qtype], len(cachedRecords)) - w.WriteMsg(response) - return - } - - logger.Debug("Cache miss for %s %s, querying upstream", question.Name, dns.TypeToString[qtype]) - - // Try upstream first - response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype) - if err != nil { - if s.config.Verbose { - logger.Info("Upstream query failed: %v", err) - } - logger.Debug("Upstream query failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err) - - // Try fallback if available - if s.fallbackClient != nil { - if s.config.Verbose { - logger.Info("Trying fallback server") - } - logger.Debug("Attempting fallback query for %s %s", question.Name, dns.TypeToString[qtype]) - - response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype) - if err != nil { - logger.Error("Both upstream and fallback failed for %s %s: %v", - question.Name, - dns.TypeToString[qtype], - err) - } else { - logger.Debug("Fallback query succeeded for %s %s", question.Name, dns.TypeToString[qtype]) - } - } - - // If still failed, return SERVFAIL - if err != nil { - logger.Error("All servers failed for %s %s: %v", - question.Name, - dns.TypeToString[qtype], - err) - - m := new(dns.Msg) - m.SetReply(r) - m.Rcode = dns.RcodeServerFailure - w.WriteMsg(m) - return - } - } else { - logger.Debug("Upstream query succeeded for %s %s", question.Name, dns.TypeToString[qtype]) - } - - // Cache successful response - s.cacheResponse(domain, qtype, response) - - // Copy request ID to response - response.Id = r.Id - - if s.config.Verbose { - logger.Info("Response: %s %s -> %d answers", - question.Name, - dns.TypeToString[qtype], - len(response.Answer)) - } - - logger.Debug("Sending response for %s %s: %d answers, rcode: %s", - question.Name, dns.TypeToString[qtype], len(response.Answer), dns.RcodeToString[response.Rcode]) - w.WriteMsg(response) -} - -func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR { - key := cacheKey{domain: domain, qtype: qtype} - - s.cacheMutex.RLock() - entry, exists := s.queryCache[key] - s.cacheMutex.RUnlock() - - if !exists { - logger.Debug("No cache entry for %s %s", domain, dns.TypeToString[qtype]) - return nil - } - - // Check if expired and clean up on the spot - if time.Now().After(entry.expiresAt) { - logger.Debug("Cache entry expired for %s %s", domain, dns.TypeToString[qtype]) - s.cacheMutex.Lock() - delete(s.queryCache, key) - s.cacheMutex.Unlock() - return nil - } - - logger.Debug("Cache hit for %s %s (%d records, expires in %v)", - domain, dns.TypeToString[qtype], len(entry.records), time.Until(entry.expiresAt)) - - // Return a copy of the cached records - records := make([]dns.RR, len(entry.records)) - for i, rr := range entry.records { - records[i] = dns.Copy(rr) - } - return records -} - -func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg { - response := new(dns.Msg) - response.SetReply(request) - response.Answer = records - logger.Debug("Built response with %d records", len(records)) - return response -} - -func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) { - if msg == nil || len(msg.Answer) == 0 { - logger.Debug("Not caching empty response for %s %s", domain, dns.TypeToString[qtype]) - return - } - - var validRecords []dns.RR - minTTL := uint32(3600) - - // Find minimum TTL from answer records - for _, rr := range msg.Answer { - // Only cache records that match our query type or are CNAMEs - if rr.Header().Rrtype == qtype || rr.Header().Rrtype == dns.TypeCNAME { - validRecords = append(validRecords, dns.Copy(rr)) - if rr.Header().Ttl < minTTL { - minTTL = rr.Header().Ttl - } - } - } - - if len(validRecords) == 0 { - logger.Debug("No valid records to cache for %s %s", domain, dns.TypeToString[qtype]) - return - } - - // Don't cache responses with very low TTL - if minTTL < 10 { - logger.Debug("TTL too low (%ds) for caching %s %s", minTTL, domain, dns.TypeToString[qtype]) - return - } - - key := cacheKey{domain: domain, qtype: qtype} - entry := &cacheEntry{ - records: validRecords, - expiresAt: time.Now().Add(time.Duration(minTTL) * time.Second), - } - - s.cacheMutex.Lock() - s.queryCache[key] = entry - s.cacheMutex.Unlock() - - if s.config.Verbose { - logger.Info("Cached %d records for %s %s (TTL: %ds)", - len(validRecords), domain, dns.TypeToString[qtype], minTTL) - } - logger.Debug("Cached %d records for %s %s (TTL: %ds, expires: %v)", - len(validRecords), domain, dns.TypeToString[qtype], minTTL, entry.expiresAt) -} - -func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) { - logger.Debug("Querying upstream for %s %s", domain, dns.TypeToString[qtype]) - - // Create context with timeout - ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout) - defer cancel() - - // Channel to receive result - type result struct { - msg *dns.Msg - err error - } - resultChan := make(chan result, 1) - - // Query in goroutine to respect context timeout - go func() { - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain), qtype) - msg.Id = dns.Id() - msg.RecursionDesired = true - - logger.Debug("Sending upstream query: %s %s (ID: %d)", domain, dns.TypeToString[qtype], msg.Id) - recvMsg, err := upstreamClient.Query(msg) - if err != nil { - logger.Debug("Upstream query error for %s %s: %v", domain, dns.TypeToString[qtype], err) - } else { - logger.Debug("Upstream query response for %s %s: %d answers, rcode: %s", - domain, dns.TypeToString[qtype], len(recvMsg.Answer), dns.RcodeToString[recvMsg.Rcode]) - } - resultChan <- result{msg: recvMsg, err: err} - }() - - select { - case res := <-resultChan: - return res.msg, res.err - case <-ctx.Done(): - logger.Debug("Upstream query timeout for %s %s after %v", domain, dns.TypeToString[qtype], s.config.Timeout) - return nil, fmt.Errorf("upstream query timeout") - } -} - -func (s *Server) Start() error { - go func() { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - sig := <-sigChan - logger.Info("Received signal %v, shutting down DNS server...", sig) - s.Shutdown() - }() - - logger.Info("DNS proxy server listening on %s", s.config.Address) - logger.Debug("Server starting with timeout: %v, DNSSEC: %v", s.config.Timeout, s.config.DNSSEC) - return s.dnsServer.ListenAndServe() -} - -func (s *Server) Shutdown() { - logger.Debug("Shutting down server components") - - if s.dnsServer != nil { - logger.Debug("Shutting down DNS server") - s.dnsServer.Shutdown() - } - - if s.upstreamClient != nil { - logger.Debug("Closing upstream client") - s.upstreamClient.Close() - } - - if s.fallbackClient != nil { - logger.Debug("Closing fallback client") - s.fallbackClient.Close() - } - - if s.bootstrapClient != nil { - logger.Debug("Closing bootstrap client") - s.bootstrapClient.Close() - } - - logger.Info("Server shutdown complete") -}