223 lines
5.8 KiB
Go
223 lines
5.8 KiB
Go
package dotcp
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/afonsofrancof/sdns-proxy/common/logger"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type Config struct {
|
|
HostAndPort string
|
|
DNSSEC bool
|
|
KeepAlive bool
|
|
WriteTimeout time.Duration
|
|
ReadTimeout time.Duration
|
|
}
|
|
|
|
type Client struct {
|
|
hostAndPort string
|
|
config Config
|
|
conn net.Conn
|
|
connMutex sync.Mutex
|
|
}
|
|
|
|
func New(config Config) (*Client, error) {
|
|
logger.Debug("Creating DoTCP client: %s (KeepAlive: %v)", config.HostAndPort, config.KeepAlive)
|
|
|
|
if config.HostAndPort == "" {
|
|
logger.Error("DoTCP client creation failed: empty HostAndPort")
|
|
return nil, fmt.Errorf("dotcp: HostAndPort cannot be empty")
|
|
}
|
|
if config.WriteTimeout <= 0 {
|
|
config.WriteTimeout = 2 * time.Second
|
|
}
|
|
if config.ReadTimeout <= 0 {
|
|
config.ReadTimeout = 5 * time.Second
|
|
}
|
|
|
|
client := &Client{
|
|
hostAndPort: config.HostAndPort,
|
|
config: config,
|
|
}
|
|
|
|
if config.KeepAlive {
|
|
if err := client.ensureConnection(); err != nil {
|
|
logger.Error("DoTCP failed to establish initial connection: %v", err)
|
|
return nil, fmt.Errorf("failed to establish initial connection: %w", err)
|
|
}
|
|
}
|
|
|
|
logger.Debug("DoTCP client created: %s (DNSSEC: %v, KeepAlive: %v)", config.HostAndPort, config.DNSSEC, config.KeepAlive)
|
|
return client, nil
|
|
}
|
|
|
|
func (c *Client) Close() {
|
|
logger.Debug("Closing DoTCP client")
|
|
c.connMutex.Lock()
|
|
defer c.connMutex.Unlock()
|
|
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
}
|
|
|
|
func (c *Client) ensureConnection() error {
|
|
c.connMutex.Lock()
|
|
defer c.connMutex.Unlock()
|
|
|
|
if c.conn != nil {
|
|
if err := c.conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err == nil {
|
|
var testBuf [1]byte
|
|
_, err := c.conn.Read(testBuf[:])
|
|
c.conn.SetReadDeadline(time.Time{})
|
|
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
return nil
|
|
}
|
|
|
|
logger.Debug("DoTCP connection test failed, reconnecting: %v", err)
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
}
|
|
|
|
logger.Debug("Establishing DoTCP connection to %s", c.hostAndPort)
|
|
dialer := &net.Dialer{
|
|
Timeout: c.config.WriteTimeout,
|
|
}
|
|
|
|
conn, err := dialer.Dial("tcp", c.hostAndPort)
|
|
if err != nil {
|
|
logger.Error("DoTCP connection failed to %s: %v", c.hostAndPort, err)
|
|
return err
|
|
}
|
|
|
|
c.conn = conn
|
|
logger.Debug("DoTCP connection established to %s", c.hostAndPort)
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
|
|
if len(msg.Question) > 0 {
|
|
question := msg.Question[0]
|
|
logger.Debug("DoTCP query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
|
|
}
|
|
|
|
if c.config.KeepAlive {
|
|
if err := c.ensureConnection(); err != nil {
|
|
return nil, fmt.Errorf("dotcp: failed to ensure connection: %w", err)
|
|
}
|
|
} else {
|
|
c.connMutex.Lock()
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
c.connMutex.Unlock()
|
|
|
|
if err := c.ensureConnection(); err != nil {
|
|
return nil, fmt.Errorf("dotcp: failed to create connection: %w", err)
|
|
}
|
|
}
|
|
|
|
if c.config.DNSSEC {
|
|
msg.SetEdns0(4096, true)
|
|
}
|
|
|
|
packed, err := msg.Pack()
|
|
if err != nil {
|
|
logger.Error("DoTCP failed to pack message: %v", err)
|
|
return nil, fmt.Errorf("dotcp: failed to pack message: %w", err)
|
|
}
|
|
|
|
// DNS over TCP uses 2-byte length prefix
|
|
length := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(length, uint16(len(packed)))
|
|
data := append(length, packed...)
|
|
|
|
c.connMutex.Lock()
|
|
conn := c.conn
|
|
c.connMutex.Unlock()
|
|
|
|
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
|
logger.Error("DoTCP failed to set write deadline: %v", err)
|
|
return nil, fmt.Errorf("dotcp: failed to set write deadline: %w", err)
|
|
}
|
|
|
|
if _, err := conn.Write(data); err != nil {
|
|
logger.Error("DoTCP failed to write message to %s: %v", c.hostAndPort, err)
|
|
|
|
if c.config.KeepAlive {
|
|
logger.Debug("DoTCP write failed with keep-alive, attempting reconnect")
|
|
if reconnectErr := c.ensureConnection(); reconnectErr != nil {
|
|
return nil, fmt.Errorf("dotcp: failed to reconnect: %w", reconnectErr)
|
|
}
|
|
|
|
c.connMutex.Lock()
|
|
conn = c.conn
|
|
c.connMutex.Unlock()
|
|
|
|
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
|
|
return nil, fmt.Errorf("dotcp: failed to set write deadline after reconnect: %w", err)
|
|
}
|
|
|
|
if _, err := conn.Write(data); err != nil {
|
|
return nil, fmt.Errorf("dotcp: failed to write message after reconnect: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("dotcp: failed to write message: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
|
|
logger.Error("DoTCP failed to set read deadline: %v", err)
|
|
return nil, fmt.Errorf("dotcp: failed to set read deadline: %w", err)
|
|
}
|
|
|
|
lengthBuf := make([]byte, 2)
|
|
if _, err := io.ReadFull(conn, lengthBuf); err != nil {
|
|
logger.Error("DoTCP failed to read response length from %s: %v", c.hostAndPort, err)
|
|
return nil, fmt.Errorf("dotcp: failed to read response length: %w", err)
|
|
}
|
|
|
|
msgLen := binary.BigEndian.Uint16(lengthBuf)
|
|
if msgLen > dns.MaxMsgSize {
|
|
logger.Error("DoTCP response too large from %s: %d bytes", c.hostAndPort, msgLen)
|
|
return nil, fmt.Errorf("dotcp: response message too large: %d", msgLen)
|
|
}
|
|
|
|
buffer := make([]byte, msgLen)
|
|
if _, err := io.ReadFull(conn, buffer); err != nil {
|
|
logger.Error("DoTCP failed to read response from %s: %v", c.hostAndPort, err)
|
|
return nil, fmt.Errorf("dotcp: failed to read response: %w", err)
|
|
}
|
|
|
|
response := new(dns.Msg)
|
|
if err := response.Unpack(buffer); err != nil {
|
|
logger.Error("DoTCP failed to unpack response from %s: %v", c.hostAndPort, err)
|
|
return nil, fmt.Errorf("dotcp: failed to unpack response: %w", err)
|
|
}
|
|
|
|
if len(response.Answer) > 0 {
|
|
logger.Debug("DoTCP response from %s: %d answers", c.hostAndPort, len(response.Answer))
|
|
}
|
|
|
|
if !c.config.KeepAlive {
|
|
c.connMutex.Lock()
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
c.connMutex.Unlock()
|
|
}
|
|
|
|
return response, nil
|
|
}
|