feat(keep-alive): add keep-alive option
This commit is contained in:
@@ -22,12 +22,13 @@ import (
|
||||
const dnsMessageContentType = "application/dns-message"
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
Path string
|
||||
DNSSEC bool
|
||||
HTTP3 bool
|
||||
HTTP2 bool
|
||||
Host string
|
||||
Port string
|
||||
Path string
|
||||
DNSSEC bool
|
||||
HTTP3 bool
|
||||
HTTP2 bool
|
||||
KeepAlive bool
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -37,7 +38,7 @@ type Client struct {
|
||||
}
|
||||
|
||||
func New(config Config) (*Client, error) {
|
||||
logger.Debug("Creating DoH client: %s:%s%s", config.Host, config.Port, config.Path)
|
||||
logger.Debug("Creating DoH client: %s:%s%s (KeepAlive: %v)", config.Host, config.Port, config.Path, config.KeepAlive)
|
||||
|
||||
if config.Host == "" || config.Port == "" || config.Path == "" {
|
||||
logger.Error("DoH client creation failed: missing required fields")
|
||||
@@ -65,31 +66,58 @@ func New(config Config) (*Client, error) {
|
||||
DisablePathMTUDiscovery: true,
|
||||
}
|
||||
|
||||
transport := http.DefaultTransport.(*http.Transport)
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
// Configure connection pooling based on KeepAlive setting
|
||||
if config.KeepAlive {
|
||||
transport.MaxIdleConnsPerHost = 10
|
||||
transport.MaxConnsPerHost = 10
|
||||
} else {
|
||||
transport.MaxIdleConns = 0
|
||||
transport.MaxIdleConnsPerHost = 0
|
||||
transport.DisableKeepAlives = true
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
var transportType string
|
||||
if config.HTTP2 {
|
||||
httpClient.Transport = &http2.Transport{
|
||||
http2Transport := &http2.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
AllowHTTP: true,
|
||||
}
|
||||
|
||||
if !config.KeepAlive {
|
||||
http2Transport.DisableCompression = true
|
||||
}
|
||||
|
||||
httpClient.Transport = http2Transport
|
||||
transportType = "HTTP/2"
|
||||
} else if config.HTTP3 {
|
||||
quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig)
|
||||
httpClient.Transport = &http3.Transport{
|
||||
http3Transport := &http3.Transport{
|
||||
TLSClientConfig: quicTlsConfig,
|
||||
QUICConfig: quicConfig,
|
||||
}
|
||||
|
||||
if !config.KeepAlive {
|
||||
http3Transport.DisableCompression = true
|
||||
}
|
||||
|
||||
httpClient.Transport = http3Transport
|
||||
transportType = "HTTP/3"
|
||||
} else {
|
||||
transportType = "HTTP/1.1"
|
||||
}
|
||||
|
||||
logger.Debug("DoH client created: %s (%s, DNSSEC: %v)", rawURL, transportType, config.DNSSEC)
|
||||
logger.Debug("DoH client created: %s (%s, DNSSEC: %v, KeepAlive: %v)", rawURL, transportType, config.DNSSEC, config.KeepAlive)
|
||||
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
@@ -102,6 +130,8 @@ func (c *Client) Close() {
|
||||
logger.Debug("Closing DoH client")
|
||||
if t, ok := c.httpClient.Transport.(*http.Transport); ok {
|
||||
t.CloseIdleConnections()
|
||||
} else if t2, ok := c.httpClient.Transport.(*http2.Transport); ok {
|
||||
t2.CloseIdleConnections()
|
||||
} else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok {
|
||||
t3.CloseIdleConnections()
|
||||
}
|
||||
@@ -132,6 +162,13 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
httpReq.Header.Set("User-Agent", "sdns-proxy")
|
||||
httpReq.Header.Set("Content-Type", dnsMessageContentType)
|
||||
httpReq.Header.Set("Accept", dnsMessageContentType)
|
||||
|
||||
// Set Connection header based on KeepAlive setting
|
||||
if c.config.KeepAlive {
|
||||
httpReq.Header.Set("Connection", "keep-alive")
|
||||
} else {
|
||||
httpReq.Header.Set("Connection", "close")
|
||||
}
|
||||
|
||||
httpResp, err := c.httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
||||
@@ -16,10 +17,11 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
Debug bool
|
||||
DNSSEC bool
|
||||
Host string
|
||||
Port string
|
||||
Debug bool
|
||||
DNSSEC bool
|
||||
KeepAlive bool
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
@@ -30,10 +32,11 @@ type Client struct {
|
||||
quicTransport *quic.Transport
|
||||
quicConfig *quic.Config
|
||||
config Config
|
||||
connMutex sync.Mutex
|
||||
}
|
||||
|
||||
func New(config Config) (*Client, error) {
|
||||
logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port)
|
||||
logger.Debug("Creating DoQ client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: config.Host,
|
||||
@@ -62,9 +65,7 @@ func New(config Config) (*Client, error) {
|
||||
MaxIdleTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC)
|
||||
|
||||
return &Client{
|
||||
client := &Client{
|
||||
targetAddr: targetAddr,
|
||||
tlsConfig: tlsConfig,
|
||||
udpConn: udpConn,
|
||||
@@ -72,18 +73,52 @@ func New(config Config) (*Client, error) {
|
||||
quicTransport: &quicTransport,
|
||||
quicConfig: &quicConfig,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// If keep-alive is enabled, establish connection now
|
||||
if config.KeepAlive {
|
||||
if err := client.ensureConnection(); err != nil {
|
||||
logger.Error("DoQ failed to establish initial connection: %v", err)
|
||||
client.Close()
|
||||
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v, KeepAlive: %v)", config.Host, config.Port, config.DNSSEC, config.KeepAlive)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
logger.Debug("Closing DoQ client")
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
if c.quicConn != nil {
|
||||
c.quicConn.CloseWithError(0, "client shutdown")
|
||||
c.quicConn = nil
|
||||
}
|
||||
if c.udpConn != nil {
|
||||
c.udpConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) OpenConnection() error {
|
||||
logger.Debug("Opening DoQ connection to %s", c.targetAddr)
|
||||
func (c *Client) ensureConnection() error {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
// Check if existing connection is still valid
|
||||
if c.quicConn != nil {
|
||||
select {
|
||||
case <-c.quicConn.Context().Done():
|
||||
logger.Debug("DoQ connection closed, reconnecting")
|
||||
c.quicConn = nil
|
||||
default:
|
||||
// Connection is still valid
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Establishing DoQ connection to %s", c.targetAddr)
|
||||
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
|
||||
if err != nil {
|
||||
logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err)
|
||||
@@ -101,10 +136,19 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr)
|
||||
}
|
||||
|
||||
if c.quicConn == nil {
|
||||
err := c.OpenConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Ensure we have a connection (either persistent or new)
|
||||
if c.config.KeepAlive {
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("doq: failed to ensure connection: %w", err)
|
||||
}
|
||||
} else {
|
||||
// For non-keepalive mode, create a fresh connection for each query
|
||||
c.connMutex.Lock()
|
||||
c.quicConn = nil // Force new connection
|
||||
c.connMutex.Unlock()
|
||||
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("doq: failed to create connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,18 +163,33 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
return nil, fmt.Errorf("doq: failed to pack message: %w", err)
|
||||
}
|
||||
|
||||
var quicStream quic.Stream
|
||||
quicStream, err = c.quicConn.OpenStream()
|
||||
// Open a stream for this query
|
||||
c.connMutex.Lock()
|
||||
quicConn := c.quicConn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
quicStream, err := quicConn.OpenStream()
|
||||
if err != nil {
|
||||
logger.Debug("DoQ stream failed, reconnecting: %v", err)
|
||||
err = c.OpenConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
quicStream, err = c.quicConn.OpenStream()
|
||||
if err != nil {
|
||||
logger.Error("DoQ failed to open stream after reconnect: %v", err)
|
||||
return nil, err
|
||||
logger.Error("DoQ failed to open stream: %v", err)
|
||||
|
||||
// If keep-alive is enabled, try to reconnect once
|
||||
if c.config.KeepAlive {
|
||||
logger.Debug("DoQ stream failed with keep-alive, attempting reconnect")
|
||||
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
||||
return nil, fmt.Errorf("doq: failed to reconnect: %w", reconnectErr)
|
||||
}
|
||||
|
||||
c.connMutex.Lock()
|
||||
quicConn = c.quicConn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
quicStream, err = quicConn.OpenStream()
|
||||
if err != nil {
|
||||
logger.Error("DoQ failed to open stream after reconnect: %v", err)
|
||||
return nil, fmt.Errorf("doq: failed to open stream after reconnect: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("doq: failed to open stream: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,5 +243,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer))
|
||||
}
|
||||
|
||||
// Close the connection if not using keep-alive
|
||||
if !c.config.KeepAlive {
|
||||
c.connMutex.Lock()
|
||||
if c.quicConn != nil {
|
||||
c.quicConn.CloseWithError(0, "query complete")
|
||||
c.quicConn = nil
|
||||
}
|
||||
c.connMutex.Unlock()
|
||||
}
|
||||
|
||||
return recvMsg, nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
||||
@@ -16,6 +17,7 @@ type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
DNSSEC bool
|
||||
KeepAlive bool
|
||||
WriteTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
Debug bool
|
||||
@@ -25,10 +27,12 @@ type Client struct {
|
||||
hostAndPort string
|
||||
tlsConfig *tls.Config
|
||||
config Config
|
||||
conn *tls.Conn
|
||||
connMutex sync.Mutex
|
||||
}
|
||||
|
||||
func New(config Config) (*Client, error) {
|
||||
logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port)
|
||||
logger.Debug("Creating DoT client: %s:%s (KeepAlive: %v)", config.Host, config.Port, config.KeepAlive)
|
||||
|
||||
if config.Host == "" {
|
||||
logger.Error("DoT client creation failed: empty host")
|
||||
@@ -47,33 +51,76 @@ func New(config Config) (*Client, error) {
|
||||
ServerName: config.Host,
|
||||
}
|
||||
|
||||
logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC)
|
||||
|
||||
return &Client{
|
||||
client := &Client{
|
||||
hostAndPort: hostAndPort,
|
||||
tlsConfig: tlsConfig,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// If keep-alive is enabled, establish connection now
|
||||
if config.KeepAlive {
|
||||
if err := client.ensureConnection(); err != nil {
|
||||
logger.Error("DoT failed to establish initial connection: %v", err)
|
||||
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("DoT client created: %s (DNSSEC: %v, KeepAlive: %v)", hostAndPort, config.DNSSEC, config.KeepAlive)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() {
|
||||
logger.Debug("Closing DoT client")
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) createConnection() (*tls.Conn, error) {
|
||||
func (c *Client) ensureConnection() error {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
// Check if existing connection is still valid
|
||||
if c.conn != nil {
|
||||
// Test the connection with a very short deadline
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err == nil {
|
||||
// Try to read one byte to test connection
|
||||
var testBuf [1]byte
|
||||
_, err := c.conn.Read(testBuf[:])
|
||||
|
||||
// Reset deadline
|
||||
c.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// If we get a timeout error, connection is still good
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Any other error means connection is dead
|
||||
logger.Debug("DoT connection test failed, reconnecting: %v", err)
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: c.config.WriteTimeout,
|
||||
}
|
||||
|
||||
logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig)
|
||||
if err != nil {
|
||||
logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
logger.Debug("DoT connection established to %s", c.hostAndPort)
|
||||
return conn, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
@@ -82,12 +129,24 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
|
||||
}
|
||||
|
||||
// Create connection for this query
|
||||
conn, err := c.createConnection()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
|
||||
// Ensure we have a connection (either persistent or new)
|
||||
if c.config.KeepAlive {
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to ensure connection: %w", err)
|
||||
}
|
||||
} else {
|
||||
// For non-keepalive mode, create a fresh connection for each query
|
||||
c.connMutex.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMutex.Unlock()
|
||||
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to create connection: %w", err)
|
||||
}
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Prepare DNS message
|
||||
if c.config.DNSSEC {
|
||||
@@ -104,6 +163,10 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
binary.BigEndian.PutUint16(length, uint16(len(packed)))
|
||||
data := append(length, packed...)
|
||||
|
||||
c.connMutex.Lock()
|
||||
conn := c.conn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
// Write query
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
logger.Error("DoT failed to set write deadline: %v", err)
|
||||
@@ -112,7 +175,28 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err)
|
||||
return nil, fmt.Errorf("dot: failed to write message: %w", err)
|
||||
|
||||
// If keep-alive is enabled and write failed, try to reconnect once
|
||||
if c.config.KeepAlive {
|
||||
logger.Debug("DoT write failed with keep-alive, attempting reconnect")
|
||||
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
||||
return nil, fmt.Errorf("dot: failed to reconnect: %w", reconnectErr)
|
||||
}
|
||||
|
||||
c.connMutex.Lock()
|
||||
conn = c.conn
|
||||
c.connMutex.Unlock()
|
||||
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to set write deadline after reconnect: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("dot: failed to write message after reconnect: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("dot: failed to write message: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Read response
|
||||
@@ -152,5 +236,15 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer))
|
||||
}
|
||||
|
||||
// Close the connection if not using keep-alive
|
||||
if !c.config.KeepAlive {
|
||||
c.connMutex.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.connMutex.Unlock()
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user