feat(keep-alive): add keep-alive option

This commit is contained in:
2025-09-27 21:49:17 +01:00
parent ddb0d2ca4e
commit 2240d18f0b
13 changed files with 363 additions and 619 deletions

1
.gitignore vendored
View File

@@ -15,3 +15,4 @@
# vendor/ # vendor/
**/tls-key-log.txt **/tls-key-log.txt
/results

View File

@@ -30,6 +30,7 @@ type Options struct {
DNSSEC bool DNSSEC bool
ValidateOnly bool ValidateOnly bool
StrictValidation bool StrictValidation bool
KeepAlive bool // New flag for long-lived connections
} }
// New creates a DNS client based on the upstream string // New creates a DNS client based on the upstream string
@@ -214,8 +215,8 @@ func getDefaultPath(scheme string) string {
} }
func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) { func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) {
logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v", logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v, KeepAlive=%v",
scheme, host, port, path, opts.DNSSEC) scheme, host, port, path, opts.DNSSEC, opts.KeepAlive)
switch scheme { switch scheme {
case "udp", "tcp", "do53", "": case "udp", "tcp", "do53", "":
@@ -228,40 +229,44 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err
case "http", "doh": case "http", "doh":
config := doh.Config{ config := doh.Config{
Host: host, Host: host,
Port: port, Port: port,
Path: path, Path: path,
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
HTTP3: false, HTTP3: false,
KeepAlive: opts.KeepAlive,
} }
logger.Debug("Creating DoH client with config: %+v", config) logger.Debug("Creating DoH client with config: %+v", config)
return doh.New(config) return doh.New(config)
case "https", "doh3": case "https", "doh3":
config := doh.Config{ config := doh.Config{
Host: host, Host: host,
Port: port, Port: port,
Path: path, Path: path,
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
HTTP3: true, HTTP3: true,
KeepAlive: opts.KeepAlive,
} }
logger.Debug("Creating DoH3 client with config: %+v", config) logger.Debug("Creating DoH3 client with config: %+v", config)
return doh.New(config) return doh.New(config)
case "tls", "dot": case "tls", "dot":
config := dot.Config{ config := dot.Config{
Host: host, Host: host,
Port: port, Port: port,
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
KeepAlive: opts.KeepAlive,
} }
logger.Debug("Creating DoT client with config: %+v", config) logger.Debug("Creating DoT client with config: %+v", config)
return dot.New(config) return dot.New(config)
case "doq": // DNS over QUIC case "doq": // DNS over QUIC
config := doq.Config{ config := doq.Config{
Host: host, Host: host,
Port: port, Port: port,
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
KeepAlive: opts.KeepAlive,
} }
logger.Debug("Creating DoQ client with config: %+v", config) logger.Debug("Creating DoQ client with config: %+v", config)
return doq.New(config) return doq.New(config)

View File

@@ -17,6 +17,7 @@ type RunCmd struct {
QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"` QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"`
Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"` Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"`
DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"` DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"`
KeepAlive bool `short:"k" long:"keep-alive" help:"Use persistent connections"`
Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"` Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"`
Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"` Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"`
} }
@@ -27,6 +28,7 @@ func (r *RunCmd) Run() error {
OutputDir: r.OutputDir, OutputDir: r.OutputDir,
QueryType: r.QueryType, QueryType: r.QueryType,
DNSSEC: r.DNSSEC, DNSSEC: r.DNSSEC,
KeepAlive: r.KeepAlive,
Interface: r.Interface, Interface: r.Interface,
Servers: r.Servers, Servers: r.Servers,
} }

View File

@@ -26,6 +26,7 @@ type QueryCmd struct {
DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"` DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"`
ValidateOnly bool `help:"Only return DNSSEC validated responses." short:"V"` ValidateOnly bool `help:"Only return DNSSEC validated responses." short:"V"`
StrictValidation bool `help:"Fail on any DNSSEC validation error." short:"S"` StrictValidation bool `help:"Fail on any DNSSEC validation error." short:"S"`
KeepAlive bool `help:"Use persistent connections." short:"k"`
Timeout time.Duration `help:"Timeout for the query operation." default:"10s"` Timeout time.Duration `help:"Timeout for the query operation." default:"10s"`
KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"` KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"`
} }
@@ -36,18 +37,20 @@ type ListenCmd struct {
Fallback string `help:"Fallback DNS server (e.g., https://1.1.1.1/dns-query, tls://8.8.8.8)." short:"f"` Fallback string `help:"Fallback DNS server (e.g., https://1.1.1.1/dns-query, tls://8.8.8.8)." short:"f"`
Bootstrap string `help:"Bootstrap DNS server (must be an IP address, e.g., 8.8.8.8, 1.1.1.1)." short:"b"` Bootstrap string `help:"Bootstrap DNS server (must be an IP address, e.g., 8.8.8.8, 1.1.1.1)." short:"b"`
DNSSEC bool `help:"Enable DNSSEC for upstream queries." short:"d"` DNSSEC bool `help:"Enable DNSSEC for upstream queries." short:"d"`
KeepAlive bool `help:"Use persistent connections to upstream servers." short:"k"`
Timeout time.Duration `help:"Timeout for upstream queries." default:"5s"` Timeout time.Duration `help:"Timeout for upstream queries." default:"5s"`
Verbose bool `help:"Enable verbose logging." short:"v"` Verbose bool `help:"Enable verbose logging." short:"v"`
} }
func (q *QueryCmd) Run() error { func (q *QueryCmd) Run() error {
logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)", logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, KeepAlive: %v, Timeout: %v)",
q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.Timeout) q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.KeepAlive, q.Timeout)
opts := client.Options{ opts := client.Options{
DNSSEC: q.DNSSEC, DNSSEC: q.DNSSEC,
ValidateOnly: q.ValidateOnly, ValidateOnly: q.ValidateOnly,
StrictValidation: q.StrictValidation, StrictValidation: q.StrictValidation,
KeepAlive: q.KeepAlive,
} }
logger.Debug("Creating DNS client with options: %+v", opts) logger.Debug("Creating DNS client with options: %+v", opts)
@@ -90,6 +93,7 @@ func (l *ListenCmd) Run() error {
Fallback: l.Fallback, Fallback: l.Fallback,
Bootstrap: l.Bootstrap, Bootstrap: l.Bootstrap,
DNSSEC: l.DNSSEC, DNSSEC: l.DNSSEC,
KeepAlive: l.KeepAlive,
Timeout: l.Timeout, Timeout: l.Timeout,
Verbose: l.Verbose, Verbose: l.Verbose,
} }
@@ -105,10 +109,12 @@ func (l *ListenCmd) Run() error {
logger.Info("Upstream server: %v", l.Upstream) logger.Info("Upstream server: %v", l.Upstream)
logger.Info("Fallback server: %v", l.Fallback) logger.Info("Fallback server: %v", l.Fallback)
logger.Info("Bootstrap server: %v", l.Bootstrap) logger.Info("Bootstrap server: %v", l.Bootstrap)
logger.Info("KeepAlive: %v", l.KeepAlive)
return srv.Start() return srv.Start()
} }
func printResponse(domain, qtype string, msg *dns.Msg) { func printResponse(domain, qtype string, msg *dns.Msg) {
fmt.Println(";; QUESTION SECTION:") fmt.Println(";; QUESTION SECTION:")

View File

@@ -22,12 +22,13 @@ import (
const dnsMessageContentType = "application/dns-message" const dnsMessageContentType = "application/dns-message"
type Config struct { type Config struct {
Host string Host string
Port string Port string
Path string Path string
DNSSEC bool DNSSEC bool
HTTP3 bool HTTP3 bool
HTTP2 bool HTTP2 bool
KeepAlive bool
} }
type Client struct { type Client struct {
@@ -37,7 +38,7 @@ type Client struct {
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DoH client: %s:%s%s", config.Host, config.Port, config.Path) logger.Debug("Creating DoH client: %s:%s%s (KeepAlive: %v)", config.Host, config.Port, config.Path, config.KeepAlive)
if config.Host == "" || config.Port == "" || config.Path == "" { if config.Host == "" || config.Port == "" || config.Path == "" {
logger.Error("DoH client creation failed: missing required fields") logger.Error("DoH client creation failed: missing required fields")
@@ -65,31 +66,58 @@ func New(config Config) (*Client, error) {
DisablePathMTUDiscovery: true, DisablePathMTUDiscovery: true,
} }
transport := http.DefaultTransport.(*http.Transport) transport := &http.Transport{
transport.TLSClientConfig = tlsConfig TLSClientConfig: tlsConfig,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
}
// Configure connection pooling based on KeepAlive setting
if config.KeepAlive {
transport.MaxIdleConnsPerHost = 10
transport.MaxConnsPerHost = 10
} else {
transport.MaxIdleConns = 0
transport.MaxIdleConnsPerHost = 0
transport.DisableKeepAlives = true
}
httpClient := &http.Client{ httpClient := &http.Client{
Transport: transport, Transport: transport,
Timeout: 30 * time.Second,
} }
var transportType string var transportType string
if config.HTTP2 { if config.HTTP2 {
httpClient.Transport = &http2.Transport{ http2Transport := &http2.Transport{
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
AllowHTTP: true, AllowHTTP: true,
} }
if !config.KeepAlive {
http2Transport.DisableCompression = true
}
httpClient.Transport = http2Transport
transportType = "HTTP/2" transportType = "HTTP/2"
} else if config.HTTP3 { } else if config.HTTP3 {
quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig) quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig)
httpClient.Transport = &http3.Transport{ http3Transport := &http3.Transport{
TLSClientConfig: quicTlsConfig, TLSClientConfig: quicTlsConfig,
QUICConfig: quicConfig, QUICConfig: quicConfig,
} }
if !config.KeepAlive {
http3Transport.DisableCompression = true
}
httpClient.Transport = http3Transport
transportType = "HTTP/3" transportType = "HTTP/3"
} else { } else {
transportType = "HTTP/1.1" transportType = "HTTP/1.1"
} }
logger.Debug("DoH client created: %s (%s, DNSSEC: %v)", rawURL, transportType, config.DNSSEC) logger.Debug("DoH client created: %s (%s, DNSSEC: %v, KeepAlive: %v)", rawURL, transportType, config.DNSSEC, config.KeepAlive)
return &Client{ return &Client{
httpClient: httpClient, httpClient: httpClient,
@@ -102,6 +130,8 @@ func (c *Client) Close() {
logger.Debug("Closing DoH client") logger.Debug("Closing DoH client")
if t, ok := c.httpClient.Transport.(*http.Transport); ok { if t, ok := c.httpClient.Transport.(*http.Transport); ok {
t.CloseIdleConnections() t.CloseIdleConnections()
} else if t2, ok := c.httpClient.Transport.(*http2.Transport); ok {
t2.CloseIdleConnections()
} else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok { } else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok {
t3.CloseIdleConnections() t3.CloseIdleConnections()
} }
@@ -132,6 +162,13 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
httpReq.Header.Set("User-Agent", "sdns-proxy") httpReq.Header.Set("User-Agent", "sdns-proxy")
httpReq.Header.Set("Content-Type", dnsMessageContentType) httpReq.Header.Set("Content-Type", dnsMessageContentType)
httpReq.Header.Set("Accept", dnsMessageContentType) httpReq.Header.Set("Accept", dnsMessageContentType)
// Set Connection header based on KeepAlive setting
if c.config.KeepAlive {
httpReq.Header.Set("Connection", "keep-alive")
} else {
httpReq.Header.Set("Connection", "close")
}
httpResp, err := c.httpClient.Do(httpReq) httpResp, err := c.httpClient.Do(httpReq)
if err != nil { if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/afonsofrancof/sdns-proxy/common/logger"
@@ -16,10 +17,11 @@ import (
) )
type Config struct { type Config struct {
Host string Host string
Port string Port string
Debug bool Debug bool
DNSSEC bool DNSSEC bool
KeepAlive bool
} }
type Client struct { type Client struct {
@@ -30,10 +32,11 @@ type Client struct {
quicTransport *quic.Transport quicTransport *quic.Transport
quicConfig *quic.Config quicConfig *quic.Config
config Config config Config
connMutex sync.Mutex
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port) logger.Debug("Creating DoQ client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: config.Host, ServerName: config.Host,
@@ -62,9 +65,7 @@ func New(config Config) (*Client, error) {
MaxIdleTimeout: 30 * time.Second, MaxIdleTimeout: 30 * time.Second,
} }
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC) client := &Client{
return &Client{
targetAddr: targetAddr, targetAddr: targetAddr,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
udpConn: udpConn, udpConn: udpConn,
@@ -72,18 +73,52 @@ func New(config Config) (*Client, error) {
quicTransport: &quicTransport, quicTransport: &quicTransport,
quicConfig: &quicConfig, quicConfig: &quicConfig,
config: config, config: config,
}, nil }
// If keep-alive is enabled, establish connection now
if config.KeepAlive {
if err := client.ensureConnection(); err != nil {
logger.Error("DoQ failed to establish initial connection: %v", err)
client.Close()
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
}
}
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v, KeepAlive: %v)", config.Host, config.Port, config.DNSSEC, config.KeepAlive)
return client, nil
} }
func (c *Client) Close() { func (c *Client) Close() {
logger.Debug("Closing DoQ client") logger.Debug("Closing DoQ client")
c.connMutex.Lock()
defer c.connMutex.Unlock()
if c.quicConn != nil {
c.quicConn.CloseWithError(0, "client shutdown")
c.quicConn = nil
}
if c.udpConn != nil { if c.udpConn != nil {
c.udpConn.Close() c.udpConn.Close()
} }
} }
func (c *Client) OpenConnection() error { func (c *Client) ensureConnection() error {
logger.Debug("Opening DoQ connection to %s", c.targetAddr) c.connMutex.Lock()
defer c.connMutex.Unlock()
// Check if existing connection is still valid
if c.quicConn != nil {
select {
case <-c.quicConn.Context().Done():
logger.Debug("DoQ connection closed, reconnecting")
c.quicConn = nil
default:
// Connection is still valid
return nil
}
}
logger.Debug("Establishing DoQ connection to %s", c.targetAddr)
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
if err != nil { if err != nil {
logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err) logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err)
@@ -101,10 +136,19 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr) logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr)
} }
if c.quicConn == nil { // Ensure we have a connection (either persistent or new)
err := c.OpenConnection() if c.config.KeepAlive {
if err != nil { if err := c.ensureConnection(); err != nil {
return nil, err return nil, fmt.Errorf("doq: failed to ensure connection: %w", err)
}
} else {
// For non-keepalive mode, create a fresh connection for each query
c.connMutex.Lock()
c.quicConn = nil // Force new connection
c.connMutex.Unlock()
if err := c.ensureConnection(); err != nil {
return nil, fmt.Errorf("doq: failed to create connection: %w", err)
} }
} }
@@ -119,18 +163,33 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("doq: failed to pack message: %w", err) return nil, fmt.Errorf("doq: failed to pack message: %w", err)
} }
var quicStream quic.Stream // Open a stream for this query
quicStream, err = c.quicConn.OpenStream() c.connMutex.Lock()
quicConn := c.quicConn
c.connMutex.Unlock()
quicStream, err := quicConn.OpenStream()
if err != nil { if err != nil {
logger.Debug("DoQ stream failed, reconnecting: %v", err) logger.Error("DoQ failed to open stream: %v", err)
err = c.OpenConnection()
if err != nil { // If keep-alive is enabled, try to reconnect once
return nil, err if c.config.KeepAlive {
} logger.Debug("DoQ stream failed with keep-alive, attempting reconnect")
quicStream, err = c.quicConn.OpenStream() if reconnectErr := c.ensureConnection(); reconnectErr != nil {
if err != nil { return nil, fmt.Errorf("doq: failed to reconnect: %w", reconnectErr)
logger.Error("DoQ failed to open stream after reconnect: %v", err) }
return nil, err
c.connMutex.Lock()
quicConn = c.quicConn
c.connMutex.Unlock()
quicStream, err = quicConn.OpenStream()
if err != nil {
logger.Error("DoQ failed to open stream after reconnect: %v", err)
return nil, fmt.Errorf("doq: failed to open stream after reconnect: %w", err)
}
} else {
return nil, fmt.Errorf("doq: failed to open stream: %w", err)
} }
} }
@@ -184,5 +243,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer)) logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer))
} }
// Close the connection if not using keep-alive
if !c.config.KeepAlive {
c.connMutex.Lock()
if c.quicConn != nil {
c.quicConn.CloseWithError(0, "query complete")
c.quicConn = nil
}
c.connMutex.Unlock()
}
return recvMsg, nil return recvMsg, nil
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/afonsofrancof/sdns-proxy/common/logger"
@@ -16,6 +17,7 @@ type Config struct {
Host string Host string
Port string Port string
DNSSEC bool DNSSEC bool
KeepAlive bool
WriteTimeout time.Duration WriteTimeout time.Duration
ReadTimeout time.Duration ReadTimeout time.Duration
Debug bool Debug bool
@@ -25,10 +27,12 @@ type Client struct {
hostAndPort string hostAndPort string
tlsConfig *tls.Config tlsConfig *tls.Config
config Config config Config
conn *tls.Conn
connMutex sync.Mutex
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port) logger.Debug("Creating DoT client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
if config.Host == "" { if config.Host == "" {
logger.Error("DoT client creation failed: empty host") logger.Error("DoT client creation failed: empty host")
@@ -47,33 +51,76 @@ func New(config Config) (*Client, error) {
ServerName: config.Host, ServerName: config.Host,
} }
logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC) client := &Client{
return &Client{
hostAndPort: hostAndPort, hostAndPort: hostAndPort,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
config: config, config: config,
}, nil }
// If keep-alive is enabled, establish connection now
if config.KeepAlive {
if err := client.ensureConnection(); err != nil {
logger.Error("DoT failed to establish initial connection: %v", err)
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
}
}
logger.Debug("DoT client created: %s (DNSSEC: %v, KeepAlive: %v)", hostAndPort, config.DNSSEC, config.KeepAlive)
return client, nil
} }
func (c *Client) Close() { func (c *Client) Close() {
logger.Debug("Closing DoT client") logger.Debug("Closing DoT client")
c.connMutex.Lock()
defer c.connMutex.Unlock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
} }
func (c *Client) createConnection() (*tls.Conn, error) { func (c *Client) ensureConnection() error {
c.connMutex.Lock()
defer c.connMutex.Unlock()
// Check if existing connection is still valid
if c.conn != nil {
// Test the connection with a very short deadline
if err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err == nil {
// Try to read one byte to test connection
var testBuf [1]byte
_, err := c.conn.Read(testBuf[:])
// Reset deadline
c.conn.SetReadDeadline(time.Time{})
// If we get a timeout error, connection is still good
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil
}
// Any other error means connection is dead
logger.Debug("DoT connection test failed, reconnecting: %v", err)
c.conn.Close()
c.conn = nil
}
}
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
dialer := &net.Dialer{ dialer := &net.Dialer{
Timeout: c.config.WriteTimeout, Timeout: c.config.WriteTimeout,
} }
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig) conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig)
if err != nil { if err != nil {
logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err) logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err)
return nil, err return err
} }
c.conn = conn
logger.Debug("DoT connection established to %s", c.hostAndPort) logger.Debug("DoT connection established to %s", c.hostAndPort)
return conn, nil return nil
} }
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
@@ -82,12 +129,24 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort) logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
} }
// Create connection for this query // Ensure we have a connection (either persistent or new)
conn, err := c.createConnection() if c.config.KeepAlive {
if err != nil { if err := c.ensureConnection(); err != nil {
return nil, fmt.Errorf("dot: failed to create connection: %w", err) return nil, fmt.Errorf("dot: failed to ensure connection: %w", err)
}
} else {
// For non-keepalive mode, create a fresh connection for each query
c.connMutex.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connMutex.Unlock()
if err := c.ensureConnection(); err != nil {
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
}
} }
defer conn.Close()
// Prepare DNS message // Prepare DNS message
if c.config.DNSSEC { if c.config.DNSSEC {
@@ -104,6 +163,10 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
binary.BigEndian.PutUint16(length, uint16(len(packed))) binary.BigEndian.PutUint16(length, uint16(len(packed)))
data := append(length, packed...) data := append(length, packed...)
c.connMutex.Lock()
conn := c.conn
c.connMutex.Unlock()
// Write query // Write query
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
logger.Error("DoT failed to set write deadline: %v", err) logger.Error("DoT failed to set write deadline: %v", err)
@@ -112,7 +175,28 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
if _, err := conn.Write(data); err != nil { if _, err := conn.Write(data); err != nil {
logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err) logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("dot: failed to write message: %w", err)
// If keep-alive is enabled and write failed, try to reconnect once
if c.config.KeepAlive {
logger.Debug("DoT write failed with keep-alive, attempting reconnect")
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
return nil, fmt.Errorf("dot: failed to reconnect: %w", reconnectErr)
}
c.connMutex.Lock()
conn = c.conn
c.connMutex.Unlock()
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
return nil, fmt.Errorf("dot: failed to set write deadline after reconnect: %w", err)
}
if _, err := conn.Write(data); err != nil {
return nil, fmt.Errorf("dot: failed to write message after reconnect: %w", err)
}
} else {
return nil, fmt.Errorf("dot: failed to write message: %w", err)
}
} }
// Read response // Read response
@@ -152,5 +236,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer)) logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer))
} }
// Close the connection if not using keep-alive
if !c.config.KeepAlive {
c.connMutex.Lock()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connMutex.Unlock()
}
return response, nil return response, nil
} }

27
flake.lock generated Normal file
View File

@@ -0,0 +1,27 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1758916627,
"narHash": "sha256-fB2ISCc+xn+9hZ6gOsABxSBcsCgLCjbJ5bC6U9bPzQ4=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "53614373268559d054c080d070cfc732dbe68ac4",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixpkgs-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

View File

@@ -26,6 +26,11 @@
go go
gotools gotools
golangci-lint golangci-lint
(pkgs.python3.withPackages (python-pkgs: with python-pkgs; [
pandas
matplotlib
seaborn
]))
]; ];
}; };
}); });

View File

@@ -20,6 +20,7 @@ type MeasurementConfig struct {
OutputDir string OutputDir string
QueryType string QueryType string
DNSSEC bool DNSSEC bool
KeepAlive bool
Interface string Interface string
Servers []string Servers []string
} }
@@ -74,9 +75,14 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT
defer dnsClient.Close() defer dnsClient.Close()
// Setup output files // Setup output files
jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC) jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC, r.config.KeepAlive)
fmt.Printf(">>> Measuring %s (dnssec=%v) → %s\n", upstream, r.config.DNSSEC, keepAliveStr := ""
if r.config.KeepAlive {
keepAliveStr = " (keep-alive)"
}
fmt.Printf(">>> Measuring %s (dnssec=%v%s) → %s\n", upstream, r.config.DNSSEC, keepAliveStr,
strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/")) strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/"))
// Setup packet capture // Setup packet capture
@@ -98,7 +104,10 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT
} }
func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) { func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) {
opts := client.Options{DNSSEC: r.config.DNSSEC} opts := client.Options{
DNSSEC: r.config.DNSSEC,
KeepAlive: r.config.KeepAlive,
}
return client.New(upstream, opts) return client.New(upstream, opts)
} }
@@ -136,7 +145,14 @@ func (r *MeasurementRunner) runQueries(dnsClient client.DNSClient, upstream stri
} }
r.printQueryResult(metric) r.printQueryResult(metric)
time.Sleep(10 * time.Millisecond)
// For keep-alive connections, add smaller delays between queries
// to better utilize the persistent connection
if r.config.KeepAlive {
time.Sleep(5 * time.Millisecond)
} else {
time.Sleep(10 * time.Millisecond)
}
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -154,6 +170,7 @@ func (r *MeasurementRunner) performQuery(dnsClient client.DNSClient, domain, ups
QueryType: r.config.QueryType, QueryType: r.config.QueryType,
Protocol: proto, Protocol: proto,
DNSSEC: r.config.DNSSEC, DNSSEC: r.config.DNSSEC,
KeepAlive: r.config.KeepAlive,
DNSServer: upstream, DNSServer: upstream,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
@@ -199,8 +216,14 @@ func (r *MeasurementRunner) printQueryResult(metric results.DNSMetric) {
if metric.ResponseCode == "ERROR" { if metric.ResponseCode == "ERROR" {
statusIcon = "✗" statusIcon = "✗"
} }
fmt.Printf("%s %s [%s] %s %.2fms\n",
statusIcon, metric.Domain, metric.Protocol, metric.ResponseCode, metric.DurationMs) keepAliveIndicator := ""
if metric.KeepAlive {
keepAliveIndicator = "⟷"
}
fmt.Printf("%s %s%s [%s] %s %.2fms\n",
statusIcon, metric.Domain, keepAliveIndicator, metric.Protocol, metric.ResponseCode, metric.DurationMs)
} }
func (r *MeasurementRunner) readDomainsFile() ([]string, error) { func (r *MeasurementRunner) readDomainsFile() ([]string, error) {

View File

@@ -12,6 +12,7 @@ type DNSMetric struct {
QueryType string `json:"query_type"` QueryType string `json:"query_type"`
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
DNSSEC bool `json:"dnssec"` DNSSEC bool `json:"dnssec"`
KeepAlive bool `json:"keep_alive"`
DNSServer string `json:"dns_server"` DNSServer string `json:"dns_server"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Duration int64 `json:"duration_ns"` Duration int64 `json:"duration_ns"`
@@ -22,6 +23,7 @@ type DNSMetric struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
} }
// Rest stays exactly the same
type MetricsWriter struct { type MetricsWriter struct {
encoder *json.Encoder encoder *json.Encoder
file *os.File file *os.File

View File

@@ -8,17 +8,18 @@ import (
"time" "time"
) )
func GenerateOutputPaths(outputDir, upstream string, dnssec bool) (jsonPath, pcapPath string) { func GenerateOutputPaths(outputDir, upstream string, dnssec, keepAlive bool) (jsonPath, pcapPath string) {
proto := DetectProtocol(upstream) proto := DetectProtocol(upstream)
serverName := ExtractServerName(upstream) serverName := ExtractServerName(upstream)
ts := time.Now().Format("20060102_1504") ts := time.Now().Format("20060102_1504")
dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec] dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec]
keepAliveStr := map[bool]string{true: "on", false: "off"}[keepAlive]
base := fmt.Sprintf("%s_%s_dnssec_%s_%s", base := fmt.Sprintf("%s_%s_dnssec_%s_keepalive_%s_%s",
proto, sanitize(serverName), dnssecStr, ts) proto, sanitize(serverName), dnssecStr, keepAliveStr, ts)
return filepath.Join(outputDir, base+".jsonl"), return filepath.Join(outputDir, base+".jsonl"),
filepath.Join(outputDir, base+".pcap") filepath.Join(outputDir, base+".pcap")
} }
func sanitize(s string) string { func sanitize(s string) string {

View File

@@ -1,153 +1,15 @@
package server
import (
"context"
"fmt"
"net"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"github.com/afonsofrancof/sdns-proxy/client"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns"
)
type Config struct { type Config struct {
Address string Address string
Upstream string Upstream string
Fallback string Fallback string
Bootstrap string Bootstrap string
DNSSEC bool DNSSEC bool
KeepAlive bool
Timeout time.Duration Timeout time.Duration
Verbose bool Verbose bool
} }
type cacheKey struct { // Update the initClients method:
domain string
qtype uint16
}
type cacheEntry struct {
records []dns.RR
expiresAt time.Time
}
type Server struct {
config Config
upstreamClient client.DNSClient
fallbackClient client.DNSClient
bootstrapClient client.DNSClient
resolvedHosts map[string]string
queryCache map[cacheKey]*cacheEntry
hostsMutex sync.RWMutex
cacheMutex sync.RWMutex
dnsServer *dns.Server
}
func New(config Config) (*Server, error) {
logger.Debug("Creating new server with config: %+v", config)
if config.Upstream == "" {
logger.Error("Upstream server is required")
return nil, fmt.Errorf("upstream server is required")
}
// Check if we need bootstrap server
needsBootstrap := containsHostname(config.Upstream)
if config.Fallback != "" {
needsBootstrap = needsBootstrap || containsHostname(config.Fallback)
}
logger.Debug("Bootstrap needed: %v (upstream has hostname: %v, fallback has hostname: %v)",
needsBootstrap, containsHostname(config.Upstream),
config.Fallback != "" && containsHostname(config.Fallback))
if needsBootstrap && config.Bootstrap == "" {
logger.Error("Bootstrap server is required when upstream or fallback contains hostnames")
return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames")
}
if config.Bootstrap != "" && containsHostname(config.Bootstrap) {
logger.Error("Bootstrap server cannot contain hostnames: %s", config.Bootstrap)
return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap)
}
s := &Server{
config: config,
resolvedHosts: make(map[string]string),
queryCache: make(map[cacheKey]*cacheEntry),
}
// Create bootstrap client if needed
if config.Bootstrap != "" {
logger.Debug("Creating bootstrap client for %s", config.Bootstrap)
bootstrapClient, err := client.New(config.Bootstrap, client.Options{
DNSSEC: false,
})
if err != nil {
logger.Error("Failed to create bootstrap client: %v", err)
return nil, fmt.Errorf("failed to create bootstrap client: %w", err)
}
s.bootstrapClient = bootstrapClient
logger.Debug("Bootstrap client created successfully")
}
// Initialize upstream and fallback clients
if err := s.initClients(); err != nil {
logger.Error("Failed to initialize clients: %v", err)
return nil, fmt.Errorf("failed to initialize clients: %w", err)
}
// Setup DNS server
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleDNSRequest)
s.dnsServer = &dns.Server{
Addr: config.Address,
Net: "udp",
Handler: mux,
}
logger.Debug("Server created successfully, listening on %s", config.Address)
return s, nil
}
func containsHostname(serverAddr string) bool {
logger.Debug("Checking if %s contains hostname", serverAddr)
// Use the same parsing logic as the client package
parsedURL, err := url.Parse(serverAddr)
if err != nil {
logger.Debug("URL parsing failed for %s, treating as plain address", serverAddr)
// If URL parsing fails, assume it's a plain address
host, _, err := net.SplitHostPort(serverAddr)
if err != nil {
// Assume it's just a host
isHostname := net.ParseIP(serverAddr) == nil
logger.Debug("Address %s is hostname: %v", serverAddr, isHostname)
return isHostname
}
isHostname := net.ParseIP(host) == nil
logger.Debug("Host %s from %s is hostname: %v", host, serverAddr, isHostname)
return isHostname
}
host := parsedURL.Hostname()
if host == "" {
logger.Debug("No hostname found in URL %s", serverAddr)
return false
}
isHostname := net.ParseIP(host) == nil
logger.Debug("Host %s from URL %s is hostname: %v", host, serverAddr, isHostname)
return isHostname
}
func (s *Server) initClients() error { func (s *Server) initClients() error {
logger.Debug("Initializing DNS clients") logger.Debug("Initializing DNS clients")
@@ -160,7 +22,8 @@ func (s *Server) initClients() error {
logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream) logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream)
upstreamClient, err := client.New(resolvedUpstream, client.Options{ upstreamClient, err := client.New(resolvedUpstream, client.Options{
DNSSEC: s.config.DNSSEC, DNSSEC: s.config.DNSSEC,
KeepAlive: s.config.KeepAlive,
}) })
if err != nil { if err != nil {
logger.Error("Failed to create upstream client: %v", err) logger.Error("Failed to create upstream client: %v", err)
@@ -169,7 +32,7 @@ func (s *Server) initClients() error {
s.upstreamClient = upstreamClient s.upstreamClient = upstreamClient
if s.config.Verbose { if s.config.Verbose {
logger.Info("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream) logger.Info("Initialized upstream client: %s -> %s (KeepAlive: %v)", s.config.Upstream, resolvedUpstream, s.config.KeepAlive)
} }
// Initialize fallback client if specified // Initialize fallback client if specified
@@ -182,7 +45,8 @@ func (s *Server) initClients() error {
logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback) logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback)
fallbackClient, err := client.New(resolvedFallback, client.Options{ fallbackClient, err := client.New(resolvedFallback, client.Options{
DNSSEC: s.config.DNSSEC, DNSSEC: s.config.DNSSEC,
KeepAlive: s.config.KeepAlive,
}) })
if err != nil { if err != nil {
logger.Error("Failed to create fallback client: %v", err) logger.Error("Failed to create fallback client: %v", err)
@@ -191,402 +55,10 @@ func (s *Server) initClients() error {
s.fallbackClient = fallbackClient s.fallbackClient = fallbackClient
if s.config.Verbose { if s.config.Verbose {
logger.Info("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback) logger.Info("Initialized fallback client: %s -> %s (KeepAlive: %v)", s.config.Fallback, resolvedFallback, s.config.KeepAlive)
} }
} }
logger.Debug("All DNS clients initialized successfully") logger.Debug("All DNS clients initialized successfully")
return nil return nil
} }
func (s *Server) resolveServerAddress(serverAddr string) (string, error) {
logger.Debug("Resolving server address: %s", serverAddr)
// If it doesn't contain hostnames, return as-is
if !containsHostname(serverAddr) {
logger.Debug("Address %s contains no hostnames, returning as-is", serverAddr)
return serverAddr, nil
}
// If no bootstrap client, we can't resolve hostnames
if s.bootstrapClient == nil {
logger.Error("Cannot resolve hostname in %s: no bootstrap server configured", serverAddr)
return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr)
}
// Use the same parsing logic as the client package
parsedURL, err := url.Parse(serverAddr)
if err != nil {
logger.Debug("Parsing %s as plain host:port format", serverAddr)
// Handle plain host:port format
host, port, err := net.SplitHostPort(serverAddr)
if err != nil {
// Assume it's just a hostname
resolvedIP, err := s.resolveHostname(serverAddr)
if err != nil {
return "", err
}
logger.Debug("Resolved %s to %s", serverAddr, resolvedIP)
return resolvedIP, nil
}
resolvedIP, err := s.resolveHostname(host)
if err != nil {
return "", err
}
resolved := net.JoinHostPort(resolvedIP, port)
logger.Debug("Resolved %s to %s", serverAddr, resolved)
return resolved, nil
}
// Handle URL format
hostname := parsedURL.Hostname()
if hostname == "" {
logger.Error("No hostname in URL: %s", serverAddr)
return "", fmt.Errorf("no hostname in URL: %s", serverAddr)
}
resolvedIP, err := s.resolveHostname(hostname)
if err != nil {
return "", err
}
// Replace hostname with IP in the URL
port := parsedURL.Port()
if port == "" {
parsedURL.Host = resolvedIP
} else {
parsedURL.Host = net.JoinHostPort(resolvedIP, port)
}
resolved := parsedURL.String()
logger.Debug("Resolved URL %s to %s", serverAddr, resolved)
return resolved, nil
}
func (s *Server) resolveHostname(hostname string) (string, error) {
logger.Debug("Resolving hostname: %s", hostname)
// Check cache first
s.hostsMutex.RLock()
if ip, exists := s.resolvedHosts[hostname]; exists {
s.hostsMutex.RUnlock()
logger.Debug("Found cached resolution for %s: %s", hostname, ip)
return ip, nil
}
s.hostsMutex.RUnlock()
// Resolve using bootstrap
if s.config.Verbose {
logger.Info("Resolving hostname %s using bootstrap server", hostname)
}
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(hostname), dns.TypeA)
msg.Id = dns.Id()
msg.RecursionDesired = true
logger.Debug("Sending bootstrap query for %s (ID: %d)", hostname, msg.Id)
msg, err := s.bootstrapClient.Query(msg)
if err != nil {
logger.Error("Bootstrap query failed for %s: %v", hostname, err)
return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err)
}
logger.Debug("Bootstrap response for %s: %d answers", hostname, len(msg.Answer))
if len(msg.Answer) == 0 {
logger.Error("No A records found for %s", hostname)
return "", fmt.Errorf("no A records found for %s", hostname)
}
// Find first A record
for _, rr := range msg.Answer {
if a, ok := rr.(*dns.A); ok {
ip := a.A.String()
// Cache the result
s.hostsMutex.Lock()
s.resolvedHosts[hostname] = ip
s.hostsMutex.Unlock()
if s.config.Verbose {
logger.Info("Resolved %s to %s", hostname, ip)
}
logger.Debug("Cached resolution: %s -> %s", hostname, ip)
return ip, nil
}
}
logger.Error("No valid A record found for %s", hostname)
return "", fmt.Errorf("no valid A record found for %s", hostname)
}
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
logger.Debug("Received request with no questions from %s", w.RemoteAddr())
dns.HandleFailed(w, r)
return
}
question := r.Question[0]
domain := strings.ToLower(question.Name)
qtype := question.Qtype
logger.Debug("Handling DNS request: %s %s from %s (ID: %d)",
question.Name, dns.TypeToString[qtype], w.RemoteAddr(), r.Id)
if s.config.Verbose {
logger.Info("Query: %s %s from %s",
question.Name,
dns.TypeToString[qtype],
w.RemoteAddr())
}
// Check cache first
if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil {
response := s.buildResponse(r, cachedRecords)
if s.config.Verbose {
logger.Info("Cache hit: %s %s -> %d records",
question.Name,
dns.TypeToString[qtype],
len(cachedRecords))
}
logger.Debug("Serving cached response for %s %s (%d records)",
question.Name, dns.TypeToString[qtype], len(cachedRecords))
w.WriteMsg(response)
return
}
logger.Debug("Cache miss for %s %s, querying upstream", question.Name, dns.TypeToString[qtype])
// Try upstream first
response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype)
if err != nil {
if s.config.Verbose {
logger.Info("Upstream query failed: %v", err)
}
logger.Debug("Upstream query failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err)
// Try fallback if available
if s.fallbackClient != nil {
if s.config.Verbose {
logger.Info("Trying fallback server")
}
logger.Debug("Attempting fallback query for %s %s", question.Name, dns.TypeToString[qtype])
response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype)
if err != nil {
logger.Error("Both upstream and fallback failed for %s %s: %v",
question.Name,
dns.TypeToString[qtype],
err)
} else {
logger.Debug("Fallback query succeeded for %s %s", question.Name, dns.TypeToString[qtype])
}
}
// If still failed, return SERVFAIL
if err != nil {
logger.Error("All servers failed for %s %s: %v",
question.Name,
dns.TypeToString[qtype],
err)
m := new(dns.Msg)
m.SetReply(r)
m.Rcode = dns.RcodeServerFailure
w.WriteMsg(m)
return
}
} else {
logger.Debug("Upstream query succeeded for %s %s", question.Name, dns.TypeToString[qtype])
}
// Cache successful response
s.cacheResponse(domain, qtype, response)
// Copy request ID to response
response.Id = r.Id
if s.config.Verbose {
logger.Info("Response: %s %s -> %d answers",
question.Name,
dns.TypeToString[qtype],
len(response.Answer))
}
logger.Debug("Sending response for %s %s: %d answers, rcode: %s",
question.Name, dns.TypeToString[qtype], len(response.Answer), dns.RcodeToString[response.Rcode])
w.WriteMsg(response)
}
func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR {
key := cacheKey{domain: domain, qtype: qtype}
s.cacheMutex.RLock()
entry, exists := s.queryCache[key]
s.cacheMutex.RUnlock()
if !exists {
logger.Debug("No cache entry for %s %s", domain, dns.TypeToString[qtype])
return nil
}
// Check if expired and clean up on the spot
if time.Now().After(entry.expiresAt) {
logger.Debug("Cache entry expired for %s %s", domain, dns.TypeToString[qtype])
s.cacheMutex.Lock()
delete(s.queryCache, key)
s.cacheMutex.Unlock()
return nil
}
logger.Debug("Cache hit for %s %s (%d records, expires in %v)",
domain, dns.TypeToString[qtype], len(entry.records), time.Until(entry.expiresAt))
// Return a copy of the cached records
records := make([]dns.RR, len(entry.records))
for i, rr := range entry.records {
records[i] = dns.Copy(rr)
}
return records
}
func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg {
response := new(dns.Msg)
response.SetReply(request)
response.Answer = records
logger.Debug("Built response with %d records", len(records))
return response
}
func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) {
if msg == nil || len(msg.Answer) == 0 {
logger.Debug("Not caching empty response for %s %s", domain, dns.TypeToString[qtype])
return
}
var validRecords []dns.RR
minTTL := uint32(3600)
// Find minimum TTL from answer records
for _, rr := range msg.Answer {
// Only cache records that match our query type or are CNAMEs
if rr.Header().Rrtype == qtype || rr.Header().Rrtype == dns.TypeCNAME {
validRecords = append(validRecords, dns.Copy(rr))
if rr.Header().Ttl < minTTL {
minTTL = rr.Header().Ttl
}
}
}
if len(validRecords) == 0 {
logger.Debug("No valid records to cache for %s %s", domain, dns.TypeToString[qtype])
return
}
// Don't cache responses with very low TTL
if minTTL < 10 {
logger.Debug("TTL too low (%ds) for caching %s %s", minTTL, domain, dns.TypeToString[qtype])
return
}
key := cacheKey{domain: domain, qtype: qtype}
entry := &cacheEntry{
records: validRecords,
expiresAt: time.Now().Add(time.Duration(minTTL) * time.Second),
}
s.cacheMutex.Lock()
s.queryCache[key] = entry
s.cacheMutex.Unlock()
if s.config.Verbose {
logger.Info("Cached %d records for %s %s (TTL: %ds)",
len(validRecords), domain, dns.TypeToString[qtype], minTTL)
}
logger.Debug("Cached %d records for %s %s (TTL: %ds, expires: %v)",
len(validRecords), domain, dns.TypeToString[qtype], minTTL, entry.expiresAt)
}
func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) {
logger.Debug("Querying upstream for %s %s", domain, dns.TypeToString[qtype])
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout)
defer cancel()
// Channel to receive result
type result struct {
msg *dns.Msg
err error
}
resultChan := make(chan result, 1)
// Query in goroutine to respect context timeout
go func() {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), qtype)
msg.Id = dns.Id()
msg.RecursionDesired = true
logger.Debug("Sending upstream query: %s %s (ID: %d)", domain, dns.TypeToString[qtype], msg.Id)
recvMsg, err := upstreamClient.Query(msg)
if err != nil {
logger.Debug("Upstream query error for %s %s: %v", domain, dns.TypeToString[qtype], err)
} else {
logger.Debug("Upstream query response for %s %s: %d answers, rcode: %s",
domain, dns.TypeToString[qtype], len(recvMsg.Answer), dns.RcodeToString[recvMsg.Rcode])
}
resultChan <- result{msg: recvMsg, err: err}
}()
select {
case res := <-resultChan:
return res.msg, res.err
case <-ctx.Done():
logger.Debug("Upstream query timeout for %s %s after %v", domain, dns.TypeToString[qtype], s.config.Timeout)
return nil, fmt.Errorf("upstream query timeout")
}
}
func (s *Server) Start() error {
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
logger.Info("Received signal %v, shutting down DNS server...", sig)
s.Shutdown()
}()
logger.Info("DNS proxy server listening on %s", s.config.Address)
logger.Debug("Server starting with timeout: %v, DNSSEC: %v", s.config.Timeout, s.config.DNSSEC)
return s.dnsServer.ListenAndServe()
}
func (s *Server) Shutdown() {
logger.Debug("Shutting down server components")
if s.dnsServer != nil {
logger.Debug("Shutting down DNS server")
s.dnsServer.Shutdown()
}
if s.upstreamClient != nil {
logger.Debug("Closing upstream client")
s.upstreamClient.Close()
}
if s.fallbackClient != nil {
logger.Debug("Closing fallback client")
s.fallbackClient.Close()
}
if s.bootstrapClient != nil {
logger.Debug("Closing bootstrap client")
s.bootstrapClient.Close()
}
logger.Info("Server shutdown complete")
}