feat(keep-alive): add keep-alive option
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
||||
@@ -16,10 +17,11 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
Debug bool
|
||||
DNSSEC bool
|
||||
Host string
|
||||
Port string
|
||||
Debug bool
|
||||
DNSSEC bool
|
||||
KeepAlive bool
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -30,10 +32,11 @@ type Client struct {
|
||||
quicTransport *quic.Transport
|
||||
quicConfig *quic.Config
|
||||
config Config
|
||||
connMutex sync.Mutex
|
||||
}
|
||||
|
||||
func New(config Config) (*Client, error) {
|
||||
logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port)
|
||||
logger.Debug("Creating DoQ client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: config.Host,
|
||||
@@ -62,9 +65,7 @@ func New(config Config) (*Client, error) {
|
||||
MaxIdleTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC)
|
||||
|
||||
return &Client{
|
||||
client := &Client{
|
||||
targetAddr: targetAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
udpConn: udpConn,
|
||||
@@ -72,18 +73,52 @@ func New(config Config) (*Client, error) {
|
||||
quicTransport: &quicTransport,
|
||||
quicConfig: &quicConfig,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// If keep-alive is enabled, establish connection now
|
||||
if config.KeepAlive {
|
||||
if err := client.ensureConnection(); err != nil {
|
||||
logger.Error("DoQ failed to establish initial connection: %v", err)
|
||||
client.Close()
|
||||
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v, KeepAlive: %v)", config.Host, config.Port, config.DNSSEC, config.KeepAlive)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
logger.Debug("Closing DoQ client")
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
if c.quicConn != nil {
|
||||
c.quicConn.CloseWithError(0, "client shutdown")
|
||||
c.quicConn = nil
|
||||
}
|
||||
if c.udpConn != nil {
|
||||
c.udpConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) OpenConnection() error {
|
||||
logger.Debug("Opening DoQ connection to %s", c.targetAddr)
|
||||
func (c *Client) ensureConnection() error {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
// Check if existing connection is still valid
|
||||
if c.quicConn != nil {
|
||||
select {
|
||||
case <-c.quicConn.Context().Done():
|
||||
logger.Debug("DoQ connection closed, reconnecting")
|
||||
c.quicConn = nil
|
||||
default:
|
||||
// Connection is still valid
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Establishing DoQ connection to %s", c.targetAddr)
|
||||
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
|
||||
if err != nil {
|
||||
logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err)
|
||||
@@ -101,10 +136,19 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr)
|
||||
}
|
||||
|
||||
if c.quicConn == nil {
|
||||
err := c.OpenConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Ensure we have a connection (either persistent or new)
|
||||
if c.config.KeepAlive {
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("doq: failed to ensure connection: %w", err)
|
||||
}
|
||||
} else {
|
||||
// For non-keepalive mode, create a fresh connection for each query
|
||||
c.connMutex.Lock()
|
||||
c.quicConn = nil // Force new connection
|
||||
c.connMutex.Unlock()
|
||||
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("doq: failed to create connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,18 +163,33 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
return nil, fmt.Errorf("doq: failed to pack message: %w", err)
|
||||
}
|
||||
|
||||
var quicStream quic.Stream
|
||||
quicStream, err = c.quicConn.OpenStream()
|
||||
// Open a stream for this query
|
||||
c.connMutex.Lock()
|
||||
quicConn := c.quicConn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
quicStream, err := quicConn.OpenStream()
|
||||
if err != nil {
|
||||
logger.Debug("DoQ stream failed, reconnecting: %v", err)
|
||||
err = c.OpenConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quicStream, err = c.quicConn.OpenStream()
|
||||
if err != nil {
|
||||
logger.Error("DoQ failed to open stream after reconnect: %v", err)
|
||||
return nil, err
|
||||
logger.Error("DoQ failed to open stream: %v", err)
|
||||
|
||||
// If keep-alive is enabled, try to reconnect once
|
||||
if c.config.KeepAlive {
|
||||
logger.Debug("DoQ stream failed with keep-alive, attempting reconnect")
|
||||
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
||||
return nil, fmt.Errorf("doq: failed to reconnect: %w", reconnectErr)
|
||||
}
|
||||
|
||||
c.connMutex.Lock()
|
||||
quicConn = c.quicConn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
quicStream, err = quicConn.OpenStream()
|
||||
if err != nil {
|
||||
logger.Error("DoQ failed to open stream after reconnect: %v", err)
|
||||
return nil, fmt.Errorf("doq: failed to open stream after reconnect: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("doq: failed to open stream: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,5 +243,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer))
|
||||
}
|
||||
|
||||
// Close the connection if not using keep-alive
|
||||
if !c.config.KeepAlive {
|
||||
c.connMutex.Lock()
|
||||
if c.quicConn != nil {
|
||||
c.quicConn.CloseWithError(0, "query complete")
|
||||
c.quicConn = nil
|
||||
}
|
||||
c.connMutex.Unlock()
|
||||
}
|
||||
|
||||
return recvMsg, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user