feat(keep-alive): add keep-alive option
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,3 +15,4 @@
|
|||||||
# vendor/
|
# vendor/
|
||||||
|
|
||||||
**/tls-key-log.txt
|
**/tls-key-log.txt
|
||||||
|
/results
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ type Options struct {
|
|||||||
DNSSEC bool
|
DNSSEC bool
|
||||||
ValidateOnly bool
|
ValidateOnly bool
|
||||||
StrictValidation bool
|
StrictValidation bool
|
||||||
|
KeepAlive bool // New flag for long-lived connections
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a DNS client based on the upstream string
|
// New creates a DNS client based on the upstream string
|
||||||
@@ -214,8 +215,8 @@ func getDefaultPath(scheme string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) {
|
func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) {
|
||||||
logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v",
|
logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v, KeepAlive=%v",
|
||||||
scheme, host, port, path, opts.DNSSEC)
|
scheme, host, port, path, opts.DNSSEC, opts.KeepAlive)
|
||||||
|
|
||||||
switch scheme {
|
switch scheme {
|
||||||
case "udp", "tcp", "do53", "":
|
case "udp", "tcp", "do53", "":
|
||||||
@@ -228,40 +229,44 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err
|
|||||||
|
|
||||||
case "http", "doh":
|
case "http", "doh":
|
||||||
config := doh.Config{
|
config := doh.Config{
|
||||||
Host: host,
|
Host: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Path: path,
|
Path: path,
|
||||||
DNSSEC: opts.DNSSEC,
|
DNSSEC: opts.DNSSEC,
|
||||||
HTTP3: false,
|
HTTP3: false,
|
||||||
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
logger.Debug("Creating DoH client with config: %+v", config)
|
logger.Debug("Creating DoH client with config: %+v", config)
|
||||||
return doh.New(config)
|
return doh.New(config)
|
||||||
|
|
||||||
case "https", "doh3":
|
case "https", "doh3":
|
||||||
config := doh.Config{
|
config := doh.Config{
|
||||||
Host: host,
|
Host: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Path: path,
|
Path: path,
|
||||||
DNSSEC: opts.DNSSEC,
|
DNSSEC: opts.DNSSEC,
|
||||||
HTTP3: true,
|
HTTP3: true,
|
||||||
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
logger.Debug("Creating DoH3 client with config: %+v", config)
|
logger.Debug("Creating DoH3 client with config: %+v", config)
|
||||||
return doh.New(config)
|
return doh.New(config)
|
||||||
|
|
||||||
case "tls", "dot":
|
case "tls", "dot":
|
||||||
config := dot.Config{
|
config := dot.Config{
|
||||||
Host: host,
|
Host: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
DNSSEC: opts.DNSSEC,
|
DNSSEC: opts.DNSSEC,
|
||||||
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
logger.Debug("Creating DoT client with config: %+v", config)
|
logger.Debug("Creating DoT client with config: %+v", config)
|
||||||
return dot.New(config)
|
return dot.New(config)
|
||||||
|
|
||||||
case "doq": // DNS over QUIC
|
case "doq": // DNS over QUIC
|
||||||
config := doq.Config{
|
config := doq.Config{
|
||||||
Host: host,
|
Host: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
DNSSEC: opts.DNSSEC,
|
DNSSEC: opts.DNSSEC,
|
||||||
|
KeepAlive: opts.KeepAlive,
|
||||||
}
|
}
|
||||||
logger.Debug("Creating DoQ client with config: %+v", config)
|
logger.Debug("Creating DoQ client with config: %+v", config)
|
||||||
return doq.New(config)
|
return doq.New(config)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ type RunCmd struct {
|
|||||||
QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"`
|
QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"`
|
||||||
Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"`
|
Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"`
|
||||||
DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"`
|
DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"`
|
||||||
|
KeepAlive bool `short:"k" long:"keep-alive" help:"Use persistent connections"`
|
||||||
Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"`
|
Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"`
|
||||||
Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"`
|
Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"`
|
||||||
}
|
}
|
||||||
@@ -27,6 +28,7 @@ func (r *RunCmd) Run() error {
|
|||||||
OutputDir: r.OutputDir,
|
OutputDir: r.OutputDir,
|
||||||
QueryType: r.QueryType,
|
QueryType: r.QueryType,
|
||||||
DNSSEC: r.DNSSEC,
|
DNSSEC: r.DNSSEC,
|
||||||
|
KeepAlive: r.KeepAlive,
|
||||||
Interface: r.Interface,
|
Interface: r.Interface,
|
||||||
Servers: r.Servers,
|
Servers: r.Servers,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ type QueryCmd struct {
|
|||||||
DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"`
|
DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"`
|
||||||
ValidateOnly bool `help:"Only return DNSSEC validated responses." short:"V"`
|
ValidateOnly bool `help:"Only return DNSSEC validated responses." short:"V"`
|
||||||
StrictValidation bool `help:"Fail on any DNSSEC validation error." short:"S"`
|
StrictValidation bool `help:"Fail on any DNSSEC validation error." short:"S"`
|
||||||
|
KeepAlive bool `help:"Use persistent connections." short:"k"`
|
||||||
Timeout time.Duration `help:"Timeout for the query operation." default:"10s"`
|
Timeout time.Duration `help:"Timeout for the query operation." default:"10s"`
|
||||||
KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"`
|
KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"`
|
||||||
}
|
}
|
||||||
@@ -36,18 +37,20 @@ type ListenCmd struct {
|
|||||||
Fallback string `help:"Fallback DNS server (e.g., https://1.1.1.1/dns-query, tls://8.8.8.8)." short:"f"`
|
Fallback string `help:"Fallback DNS server (e.g., https://1.1.1.1/dns-query, tls://8.8.8.8)." short:"f"`
|
||||||
Bootstrap string `help:"Bootstrap DNS server (must be an IP address, e.g., 8.8.8.8, 1.1.1.1)." short:"b"`
|
Bootstrap string `help:"Bootstrap DNS server (must be an IP address, e.g., 8.8.8.8, 1.1.1.1)." short:"b"`
|
||||||
DNSSEC bool `help:"Enable DNSSEC for upstream queries." short:"d"`
|
DNSSEC bool `help:"Enable DNSSEC for upstream queries." short:"d"`
|
||||||
|
KeepAlive bool `help:"Use persistent connections to upstream servers." short:"k"`
|
||||||
Timeout time.Duration `help:"Timeout for upstream queries." default:"5s"`
|
Timeout time.Duration `help:"Timeout for upstream queries." default:"5s"`
|
||||||
Verbose bool `help:"Enable verbose logging." short:"v"`
|
Verbose bool `help:"Enable verbose logging." short:"v"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QueryCmd) Run() error {
|
func (q *QueryCmd) Run() error {
|
||||||
logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)",
|
logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, KeepAlive: %v, Timeout: %v)",
|
||||||
q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.Timeout)
|
q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.KeepAlive, q.Timeout)
|
||||||
|
|
||||||
opts := client.Options{
|
opts := client.Options{
|
||||||
DNSSEC: q.DNSSEC,
|
DNSSEC: q.DNSSEC,
|
||||||
ValidateOnly: q.ValidateOnly,
|
ValidateOnly: q.ValidateOnly,
|
||||||
StrictValidation: q.StrictValidation,
|
StrictValidation: q.StrictValidation,
|
||||||
|
KeepAlive: q.KeepAlive,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Creating DNS client with options: %+v", opts)
|
logger.Debug("Creating DNS client with options: %+v", opts)
|
||||||
@@ -90,6 +93,7 @@ func (l *ListenCmd) Run() error {
|
|||||||
Fallback: l.Fallback,
|
Fallback: l.Fallback,
|
||||||
Bootstrap: l.Bootstrap,
|
Bootstrap: l.Bootstrap,
|
||||||
DNSSEC: l.DNSSEC,
|
DNSSEC: l.DNSSEC,
|
||||||
|
KeepAlive: l.KeepAlive,
|
||||||
Timeout: l.Timeout,
|
Timeout: l.Timeout,
|
||||||
Verbose: l.Verbose,
|
Verbose: l.Verbose,
|
||||||
}
|
}
|
||||||
@@ -105,10 +109,12 @@ func (l *ListenCmd) Run() error {
|
|||||||
logger.Info("Upstream server: %v", l.Upstream)
|
logger.Info("Upstream server: %v", l.Upstream)
|
||||||
logger.Info("Fallback server: %v", l.Fallback)
|
logger.Info("Fallback server: %v", l.Fallback)
|
||||||
logger.Info("Bootstrap server: %v", l.Bootstrap)
|
logger.Info("Bootstrap server: %v", l.Bootstrap)
|
||||||
|
logger.Info("KeepAlive: %v", l.KeepAlive)
|
||||||
|
|
||||||
return srv.Start()
|
return srv.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func printResponse(domain, qtype string, msg *dns.Msg) {
|
func printResponse(domain, qtype string, msg *dns.Msg) {
|
||||||
fmt.Println(";; QUESTION SECTION:")
|
fmt.Println(";; QUESTION SECTION:")
|
||||||
|
|
||||||
|
|||||||
@@ -22,12 +22,13 @@ import (
|
|||||||
const dnsMessageContentType = "application/dns-message"
|
const dnsMessageContentType = "application/dns-message"
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Host string
|
Host string
|
||||||
Port string
|
Port string
|
||||||
Path string
|
Path string
|
||||||
DNSSEC bool
|
DNSSEC bool
|
||||||
HTTP3 bool
|
HTTP3 bool
|
||||||
HTTP2 bool
|
HTTP2 bool
|
||||||
|
KeepAlive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -37,7 +38,7 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) (*Client, error) {
|
func New(config Config) (*Client, error) {
|
||||||
logger.Debug("Creating DoH client: %s:%s%s", config.Host, config.Port, config.Path)
|
logger.Debug("Creating DoH client: %s:%s%s (KeepAlive: %v)", config.Host, config.Port, config.Path, config.KeepAlive)
|
||||||
|
|
||||||
if config.Host == "" || config.Port == "" || config.Path == "" {
|
if config.Host == "" || config.Port == "" || config.Path == "" {
|
||||||
logger.Error("DoH client creation failed: missing required fields")
|
logger.Error("DoH client creation failed: missing required fields")
|
||||||
@@ -65,31 +66,58 @@ func New(config Config) (*Client, error) {
|
|||||||
DisablePathMTUDiscovery: true,
|
DisablePathMTUDiscovery: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
transport := http.DefaultTransport.(*http.Transport)
|
transport := &http.Transport{
|
||||||
transport.TLSClientConfig = tlsConfig
|
TLSClientConfig: tlsConfig,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure connection pooling based on KeepAlive setting
|
||||||
|
if config.KeepAlive {
|
||||||
|
transport.MaxIdleConnsPerHost = 10
|
||||||
|
transport.MaxConnsPerHost = 10
|
||||||
|
} else {
|
||||||
|
transport.MaxIdleConns = 0
|
||||||
|
transport.MaxIdleConnsPerHost = 0
|
||||||
|
transport.DisableKeepAlives = true
|
||||||
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
var transportType string
|
var transportType string
|
||||||
if config.HTTP2 {
|
if config.HTTP2 {
|
||||||
httpClient.Transport = &http2.Transport{
|
http2Transport := &http2.Transport{
|
||||||
TLSClientConfig: tlsConfig,
|
TLSClientConfig: tlsConfig,
|
||||||
AllowHTTP: true,
|
AllowHTTP: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !config.KeepAlive {
|
||||||
|
http2Transport.DisableCompression = true
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient.Transport = http2Transport
|
||||||
transportType = "HTTP/2"
|
transportType = "HTTP/2"
|
||||||
} else if config.HTTP3 {
|
} else if config.HTTP3 {
|
||||||
quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig)
|
quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig)
|
||||||
httpClient.Transport = &http3.Transport{
|
http3Transport := &http3.Transport{
|
||||||
TLSClientConfig: quicTlsConfig,
|
TLSClientConfig: quicTlsConfig,
|
||||||
QUICConfig: quicConfig,
|
QUICConfig: quicConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !config.KeepAlive {
|
||||||
|
http3Transport.DisableCompression = true
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient.Transport = http3Transport
|
||||||
transportType = "HTTP/3"
|
transportType = "HTTP/3"
|
||||||
} else {
|
} else {
|
||||||
transportType = "HTTP/1.1"
|
transportType = "HTTP/1.1"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("DoH client created: %s (%s, DNSSEC: %v)", rawURL, transportType, config.DNSSEC)
|
logger.Debug("DoH client created: %s (%s, DNSSEC: %v, KeepAlive: %v)", rawURL, transportType, config.DNSSEC, config.KeepAlive)
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
httpClient: httpClient,
|
httpClient: httpClient,
|
||||||
@@ -102,6 +130,8 @@ func (c *Client) Close() {
|
|||||||
logger.Debug("Closing DoH client")
|
logger.Debug("Closing DoH client")
|
||||||
if t, ok := c.httpClient.Transport.(*http.Transport); ok {
|
if t, ok := c.httpClient.Transport.(*http.Transport); ok {
|
||||||
t.CloseIdleConnections()
|
t.CloseIdleConnections()
|
||||||
|
} else if t2, ok := c.httpClient.Transport.(*http2.Transport); ok {
|
||||||
|
t2.CloseIdleConnections()
|
||||||
} else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok {
|
} else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok {
|
||||||
t3.CloseIdleConnections()
|
t3.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
@@ -132,6 +162,13 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
httpReq.Header.Set("User-Agent", "sdns-proxy")
|
httpReq.Header.Set("User-Agent", "sdns-proxy")
|
||||||
httpReq.Header.Set("Content-Type", dnsMessageContentType)
|
httpReq.Header.Set("Content-Type", dnsMessageContentType)
|
||||||
httpReq.Header.Set("Accept", dnsMessageContentType)
|
httpReq.Header.Set("Accept", dnsMessageContentType)
|
||||||
|
|
||||||
|
// Set Connection header based on KeepAlive setting
|
||||||
|
if c.config.KeepAlive {
|
||||||
|
httpReq.Header.Set("Connection", "keep-alive")
|
||||||
|
} else {
|
||||||
|
httpReq.Header.Set("Connection", "close")
|
||||||
|
}
|
||||||
|
|
||||||
httpResp, err := c.httpClient.Do(httpReq)
|
httpResp, err := c.httpClient.Do(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
||||||
@@ -16,10 +17,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Host string
|
Host string
|
||||||
Port string
|
Port string
|
||||||
Debug bool
|
Debug bool
|
||||||
DNSSEC bool
|
DNSSEC bool
|
||||||
|
KeepAlive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -30,10 +32,11 @@ type Client struct {
|
|||||||
quicTransport *quic.Transport
|
quicTransport *quic.Transport
|
||||||
quicConfig *quic.Config
|
quicConfig *quic.Config
|
||||||
config Config
|
config Config
|
||||||
|
connMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) (*Client, error) {
|
func New(config Config) (*Client, error) {
|
||||||
logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port)
|
logger.Debug("Creating DoQ client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
ServerName: config.Host,
|
ServerName: config.Host,
|
||||||
@@ -62,9 +65,7 @@ func New(config Config) (*Client, error) {
|
|||||||
MaxIdleTimeout: 30 * time.Second,
|
MaxIdleTimeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC)
|
client := &Client{
|
||||||
|
|
||||||
return &Client{
|
|
||||||
targetAddr: targetAddr,
|
targetAddr: targetAddr,
|
||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
udpConn: udpConn,
|
udpConn: udpConn,
|
||||||
@@ -72,18 +73,52 @@ func New(config Config) (*Client, error) {
|
|||||||
quicTransport: &quicTransport,
|
quicTransport: &quicTransport,
|
||||||
quicConfig: &quicConfig,
|
quicConfig: &quicConfig,
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
// If keep-alive is enabled, establish connection now
|
||||||
|
if config.KeepAlive {
|
||||||
|
if err := client.ensureConnection(); err != nil {
|
||||||
|
logger.Error("DoQ failed to establish initial connection: %v", err)
|
||||||
|
client.Close()
|
||||||
|
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v, KeepAlive: %v)", config.Host, config.Port, config.DNSSEC, config.KeepAlive)
|
||||||
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() {
|
func (c *Client) Close() {
|
||||||
logger.Debug("Closing DoQ client")
|
logger.Debug("Closing DoQ client")
|
||||||
|
c.connMutex.Lock()
|
||||||
|
defer c.connMutex.Unlock()
|
||||||
|
|
||||||
|
if c.quicConn != nil {
|
||||||
|
c.quicConn.CloseWithError(0, "client shutdown")
|
||||||
|
c.quicConn = nil
|
||||||
|
}
|
||||||
if c.udpConn != nil {
|
if c.udpConn != nil {
|
||||||
c.udpConn.Close()
|
c.udpConn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) OpenConnection() error {
|
func (c *Client) ensureConnection() error {
|
||||||
logger.Debug("Opening DoQ connection to %s", c.targetAddr)
|
c.connMutex.Lock()
|
||||||
|
defer c.connMutex.Unlock()
|
||||||
|
|
||||||
|
// Check if existing connection is still valid
|
||||||
|
if c.quicConn != nil {
|
||||||
|
select {
|
||||||
|
case <-c.quicConn.Context().Done():
|
||||||
|
logger.Debug("DoQ connection closed, reconnecting")
|
||||||
|
c.quicConn = nil
|
||||||
|
default:
|
||||||
|
// Connection is still valid
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Establishing DoQ connection to %s", c.targetAddr)
|
||||||
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
|
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err)
|
logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err)
|
||||||
@@ -101,10 +136,19 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr)
|
logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.quicConn == nil {
|
// Ensure we have a connection (either persistent or new)
|
||||||
err := c.OpenConnection()
|
if c.config.KeepAlive {
|
||||||
if err != nil {
|
if err := c.ensureConnection(); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("doq: failed to ensure connection: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For non-keepalive mode, create a fresh connection for each query
|
||||||
|
c.connMutex.Lock()
|
||||||
|
c.quicConn = nil // Force new connection
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
|
||||||
|
if err := c.ensureConnection(); err != nil {
|
||||||
|
return nil, fmt.Errorf("doq: failed to create connection: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,18 +163,33 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
return nil, fmt.Errorf("doq: failed to pack message: %w", err)
|
return nil, fmt.Errorf("doq: failed to pack message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var quicStream quic.Stream
|
// Open a stream for this query
|
||||||
quicStream, err = c.quicConn.OpenStream()
|
c.connMutex.Lock()
|
||||||
|
quicConn := c.quicConn
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
|
||||||
|
quicStream, err := quicConn.OpenStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("DoQ stream failed, reconnecting: %v", err)
|
logger.Error("DoQ failed to open stream: %v", err)
|
||||||
err = c.OpenConnection()
|
|
||||||
if err != nil {
|
// If keep-alive is enabled, try to reconnect once
|
||||||
return nil, err
|
if c.config.KeepAlive {
|
||||||
}
|
logger.Debug("DoQ stream failed with keep-alive, attempting reconnect")
|
||||||
quicStream, err = c.quicConn.OpenStream()
|
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("doq: failed to reconnect: %w", reconnectErr)
|
||||||
logger.Error("DoQ failed to open stream after reconnect: %v", err)
|
}
|
||||||
return nil, err
|
|
||||||
|
c.connMutex.Lock()
|
||||||
|
quicConn = c.quicConn
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
|
||||||
|
quicStream, err = quicConn.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("DoQ failed to open stream after reconnect: %v", err)
|
||||||
|
return nil, fmt.Errorf("doq: failed to open stream after reconnect: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("doq: failed to open stream: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,5 +243,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer))
|
logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the connection if not using keep-alive
|
||||||
|
if !c.config.KeepAlive {
|
||||||
|
c.connMutex.Lock()
|
||||||
|
if c.quicConn != nil {
|
||||||
|
c.quicConn.CloseWithError(0, "query complete")
|
||||||
|
c.quicConn = nil
|
||||||
|
}
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return recvMsg, nil
|
return recvMsg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
||||||
@@ -16,6 +17,7 @@ type Config struct {
|
|||||||
Host string
|
Host string
|
||||||
Port string
|
Port string
|
||||||
DNSSEC bool
|
DNSSEC bool
|
||||||
|
KeepAlive bool
|
||||||
WriteTimeout time.Duration
|
WriteTimeout time.Duration
|
||||||
ReadTimeout time.Duration
|
ReadTimeout time.Duration
|
||||||
Debug bool
|
Debug bool
|
||||||
@@ -25,10 +27,12 @@ type Client struct {
|
|||||||
hostAndPort string
|
hostAndPort string
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
config Config
|
config Config
|
||||||
|
conn *tls.Conn
|
||||||
|
connMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config Config) (*Client, error) {
|
func New(config Config) (*Client, error) {
|
||||||
logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port)
|
logger.Debug("Creating DoT client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
|
||||||
|
|
||||||
if config.Host == "" {
|
if config.Host == "" {
|
||||||
logger.Error("DoT client creation failed: empty host")
|
logger.Error("DoT client creation failed: empty host")
|
||||||
@@ -47,33 +51,76 @@ func New(config Config) (*Client, error) {
|
|||||||
ServerName: config.Host,
|
ServerName: config.Host,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC)
|
client := &Client{
|
||||||
|
|
||||||
return &Client{
|
|
||||||
hostAndPort: hostAndPort,
|
hostAndPort: hostAndPort,
|
||||||
tlsConfig: tlsConfig,
|
tlsConfig: tlsConfig,
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
// If keep-alive is enabled, establish connection now
|
||||||
|
if config.KeepAlive {
|
||||||
|
if err := client.ensureConnection(); err != nil {
|
||||||
|
logger.Error("DoT failed to establish initial connection: %v", err)
|
||||||
|
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("DoT client created: %s (DNSSEC: %v, KeepAlive: %v)", hostAndPort, config.DNSSEC, config.KeepAlive)
|
||||||
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() {
|
func (c *Client) Close() {
|
||||||
logger.Debug("Closing DoT client")
|
logger.Debug("Closing DoT client")
|
||||||
|
c.connMutex.Lock()
|
||||||
|
defer c.connMutex.Unlock()
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) createConnection() (*tls.Conn, error) {
|
func (c *Client) ensureConnection() error {
|
||||||
|
c.connMutex.Lock()
|
||||||
|
defer c.connMutex.Unlock()
|
||||||
|
|
||||||
|
// Check if existing connection is still valid
|
||||||
|
if c.conn != nil {
|
||||||
|
// Test the connection with a very short deadline
|
||||||
|
if err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err == nil {
|
||||||
|
// Try to read one byte to test connection
|
||||||
|
var testBuf [1]byte
|
||||||
|
_, err := c.conn.Read(testBuf[:])
|
||||||
|
|
||||||
|
// Reset deadline
|
||||||
|
c.conn.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
// If we get a timeout error, connection is still good
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any other error means connection is dead
|
||||||
|
logger.Debug("DoT connection test failed, reconnecting: %v", err)
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
|
||||||
dialer := &net.Dialer{
|
dialer := &net.Dialer{
|
||||||
Timeout: c.config.WriteTimeout,
|
Timeout: c.config.WriteTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
|
|
||||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig)
|
conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err)
|
logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err)
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.conn = conn
|
||||||
logger.Debug("DoT connection established to %s", c.hostAndPort)
|
logger.Debug("DoT connection established to %s", c.hostAndPort)
|
||||||
return conn, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||||
@@ -82,12 +129,24 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
|
logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create connection for this query
|
// Ensure we have a connection (either persistent or new)
|
||||||
conn, err := c.createConnection()
|
if c.config.KeepAlive {
|
||||||
if err != nil {
|
if err := c.ensureConnection(); err != nil {
|
||||||
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
|
return nil, fmt.Errorf("dot: failed to ensure connection: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For non-keepalive mode, create a fresh connection for each query
|
||||||
|
c.connMutex.Lock()
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
|
||||||
|
if err := c.ensureConnection(); err != nil {
|
||||||
|
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Prepare DNS message
|
// Prepare DNS message
|
||||||
if c.config.DNSSEC {
|
if c.config.DNSSEC {
|
||||||
@@ -104,6 +163,10 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
binary.BigEndian.PutUint16(length, uint16(len(packed)))
|
binary.BigEndian.PutUint16(length, uint16(len(packed)))
|
||||||
data := append(length, packed...)
|
data := append(length, packed...)
|
||||||
|
|
||||||
|
c.connMutex.Lock()
|
||||||
|
conn := c.conn
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
|
||||||
// Write query
|
// Write query
|
||||||
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||||
logger.Error("DoT failed to set write deadline: %v", err)
|
logger.Error("DoT failed to set write deadline: %v", err)
|
||||||
@@ -112,7 +175,28 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
|
|
||||||
if _, err := conn.Write(data); err != nil {
|
if _, err := conn.Write(data); err != nil {
|
||||||
logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err)
|
logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err)
|
||||||
return nil, fmt.Errorf("dot: failed to write message: %w", err)
|
|
||||||
|
// If keep-alive is enabled and write failed, try to reconnect once
|
||||||
|
if c.config.KeepAlive {
|
||||||
|
logger.Debug("DoT write failed with keep-alive, attempting reconnect")
|
||||||
|
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
||||||
|
return nil, fmt.Errorf("dot: failed to reconnect: %w", reconnectErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connMutex.Lock()
|
||||||
|
conn = c.conn
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
|
||||||
|
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||||
|
return nil, fmt.Errorf("dot: failed to set write deadline after reconnect: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(data); err != nil {
|
||||||
|
return nil, fmt.Errorf("dot: failed to write message after reconnect: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("dot: failed to write message: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read response
|
// Read response
|
||||||
@@ -152,5 +236,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|||||||
logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer))
|
logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the connection if not using keep-alive
|
||||||
|
if !c.config.KeepAlive {
|
||||||
|
c.connMutex.Lock()
|
||||||
|
if c.conn != nil {
|
||||||
|
c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
c.connMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|||||||
27
flake.lock
generated
Normal file
27
flake.lock
generated
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1758916627,
|
||||||
|
"narHash": "sha256-fB2ISCc+xn+9hZ6gOsABxSBcsCgLCjbJ5bC6U9bPzQ4=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "53614373268559d054c080d070cfc732dbe68ac4",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixpkgs-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
||||||
@@ -26,6 +26,11 @@
|
|||||||
go
|
go
|
||||||
gotools
|
gotools
|
||||||
golangci-lint
|
golangci-lint
|
||||||
|
(pkgs.python3.withPackages (python-pkgs: with python-pkgs; [
|
||||||
|
pandas
|
||||||
|
matplotlib
|
||||||
|
seaborn
|
||||||
|
]))
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type MeasurementConfig struct {
|
|||||||
OutputDir string
|
OutputDir string
|
||||||
QueryType string
|
QueryType string
|
||||||
DNSSEC bool
|
DNSSEC bool
|
||||||
|
KeepAlive bool
|
||||||
Interface string
|
Interface string
|
||||||
Servers []string
|
Servers []string
|
||||||
}
|
}
|
||||||
@@ -74,9 +75,14 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT
|
|||||||
defer dnsClient.Close()
|
defer dnsClient.Close()
|
||||||
|
|
||||||
// Setup output files
|
// Setup output files
|
||||||
jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC)
|
jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC, r.config.KeepAlive)
|
||||||
|
|
||||||
fmt.Printf(">>> Measuring %s (dnssec=%v) → %s\n", upstream, r.config.DNSSEC,
|
keepAliveStr := ""
|
||||||
|
if r.config.KeepAlive {
|
||||||
|
keepAliveStr = " (keep-alive)"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(">>> Measuring %s (dnssec=%v%s) → %s\n", upstream, r.config.DNSSEC, keepAliveStr,
|
||||||
strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/"))
|
strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/"))
|
||||||
|
|
||||||
// Setup packet capture
|
// Setup packet capture
|
||||||
@@ -98,7 +104,10 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) {
|
func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) {
|
||||||
opts := client.Options{DNSSEC: r.config.DNSSEC}
|
opts := client.Options{
|
||||||
|
DNSSEC: r.config.DNSSEC,
|
||||||
|
KeepAlive: r.config.KeepAlive,
|
||||||
|
}
|
||||||
return client.New(upstream, opts)
|
return client.New(upstream, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +145,14 @@ func (r *MeasurementRunner) runQueries(dnsClient client.DNSClient, upstream stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.printQueryResult(metric)
|
r.printQueryResult(metric)
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
|
// For keep-alive connections, add smaller delays between queries
|
||||||
|
// to better utilize the persistent connection
|
||||||
|
if r.config.KeepAlive {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
} else {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -154,6 +170,7 @@ func (r *MeasurementRunner) performQuery(dnsClient client.DNSClient, domain, ups
|
|||||||
QueryType: r.config.QueryType,
|
QueryType: r.config.QueryType,
|
||||||
Protocol: proto,
|
Protocol: proto,
|
||||||
DNSSEC: r.config.DNSSEC,
|
DNSSEC: r.config.DNSSEC,
|
||||||
|
KeepAlive: r.config.KeepAlive,
|
||||||
DNSServer: upstream,
|
DNSServer: upstream,
|
||||||
Timestamp: time.Now(),
|
Timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
@@ -199,8 +216,14 @@ func (r *MeasurementRunner) printQueryResult(metric results.DNSMetric) {
|
|||||||
if metric.ResponseCode == "ERROR" {
|
if metric.ResponseCode == "ERROR" {
|
||||||
statusIcon = "✗"
|
statusIcon = "✗"
|
||||||
}
|
}
|
||||||
fmt.Printf("%s %s [%s] %s %.2fms\n",
|
|
||||||
statusIcon, metric.Domain, metric.Protocol, metric.ResponseCode, metric.DurationMs)
|
keepAliveIndicator := ""
|
||||||
|
if metric.KeepAlive {
|
||||||
|
keepAliveIndicator = "⟷"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%s %s%s [%s] %s %.2fms\n",
|
||||||
|
statusIcon, metric.Domain, keepAliveIndicator, metric.Protocol, metric.ResponseCode, metric.DurationMs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MeasurementRunner) readDomainsFile() ([]string, error) {
|
func (r *MeasurementRunner) readDomainsFile() ([]string, error) {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type DNSMetric struct {
|
|||||||
QueryType string `json:"query_type"`
|
QueryType string `json:"query_type"`
|
||||||
Protocol string `json:"protocol"`
|
Protocol string `json:"protocol"`
|
||||||
DNSSEC bool `json:"dnssec"`
|
DNSSEC bool `json:"dnssec"`
|
||||||
|
KeepAlive bool `json:"keep_alive"`
|
||||||
DNSServer string `json:"dns_server"`
|
DNSServer string `json:"dns_server"`
|
||||||
Timestamp time.Time `json:"timestamp"`
|
Timestamp time.Time `json:"timestamp"`
|
||||||
Duration int64 `json:"duration_ns"`
|
Duration int64 `json:"duration_ns"`
|
||||||
@@ -22,6 +23,7 @@ type DNSMetric struct {
|
|||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rest stays exactly the same
|
||||||
type MetricsWriter struct {
|
type MetricsWriter struct {
|
||||||
encoder *json.Encoder
|
encoder *json.Encoder
|
||||||
file *os.File
|
file *os.File
|
||||||
|
|||||||
@@ -8,17 +8,18 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateOutputPaths(outputDir, upstream string, dnssec bool) (jsonPath, pcapPath string) {
|
func GenerateOutputPaths(outputDir, upstream string, dnssec, keepAlive bool) (jsonPath, pcapPath string) {
|
||||||
proto := DetectProtocol(upstream)
|
proto := DetectProtocol(upstream)
|
||||||
serverName := ExtractServerName(upstream)
|
serverName := ExtractServerName(upstream)
|
||||||
ts := time.Now().Format("20060102_1504")
|
ts := time.Now().Format("20060102_1504")
|
||||||
dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec]
|
dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec]
|
||||||
|
keepAliveStr := map[bool]string{true: "on", false: "off"}[keepAlive]
|
||||||
|
|
||||||
base := fmt.Sprintf("%s_%s_dnssec_%s_%s",
|
base := fmt.Sprintf("%s_%s_dnssec_%s_keepalive_%s_%s",
|
||||||
proto, sanitize(serverName), dnssecStr, ts)
|
proto, sanitize(serverName), dnssecStr, keepAliveStr, ts)
|
||||||
|
|
||||||
return filepath.Join(outputDir, base+".jsonl"),
|
return filepath.Join(outputDir, base+".jsonl"),
|
||||||
filepath.Join(outputDir, base+".pcap")
|
filepath.Join(outputDir, base+".pcap")
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitize(s string) string {
|
func sanitize(s string) string {
|
||||||
|
|||||||
544
server/server.go
544
server/server.go
@@ -1,153 +1,15 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/afonsofrancof/sdns-proxy/client"
|
|
||||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Address string
|
Address string
|
||||||
Upstream string
|
Upstream string
|
||||||
Fallback string
|
Fallback string
|
||||||
Bootstrap string
|
Bootstrap string
|
||||||
DNSSEC bool
|
DNSSEC bool
|
||||||
|
KeepAlive bool
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
Verbose bool
|
Verbose bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type cacheKey struct {
|
// Update the initClients method:
|
||||||
domain string
|
|
||||||
qtype uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type cacheEntry struct {
|
|
||||||
records []dns.RR
|
|
||||||
expiresAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
config Config
|
|
||||||
upstreamClient client.DNSClient
|
|
||||||
fallbackClient client.DNSClient
|
|
||||||
bootstrapClient client.DNSClient
|
|
||||||
resolvedHosts map[string]string
|
|
||||||
queryCache map[cacheKey]*cacheEntry
|
|
||||||
hostsMutex sync.RWMutex
|
|
||||||
cacheMutex sync.RWMutex
|
|
||||||
dnsServer *dns.Server
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(config Config) (*Server, error) {
|
|
||||||
logger.Debug("Creating new server with config: %+v", config)
|
|
||||||
|
|
||||||
if config.Upstream == "" {
|
|
||||||
logger.Error("Upstream server is required")
|
|
||||||
return nil, fmt.Errorf("upstream server is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we need bootstrap server
|
|
||||||
needsBootstrap := containsHostname(config.Upstream)
|
|
||||||
if config.Fallback != "" {
|
|
||||||
needsBootstrap = needsBootstrap || containsHostname(config.Fallback)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Bootstrap needed: %v (upstream has hostname: %v, fallback has hostname: %v)",
|
|
||||||
needsBootstrap, containsHostname(config.Upstream),
|
|
||||||
config.Fallback != "" && containsHostname(config.Fallback))
|
|
||||||
|
|
||||||
if needsBootstrap && config.Bootstrap == "" {
|
|
||||||
logger.Error("Bootstrap server is required when upstream or fallback contains hostnames")
|
|
||||||
return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames")
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.Bootstrap != "" && containsHostname(config.Bootstrap) {
|
|
||||||
logger.Error("Bootstrap server cannot contain hostnames: %s", config.Bootstrap)
|
|
||||||
return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap)
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Server{
|
|
||||||
config: config,
|
|
||||||
resolvedHosts: make(map[string]string),
|
|
||||||
queryCache: make(map[cacheKey]*cacheEntry),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create bootstrap client if needed
|
|
||||||
if config.Bootstrap != "" {
|
|
||||||
logger.Debug("Creating bootstrap client for %s", config.Bootstrap)
|
|
||||||
bootstrapClient, err := client.New(config.Bootstrap, client.Options{
|
|
||||||
DNSSEC: false,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to create bootstrap client: %v", err)
|
|
||||||
return nil, fmt.Errorf("failed to create bootstrap client: %w", err)
|
|
||||||
}
|
|
||||||
s.bootstrapClient = bootstrapClient
|
|
||||||
logger.Debug("Bootstrap client created successfully")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize upstream and fallback clients
|
|
||||||
if err := s.initClients(); err != nil {
|
|
||||||
logger.Error("Failed to initialize clients: %v", err)
|
|
||||||
return nil, fmt.Errorf("failed to initialize clients: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup DNS server
|
|
||||||
mux := dns.NewServeMux()
|
|
||||||
mux.HandleFunc(".", s.handleDNSRequest)
|
|
||||||
|
|
||||||
s.dnsServer = &dns.Server{
|
|
||||||
Addr: config.Address,
|
|
||||||
Net: "udp",
|
|
||||||
Handler: mux,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Server created successfully, listening on %s", config.Address)
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsHostname(serverAddr string) bool {
|
|
||||||
logger.Debug("Checking if %s contains hostname", serverAddr)
|
|
||||||
|
|
||||||
// Use the same parsing logic as the client package
|
|
||||||
parsedURL, err := url.Parse(serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug("URL parsing failed for %s, treating as plain address", serverAddr)
|
|
||||||
// If URL parsing fails, assume it's a plain address
|
|
||||||
host, _, err := net.SplitHostPort(serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
// Assume it's just a host
|
|
||||||
isHostname := net.ParseIP(serverAddr) == nil
|
|
||||||
logger.Debug("Address %s is hostname: %v", serverAddr, isHostname)
|
|
||||||
return isHostname
|
|
||||||
}
|
|
||||||
isHostname := net.ParseIP(host) == nil
|
|
||||||
logger.Debug("Host %s from %s is hostname: %v", host, serverAddr, isHostname)
|
|
||||||
return isHostname
|
|
||||||
}
|
|
||||||
|
|
||||||
host := parsedURL.Hostname()
|
|
||||||
if host == "" {
|
|
||||||
logger.Debug("No hostname found in URL %s", serverAddr)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
isHostname := net.ParseIP(host) == nil
|
|
||||||
logger.Debug("Host %s from URL %s is hostname: %v", host, serverAddr, isHostname)
|
|
||||||
return isHostname
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) initClients() error {
|
func (s *Server) initClients() error {
|
||||||
logger.Debug("Initializing DNS clients")
|
logger.Debug("Initializing DNS clients")
|
||||||
|
|
||||||
@@ -160,7 +22,8 @@ func (s *Server) initClients() error {
|
|||||||
|
|
||||||
logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream)
|
logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream)
|
||||||
upstreamClient, err := client.New(resolvedUpstream, client.Options{
|
upstreamClient, err := client.New(resolvedUpstream, client.Options{
|
||||||
DNSSEC: s.config.DNSSEC,
|
DNSSEC: s.config.DNSSEC,
|
||||||
|
KeepAlive: s.config.KeepAlive,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to create upstream client: %v", err)
|
logger.Error("Failed to create upstream client: %v", err)
|
||||||
@@ -169,7 +32,7 @@ func (s *Server) initClients() error {
|
|||||||
s.upstreamClient = upstreamClient
|
s.upstreamClient = upstreamClient
|
||||||
|
|
||||||
if s.config.Verbose {
|
if s.config.Verbose {
|
||||||
logger.Info("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream)
|
logger.Info("Initialized upstream client: %s -> %s (KeepAlive: %v)", s.config.Upstream, resolvedUpstream, s.config.KeepAlive)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize fallback client if specified
|
// Initialize fallback client if specified
|
||||||
@@ -182,7 +45,8 @@ func (s *Server) initClients() error {
|
|||||||
|
|
||||||
logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback)
|
logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback)
|
||||||
fallbackClient, err := client.New(resolvedFallback, client.Options{
|
fallbackClient, err := client.New(resolvedFallback, client.Options{
|
||||||
DNSSEC: s.config.DNSSEC,
|
DNSSEC: s.config.DNSSEC,
|
||||||
|
KeepAlive: s.config.KeepAlive,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to create fallback client: %v", err)
|
logger.Error("Failed to create fallback client: %v", err)
|
||||||
@@ -191,402 +55,10 @@ func (s *Server) initClients() error {
|
|||||||
s.fallbackClient = fallbackClient
|
s.fallbackClient = fallbackClient
|
||||||
|
|
||||||
if s.config.Verbose {
|
if s.config.Verbose {
|
||||||
logger.Info("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback)
|
logger.Info("Initialized fallback client: %s -> %s (KeepAlive: %v)", s.config.Fallback, resolvedFallback, s.config.KeepAlive)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("All DNS clients initialized successfully")
|
logger.Debug("All DNS clients initialized successfully")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) resolveServerAddress(serverAddr string) (string, error) {
|
|
||||||
logger.Debug("Resolving server address: %s", serverAddr)
|
|
||||||
|
|
||||||
// If it doesn't contain hostnames, return as-is
|
|
||||||
if !containsHostname(serverAddr) {
|
|
||||||
logger.Debug("Address %s contains no hostnames, returning as-is", serverAddr)
|
|
||||||
return serverAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no bootstrap client, we can't resolve hostnames
|
|
||||||
if s.bootstrapClient == nil {
|
|
||||||
logger.Error("Cannot resolve hostname in %s: no bootstrap server configured", serverAddr)
|
|
||||||
return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the same parsing logic as the client package
|
|
||||||
parsedURL, err := url.Parse(serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug("Parsing %s as plain host:port format", serverAddr)
|
|
||||||
// Handle plain host:port format
|
|
||||||
host, port, err := net.SplitHostPort(serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
// Assume it's just a hostname
|
|
||||||
resolvedIP, err := s.resolveHostname(serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
logger.Debug("Resolved %s to %s", serverAddr, resolvedIP)
|
|
||||||
return resolvedIP, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvedIP, err := s.resolveHostname(host)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
resolved := net.JoinHostPort(resolvedIP, port)
|
|
||||||
logger.Debug("Resolved %s to %s", serverAddr, resolved)
|
|
||||||
return resolved, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle URL format
|
|
||||||
hostname := parsedURL.Hostname()
|
|
||||||
if hostname == "" {
|
|
||||||
logger.Error("No hostname in URL: %s", serverAddr)
|
|
||||||
return "", fmt.Errorf("no hostname in URL: %s", serverAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
resolvedIP, err := s.resolveHostname(hostname)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace hostname with IP in the URL
|
|
||||||
port := parsedURL.Port()
|
|
||||||
if port == "" {
|
|
||||||
parsedURL.Host = resolvedIP
|
|
||||||
} else {
|
|
||||||
parsedURL.Host = net.JoinHostPort(resolvedIP, port)
|
|
||||||
}
|
|
||||||
|
|
||||||
resolved := parsedURL.String()
|
|
||||||
logger.Debug("Resolved URL %s to %s", serverAddr, resolved)
|
|
||||||
return resolved, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) resolveHostname(hostname string) (string, error) {
|
|
||||||
logger.Debug("Resolving hostname: %s", hostname)
|
|
||||||
|
|
||||||
// Check cache first
|
|
||||||
s.hostsMutex.RLock()
|
|
||||||
if ip, exists := s.resolvedHosts[hostname]; exists {
|
|
||||||
s.hostsMutex.RUnlock()
|
|
||||||
logger.Debug("Found cached resolution for %s: %s", hostname, ip)
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
s.hostsMutex.RUnlock()
|
|
||||||
|
|
||||||
// Resolve using bootstrap
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Resolving hostname %s using bootstrap server", hostname)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.SetQuestion(dns.Fqdn(hostname), dns.TypeA)
|
|
||||||
msg.Id = dns.Id()
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
|
|
||||||
logger.Debug("Sending bootstrap query for %s (ID: %d)", hostname, msg.Id)
|
|
||||||
msg, err := s.bootstrapClient.Query(msg)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Bootstrap query failed for %s: %v", hostname, err)
|
|
||||||
return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Bootstrap response for %s: %d answers", hostname, len(msg.Answer))
|
|
||||||
if len(msg.Answer) == 0 {
|
|
||||||
logger.Error("No A records found for %s", hostname)
|
|
||||||
return "", fmt.Errorf("no A records found for %s", hostname)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find first A record
|
|
||||||
for _, rr := range msg.Answer {
|
|
||||||
if a, ok := rr.(*dns.A); ok {
|
|
||||||
ip := a.A.String()
|
|
||||||
|
|
||||||
// Cache the result
|
|
||||||
s.hostsMutex.Lock()
|
|
||||||
s.resolvedHosts[hostname] = ip
|
|
||||||
s.hostsMutex.Unlock()
|
|
||||||
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Resolved %s to %s", hostname, ip)
|
|
||||||
}
|
|
||||||
logger.Debug("Cached resolution: %s -> %s", hostname, ip)
|
|
||||||
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Error("No valid A record found for %s", hostname)
|
|
||||||
return "", fmt.Errorf("no valid A record found for %s", hostname)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
if len(r.Question) == 0 {
|
|
||||||
logger.Debug("Received request with no questions from %s", w.RemoteAddr())
|
|
||||||
dns.HandleFailed(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
question := r.Question[0]
|
|
||||||
domain := strings.ToLower(question.Name)
|
|
||||||
qtype := question.Qtype
|
|
||||||
|
|
||||||
logger.Debug("Handling DNS request: %s %s from %s (ID: %d)",
|
|
||||||
question.Name, dns.TypeToString[qtype], w.RemoteAddr(), r.Id)
|
|
||||||
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Query: %s %s from %s",
|
|
||||||
question.Name,
|
|
||||||
dns.TypeToString[qtype],
|
|
||||||
w.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check cache first
|
|
||||||
if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil {
|
|
||||||
response := s.buildResponse(r, cachedRecords)
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Cache hit: %s %s -> %d records",
|
|
||||||
question.Name,
|
|
||||||
dns.TypeToString[qtype],
|
|
||||||
len(cachedRecords))
|
|
||||||
}
|
|
||||||
logger.Debug("Serving cached response for %s %s (%d records)",
|
|
||||||
question.Name, dns.TypeToString[qtype], len(cachedRecords))
|
|
||||||
w.WriteMsg(response)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Cache miss for %s %s, querying upstream", question.Name, dns.TypeToString[qtype])
|
|
||||||
|
|
||||||
// Try upstream first
|
|
||||||
response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype)
|
|
||||||
if err != nil {
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Upstream query failed: %v", err)
|
|
||||||
}
|
|
||||||
logger.Debug("Upstream query failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err)
|
|
||||||
|
|
||||||
// Try fallback if available
|
|
||||||
if s.fallbackClient != nil {
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Trying fallback server")
|
|
||||||
}
|
|
||||||
logger.Debug("Attempting fallback query for %s %s", question.Name, dns.TypeToString[qtype])
|
|
||||||
|
|
||||||
response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Both upstream and fallback failed for %s %s: %v",
|
|
||||||
question.Name,
|
|
||||||
dns.TypeToString[qtype],
|
|
||||||
err)
|
|
||||||
} else {
|
|
||||||
logger.Debug("Fallback query succeeded for %s %s", question.Name, dns.TypeToString[qtype])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If still failed, return SERVFAIL
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("All servers failed for %s %s: %v",
|
|
||||||
question.Name,
|
|
||||||
dns.TypeToString[qtype],
|
|
||||||
err)
|
|
||||||
|
|
||||||
m := new(dns.Msg)
|
|
||||||
m.SetReply(r)
|
|
||||||
m.Rcode = dns.RcodeServerFailure
|
|
||||||
w.WriteMsg(m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Debug("Upstream query succeeded for %s %s", question.Name, dns.TypeToString[qtype])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cache successful response
|
|
||||||
s.cacheResponse(domain, qtype, response)
|
|
||||||
|
|
||||||
// Copy request ID to response
|
|
||||||
response.Id = r.Id
|
|
||||||
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Response: %s %s -> %d answers",
|
|
||||||
question.Name,
|
|
||||||
dns.TypeToString[qtype],
|
|
||||||
len(response.Answer))
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Sending response for %s %s: %d answers, rcode: %s",
|
|
||||||
question.Name, dns.TypeToString[qtype], len(response.Answer), dns.RcodeToString[response.Rcode])
|
|
||||||
w.WriteMsg(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR {
|
|
||||||
key := cacheKey{domain: domain, qtype: qtype}
|
|
||||||
|
|
||||||
s.cacheMutex.RLock()
|
|
||||||
entry, exists := s.queryCache[key]
|
|
||||||
s.cacheMutex.RUnlock()
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
logger.Debug("No cache entry for %s %s", domain, dns.TypeToString[qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if expired and clean up on the spot
|
|
||||||
if time.Now().After(entry.expiresAt) {
|
|
||||||
logger.Debug("Cache entry expired for %s %s", domain, dns.TypeToString[qtype])
|
|
||||||
s.cacheMutex.Lock()
|
|
||||||
delete(s.queryCache, key)
|
|
||||||
s.cacheMutex.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("Cache hit for %s %s (%d records, expires in %v)",
|
|
||||||
domain, dns.TypeToString[qtype], len(entry.records), time.Until(entry.expiresAt))
|
|
||||||
|
|
||||||
// Return a copy of the cached records
|
|
||||||
records := make([]dns.RR, len(entry.records))
|
|
||||||
for i, rr := range entry.records {
|
|
||||||
records[i] = dns.Copy(rr)
|
|
||||||
}
|
|
||||||
return records
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg {
|
|
||||||
response := new(dns.Msg)
|
|
||||||
response.SetReply(request)
|
|
||||||
response.Answer = records
|
|
||||||
logger.Debug("Built response with %d records", len(records))
|
|
||||||
return response
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) {
|
|
||||||
if msg == nil || len(msg.Answer) == 0 {
|
|
||||||
logger.Debug("Not caching empty response for %s %s", domain, dns.TypeToString[qtype])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var validRecords []dns.RR
|
|
||||||
minTTL := uint32(3600)
|
|
||||||
|
|
||||||
// Find minimum TTL from answer records
|
|
||||||
for _, rr := range msg.Answer {
|
|
||||||
// Only cache records that match our query type or are CNAMEs
|
|
||||||
if rr.Header().Rrtype == qtype || rr.Header().Rrtype == dns.TypeCNAME {
|
|
||||||
validRecords = append(validRecords, dns.Copy(rr))
|
|
||||||
if rr.Header().Ttl < minTTL {
|
|
||||||
minTTL = rr.Header().Ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(validRecords) == 0 {
|
|
||||||
logger.Debug("No valid records to cache for %s %s", domain, dns.TypeToString[qtype])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Don't cache responses with very low TTL
|
|
||||||
if minTTL < 10 {
|
|
||||||
logger.Debug("TTL too low (%ds) for caching %s %s", minTTL, domain, dns.TypeToString[qtype])
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
key := cacheKey{domain: domain, qtype: qtype}
|
|
||||||
entry := &cacheEntry{
|
|
||||||
records: validRecords,
|
|
||||||
expiresAt: time.Now().Add(time.Duration(minTTL) * time.Second),
|
|
||||||
}
|
|
||||||
|
|
||||||
s.cacheMutex.Lock()
|
|
||||||
s.queryCache[key] = entry
|
|
||||||
s.cacheMutex.Unlock()
|
|
||||||
|
|
||||||
if s.config.Verbose {
|
|
||||||
logger.Info("Cached %d records for %s %s (TTL: %ds)",
|
|
||||||
len(validRecords), domain, dns.TypeToString[qtype], minTTL)
|
|
||||||
}
|
|
||||||
logger.Debug("Cached %d records for %s %s (TTL: %ds, expires: %v)",
|
|
||||||
len(validRecords), domain, dns.TypeToString[qtype], minTTL, entry.expiresAt)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) {
|
|
||||||
logger.Debug("Querying upstream for %s %s", domain, dns.TypeToString[qtype])
|
|
||||||
|
|
||||||
// Create context with timeout
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Channel to receive result
|
|
||||||
type result struct {
|
|
||||||
msg *dns.Msg
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
resultChan := make(chan result, 1)
|
|
||||||
|
|
||||||
// Query in goroutine to respect context timeout
|
|
||||||
go func() {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.SetQuestion(dns.Fqdn(domain), qtype)
|
|
||||||
msg.Id = dns.Id()
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
|
|
||||||
logger.Debug("Sending upstream query: %s %s (ID: %d)", domain, dns.TypeToString[qtype], msg.Id)
|
|
||||||
recvMsg, err := upstreamClient.Query(msg)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug("Upstream query error for %s %s: %v", domain, dns.TypeToString[qtype], err)
|
|
||||||
} else {
|
|
||||||
logger.Debug("Upstream query response for %s %s: %d answers, rcode: %s",
|
|
||||||
domain, dns.TypeToString[qtype], len(recvMsg.Answer), dns.RcodeToString[recvMsg.Rcode])
|
|
||||||
}
|
|
||||||
resultChan <- result{msg: recvMsg, err: err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case res := <-resultChan:
|
|
||||||
return res.msg, res.err
|
|
||||||
case <-ctx.Done():
|
|
||||||
logger.Debug("Upstream query timeout for %s %s after %v", domain, dns.TypeToString[qtype], s.config.Timeout)
|
|
||||||
return nil, fmt.Errorf("upstream query timeout")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Start() error {
|
|
||||||
go func() {
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
sig := <-sigChan
|
|
||||||
logger.Info("Received signal %v, shutting down DNS server...", sig)
|
|
||||||
s.Shutdown()
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Info("DNS proxy server listening on %s", s.config.Address)
|
|
||||||
logger.Debug("Server starting with timeout: %v, DNSSEC: %v", s.config.Timeout, s.config.DNSSEC)
|
|
||||||
return s.dnsServer.ListenAndServe()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Shutdown() {
|
|
||||||
logger.Debug("Shutting down server components")
|
|
||||||
|
|
||||||
if s.dnsServer != nil {
|
|
||||||
logger.Debug("Shutting down DNS server")
|
|
||||||
s.dnsServer.Shutdown()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.upstreamClient != nil {
|
|
||||||
logger.Debug("Closing upstream client")
|
|
||||||
s.upstreamClient.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.fallbackClient != nil {
|
|
||||||
logger.Debug("Closing fallback client")
|
|
||||||
s.fallbackClient.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.bootstrapClient != nil {
|
|
||||||
logger.Debug("Closing bootstrap client")
|
|
||||||
s.bootstrapClient.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Server shutdown complete")
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user