feat(keep-alive): add keep-alive option
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
||||
@@ -16,6 +17,7 @@ type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
DNSSEC bool
|
||||
KeepAlive bool
|
||||
WriteTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
Debug bool
|
||||
@@ -25,10 +27,12 @@ type Client struct {
|
||||
hostAndPort string
|
||||
tlsConfig *tls.Config
|
||||
config Config
|
||||
conn *tls.Conn
|
||||
connMutex sync.Mutex
|
||||
}
|
||||
|
||||
func New(config Config) (*Client, error) {
|
||||
logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port)
|
||||
logger.Debug("Creating DoT client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
|
||||
|
||||
if config.Host == "" {
|
||||
logger.Error("DoT client creation failed: empty host")
|
||||
@@ -47,33 +51,76 @@ func New(config Config) (*Client, error) {
|
||||
ServerName: config.Host,
|
||||
}
|
||||
|
||||
logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC)
|
||||
|
||||
return &Client{
|
||||
client := &Client{
|
||||
hostAndPort: hostAndPort,
|
||||
tlsConfig: tlsConfig,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// If keep-alive is enabled, establish connection now
|
||||
if config.KeepAlive {
|
||||
if err := client.ensureConnection(); err != nil {
|
||||
logger.Error("DoT failed to establish initial connection: %v", err)
|
||||
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("DoT client created: %s (DNSSEC: %v, KeepAlive: %v)", hostAndPort, config.DNSSEC, config.KeepAlive)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
logger.Debug("Closing DoT client")
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) createConnection() (*tls.Conn, error) {
|
||||
func (c *Client) ensureConnection() error {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
// Check if existing connection is still valid
|
||||
if c.conn != nil {
|
||||
// Test the connection with a very short deadline
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err == nil {
|
||||
// Try to read one byte to test connection
|
||||
var testBuf [1]byte
|
||||
_, err := c.conn.Read(testBuf[:])
|
||||
|
||||
// Reset deadline
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// If we get a timeout error, connection is still good
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Any other error means connection is dead
|
||||
logger.Debug("DoT connection test failed, reconnecting: %v", err)
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: c.config.WriteTimeout,
|
||||
}
|
||||
|
||||
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig)
|
||||
if err != nil {
|
||||
logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
logger.Debug("DoT connection established to %s", c.hostAndPort)
|
||||
return conn, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
@@ -82,12 +129,24 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
|
||||
}
|
||||
|
||||
// Create connection for this query
|
||||
conn, err := c.createConnection()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
|
||||
// Ensure we have a connection (either persistent or new)
|
||||
if c.config.KeepAlive {
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to ensure connection: %w", err)
|
||||
}
|
||||
} else {
|
||||
// For non-keepalive mode, create a fresh connection for each query
|
||||
c.connMutex.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMutex.Unlock()
|
||||
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Prepare DNS message
|
||||
if c.config.DNSSEC {
|
||||
@@ -104,6 +163,10 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
binary.BigEndian.PutUint16(length, uint16(len(packed)))
|
||||
data := append(length, packed...)
|
||||
|
||||
c.connMutex.Lock()
|
||||
conn := c.conn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
// Write query
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
logger.Error("DoT failed to set write deadline: %v", err)
|
||||
@@ -112,7 +175,28 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err)
|
||||
return nil, fmt.Errorf("dot: failed to write message: %w", err)
|
||||
|
||||
// If keep-alive is enabled and write failed, try to reconnect once
|
||||
if c.config.KeepAlive {
|
||||
logger.Debug("DoT write failed with keep-alive, attempting reconnect")
|
||||
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
||||
return nil, fmt.Errorf("dot: failed to reconnect: %w", reconnectErr)
|
||||
}
|
||||
|
||||
c.connMutex.Lock()
|
||||
conn = c.conn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to set write deadline after reconnect: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to write message after reconnect: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("dot: failed to write message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read response
|
||||
@@ -152,5 +236,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer))
|
||||
}
|
||||
|
||||
// Close the connection if not using keep-alive
|
||||
if !c.config.KeepAlive {
|
||||
c.connMutex.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMutex.Unlock()
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user