feat: add logging

This commit is contained in:
2025-09-08 19:06:21 +01:00
parent 234b1dcc86
commit c6e2b19a84
22 changed files with 429 additions and 1093 deletions

View File

@@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/afonsofrancof/sdns-proxy/common/dnssec" "github.com/afonsofrancof/sdns-proxy/common/dnssec"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/afonsofrancof/sdns-proxy/common/protocols/do53" "github.com/afonsofrancof/sdns-proxy/common/protocols/do53"
"github.com/afonsofrancof/sdns-proxy/common/protocols/doh" "github.com/afonsofrancof/sdns-proxy/common/protocols/doh"
"github.com/afonsofrancof/sdns-proxy/common/protocols/doq" "github.com/afonsofrancof/sdns-proxy/common/protocols/doq"
@@ -26,16 +27,19 @@ type ValidatingDNSClient struct {
} }
type Options struct { type Options struct {
DNSSEC bool DNSSEC bool
ValidateOnly bool ValidateOnly bool
StrictValidation bool StrictValidation bool
} }
// New creates a DNS client based on the upstream string // New creates a DNS client based on the upstream string
func New(upstream string, opts Options) (DNSClient, error) { func New(upstream string, opts Options) (DNSClient, error) {
logger.Debug("Creating DNS client for upstream: %s with options: %+v", upstream, opts)
// Try to parse as URL first // Try to parse as URL first
parsedURL, err := url.Parse(upstream) parsedURL, err := url.Parse(upstream)
if err != nil { if err != nil {
logger.Error("Invalid upstream format: %v", err)
return nil, fmt.Errorf("invalid upstream format: %w", err) return nil, fmt.Errorf("invalid upstream format: %w", err)
} }
@@ -43,30 +47,26 @@ func New(upstream string, opts Options) (DNSClient, error) {
// If it has a scheme, treat it as a full URL // If it has a scheme, treat it as a full URL
if parsedURL.Scheme != "" { if parsedURL.Scheme != "" {
logger.Debug("Parsing %s as URL with scheme %s", upstream, parsedURL.Scheme)
baseClient, err = createClientFromURL(parsedURL, opts) baseClient, err = createClientFromURL(parsedURL, opts)
} else { } else {
// No scheme - treat as plain DNS address // No scheme - treat as plain DNS address
logger.Debug("Parsing %s as plain DNS address", upstream)
baseClient, err = createClientFromPlainAddress(upstream, opts) baseClient, err = createClientFromPlainAddress(upstream, opts)
} }
if err != nil { if err != nil {
logger.Error("Failed to create base client: %v", err)
return nil, err return nil, err
} }
// If DNSSEC is not enabled, return the base client // If DNSSEC is not enabled, return the base client
if !opts.DNSSEC { if !opts.DNSSEC {
logger.Debug("DNSSEC disabled, returning base client")
return baseClient, nil return baseClient, nil
} }
// Wrap with DNSSEC validation logger.Debug("DNSSEC enabled, wrapping with validator")
// validator := dnssec.NewValidator(func(qname string, qtype uint16) (*dns.Msg, error) {
// msg := new(dns.Msg)
// msg.SetQuestion(dns.Fqdn(qname), qtype)
// msg.Id = dns.Id()
// msg.RecursionDesired = true
// msg.SetEdns0(4096, true) // Enable DNSSEC
// return baseClient.Query(msg)
// })
validator := dnssec.NewValidatorWithAuthoritativeQueries() validator := dnssec.NewValidatorWithAuthoritativeQueries()
return &ValidatingDNSClient{ return &ValidatingDNSClient{
@@ -77,9 +77,16 @@ func New(upstream string, opts Options) (DNSClient, error) {
} }
func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) { func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) > 0 {
question := msg.Question[0]
logger.Debug("ValidatingDNSClient query: %s %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v)",
question.Name, dns.TypeToString[question.Qtype], v.options.DNSSEC, v.options.ValidateOnly, v.options.StrictValidation)
}
// Always query the upstream first // Always query the upstream first
response, err := v.client.Query(msg) response, err := v.client.Query(msg)
if err != nil { if err != nil {
logger.Debug("Base client query failed: %v", err)
return nil, err return nil, err
} }
@@ -90,6 +97,7 @@ func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) {
// Extract question details for validation // Extract question details for validation
if len(msg.Question) == 0 { if len(msg.Question) == 0 {
logger.Debug("No questions in message, skipping DNSSEC validation")
return response, nil return response, nil
} }
@@ -97,6 +105,8 @@ func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) {
qname := question.Name qname := question.Name
qtype := question.Qtype qtype := question.Qtype
logger.Debug("Starting DNSSEC validation for %s %s", qname, dns.TypeToString[qtype])
// Validate the response // Validate the response
validationErr := v.validator.ValidateResponse(response, qname, qtype) validationErr := v.validator.ValidateResponse(response, qname, qtype)
@@ -104,28 +114,35 @@ func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) {
if validationErr != nil { if validationErr != nil {
// Check if it's a "not signed" error // Check if it's a "not signed" error
if validationErr == dnssec.ErrResourceNotSigned { if validationErr == dnssec.ErrResourceNotSigned {
logger.Debug("Domain %s is not DNSSEC signed", qname)
if v.options.ValidateOnly { if v.options.ValidateOnly {
logger.Error("Domain %s is not DNSSEC signed (ValidateOnly mode)", qname)
return nil, fmt.Errorf("domain %s is not DNSSEC signed", qname) return nil, fmt.Errorf("domain %s is not DNSSEC signed", qname)
} }
// Return unsigned response if not in validate-only mode // Return unsigned response if not in validate-only mode
logger.Debug("Returning unsigned response for %s", qname)
return response, nil return response, nil
} }
// For other validation errors // For other validation errors
logger.Debug("DNSSEC validation failed for %s: %v", qname, validationErr)
if v.options.StrictValidation { if v.options.StrictValidation {
logger.Error("DNSSEC validation failed for %s (strict mode): %v", qname, validationErr)
return nil, fmt.Errorf("DNSSEC validation failed for %s: %w", qname, validationErr) return nil, fmt.Errorf("DNSSEC validation failed for %s: %w", qname, validationErr)
} }
// In non-strict mode, log the error but return the response // In non-strict mode, log the error but return the response
// (You might want to add logging here) logger.Debug("DNSSEC validation failed for %s (non-strict mode), returning response anyway: %v", qname, validationErr)
return response, nil return response, nil
} }
// Validation successful // Validation successful
logger.Debug("DNSSEC validation successful for %s %s", qname, dns.TypeToString[qtype])
return response, nil return response, nil
} }
func (v *ValidatingDNSClient) Close() { func (v *ValidatingDNSClient) Close() {
logger.Debug("Closing ValidatingDNSClient")
if v.client != nil { if v.client != nil {
v.client.Close() v.client.Close()
} }
@@ -135,6 +152,7 @@ func createClientFromURL(parsedURL *url.URL, opts Options) (DNSClient, error) {
scheme := strings.ToLower(parsedURL.Scheme) scheme := strings.ToLower(parsedURL.Scheme)
host := parsedURL.Hostname() host := parsedURL.Hostname()
if host == "" { if host == "" {
logger.Error("Missing host in upstream URL: %s", parsedURL.String())
return nil, fmt.Errorf("missing host in upstream URL") return nil, fmt.Errorf("missing host in upstream URL")
} }
@@ -148,6 +166,7 @@ func createClientFromURL(parsedURL *url.URL, opts Options) (DNSClient, error) {
path = getDefaultPath(scheme) path = getDefaultPath(scheme)
} }
logger.Debug("Creating client from URL: scheme=%s, host=%s, port=%s, path=%s", scheme, host, port, path)
return createClient(scheme, host, port, path, opts) return createClient(scheme, host, port, path, opts)
} }
@@ -162,41 +181,49 @@ func createClientFromPlainAddress(address string, opts Options) (DNSClient, erro
} }
if host == "" { if host == "" {
logger.Error("Empty host in address: %s", address)
return nil, fmt.Errorf("empty host in address: %s", address) return nil, fmt.Errorf("empty host in address: %s", address)
} }
logger.Debug("Creating client from plain address: host=%s, port=%s", host, port)
return createClient("", host, port, "", opts) return createClient("", host, port, "", opts)
} }
func getDefaultPort(scheme string) string { func getDefaultPort(scheme string) string {
port := "53"
switch scheme { switch scheme {
case "https", "doh", "doh3": case "https", "doh", "doh3":
return "443" port = "443"
case "tls", "dot": case "tls", "dot":
return "853" port = "853"
case "quic", "doq": case "quic", "doq":
return "853" port = "853"
default:
return "53"
} }
logger.Debug("Default port for scheme %s: %s", scheme, port)
return port
} }
func getDefaultPath(scheme string) string { func getDefaultPath(scheme string) string {
path := ""
switch scheme { switch scheme {
case "https", "doh", "doh3": case "https", "doh", "doh3":
return "/dns-query" path = "/dns-query"
default:
return ""
} }
logger.Debug("Default path for scheme %s: %s", scheme, path)
return path
} }
func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) { func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) {
logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v",
scheme, host, port, path, opts.DNSSEC)
switch scheme { switch scheme {
case "udp", "tcp", "do53", "": case "udp", "tcp", "do53", "":
config := do53.Config{ config := do53.Config{
HostAndPort: net.JoinHostPort(host, port), HostAndPort: net.JoinHostPort(host, port),
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
} }
logger.Debug("Creating DO53 client with config: %+v", config)
return do53.New(config) return do53.New(config)
case "http", "doh": case "http", "doh":
@@ -207,6 +234,7 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
HTTP3: false, HTTP3: false,
} }
logger.Debug("Creating DoH client with config: %+v", config)
return doh.New(config) return doh.New(config)
case "https", "doh3": case "https", "doh3":
@@ -217,6 +245,7 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
HTTP3: true, HTTP3: true,
} }
logger.Debug("Creating DoH3 client with config: %+v", config)
return doh.New(config) return doh.New(config)
case "tls", "dot": case "tls", "dot":
@@ -225,6 +254,7 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err
Port: port, Port: port,
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
} }
logger.Debug("Creating DoT client with config: %+v", config)
return dot.New(config) return dot.New(config)
case "doq": // DNS over QUIC case "doq": // DNS over QUIC
@@ -233,9 +263,11 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err
Port: port, Port: port,
DNSSEC: opts.DNSSEC, DNSSEC: opts.DNSSEC,
} }
logger.Debug("Creating DoQ client with config: %+v", config)
return doq.New(config) return doq.New(config)
default: default:
logger.Error("Unsupported scheme: %s", scheme)
return nil, fmt.Errorf("unsupported scheme: %s", scheme) return nil, fmt.Errorf("unsupported scheme: %s", scheme)
} }
} }

View File

@@ -1,138 +0,0 @@
package main
import (
"fmt"
"log"
"os"
"strings"
"time"
"github.com/afonsofrancof/sdns-perf/internal/client"
"github.com/alecthomas/kong"
"github.com/miekg/dns"
)
var cli struct {
// Global flags
Verbose bool `help:"Enable verbose logging." short:"v"`
Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."`
Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."`
}
type QueryCmd struct {
DomainName string `help:"Domain name to resolve." arg:"" required:""`
Server string `help:"Upstream server address (e.g., https://1.1.1.1/dns-query, tls://1.1.1.1, 8.8.8.8)." short:"s" required:""`
QueryType string `help:"Query type (A, AAAA, MX, TXT, etc.)." short:"t" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"`
DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"`
Timeout time.Duration `help:"Timeout for the query operation." default:"10s"` // Default might be higher now
KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"`
}
func (q *QueryCmd) Run() error {
log.Printf("Querying %s for %s type %s (DNSSEC: %v, Timeout: %v)\n",
q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.Timeout)
opts := client.Options{
Timeout: q.Timeout,
DNSSEC: q.DNSSEC,
KeyLogPath: q.KeyLogFile,
}
dnsClient, err := client.New(q.Server, opts)
if err != nil {
return err
}
defer dnsClient.Close()
qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)]
if !ok {
return fmt.Errorf("invalid query type: %s", q.QueryType)
}
dnsMsg, err := dnsClient.Query(q.DomainName, qTypeUint)
if err != nil {
return fmt.Errorf("query failed: %w ", err)
}
printResponse(q.DomainName, q.QueryType, dnsMsg)
return nil
}
type ListenCmd struct {
Address string `help:"Address to listen on (e.g., :53, :8053)." default:":53"`
// Add other server-specific flags: default upstream, TLS cert/key paths etc.
}
func (l *ListenCmd) Run() error {
return fmt.Errorf("server/listen mode not yet implemented")
}
func printResponse(domain, qtype string, msg *dns.Msg) {
fmt.Println(";; QUESTION SECTION:")
fmt.Printf(";%s.\tIN\t%s\n", dns.Fqdn(domain), strings.ToUpper(qtype))
fmt.Println("\n;; ANSWER SECTION:")
if len(msg.Answer) > 0 {
for _, rr := range msg.Answer {
fmt.Println(rr.String())
}
} else {
fmt.Println(";; No records found in answer section.")
}
if len(msg.Ns) > 0 {
fmt.Println("\n;; AUTHORITY SECTION:")
for _, rr := range msg.Ns {
fmt.Println(rr.String())
}
}
if len(msg.Extra) > 0 {
hasRealExtra := false
for _, rr := range msg.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
hasRealExtra = true
break
}
}
if hasRealExtra {
fmt.Println("\n;; ADDITIONAL SECTION:")
for _, rr := range msg.Extra {
if rr.Header().Rrtype != dns.TypeOPT {
fmt.Println(rr.String())
}
}
}
}
fmt.Printf("\n;; RCODE: %s, ID: %d", dns.RcodeToString[msg.Rcode], msg.Id)
opt := msg.IsEdns0()
if opt != nil {
fmt.Printf(", EDNS: version: %d; flags:", opt.Version())
if opt.Do() {
fmt.Printf(" do;")
} else {
fmt.Printf(";")
}
fmt.Printf(" udp: %d", opt.UDPSize())
}
fmt.Println()
}
func main() {
log.SetOutput(os.Stderr)
log.SetFlags(log.Ltime | log.Lshortfile)
kongCtx := kong.Parse(&cli,
kong.Name("sdns-perf"),
kong.Description("A DNS client/server tool supporting multiple protocols."),
kong.UsageOnError(),
kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}),
)
err := kongCtx.Run()
kongCtx.FatalIfErrorf(err)
}

View File

@@ -20,9 +20,9 @@ package dnssec
import ( import (
"fmt" "fmt"
"log"
"strings" "strings"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -57,13 +57,13 @@ func (ac *AuthenticationChain) Populate(domainName string, queryFunc func(string
zones = append(zones, zone) zones = append(zones, zone)
} }
log.Printf("Building DNSSEC chain for zones: %v", zones) logger.Debug("Building DNSSEC chain for zones: %v", zones)
ac.DelegationChain = make([]SignedZone, 0, len(zones)) ac.DelegationChain = make([]SignedZone, 0, len(zones))
// Query each zone from root down // Query each zone from root down
for i, zoneName := range zones { for i, zoneName := range zones {
log.Printf("Querying zone: %s", zoneName) logger.Debug("Querying zone: %s", zoneName)
delegation, err := ac.queryDelegation(zoneName, queryFunc) delegation, err := ac.queryDelegation(zoneName, queryFunc)
if err != nil { if err != nil {
@@ -91,13 +91,13 @@ func (ac *AuthenticationChain) queryDelegation(domainName string, queryFunc func
} }
signedZone.DNSKey = dnskeyRRset signedZone.DNSKey = dnskeyRRset
log.Printf("Found %d DNSKEY records for %s", len(dnskeyRRset.RRs), domainName) logger.Debug("Found %d DNSKEY records for %s", len(dnskeyRRset.RRs), domainName)
// Populate public key lookup // Populate public key lookup
for _, rr := range signedZone.DNSKey.RRs { for _, rr := range signedZone.DNSKey.RRs {
if dnskey, ok := rr.(*dns.DNSKEY); ok { if dnskey, ok := rr.(*dns.DNSKEY); ok {
signedZone.AddPubKey(dnskey) signedZone.AddPubKey(dnskey)
log.Printf("Added DNSKEY for %s: keytag=%d, flags=%d, algorithm=%d", domainName, dnskey.KeyTag(), dnskey.Flags, dnskey.Algorithm) logger.Debug("Added DNSKEY for %s: keytag=%d, flags=%d, algorithm=%d", domainName, dnskey.KeyTag(), dnskey.Flags, dnskey.Algorithm)
} }
} }
@@ -106,17 +106,17 @@ func (ac *AuthenticationChain) queryDelegation(domainName string, queryFunc func
dsRRset, _ := ac.queryRRset(domainName, dns.TypeDS, queryFunc) dsRRset, _ := ac.queryRRset(domainName, dns.TypeDS, queryFunc)
signedZone.DS = dsRRset signedZone.DS = dsRRset
if dsRRset != nil && len(dsRRset.RRs) > 0 { if dsRRset != nil && len(dsRRset.RRs) > 0 {
log.Printf("Found %d DS records for %s", len(dsRRset.RRs), domainName) logger.Debug("Found %d DS records for %s", len(dsRRset.RRs), domainName)
for _, rr := range dsRRset.RRs { for _, rr := range dsRRset.RRs {
if ds, ok := rr.(*dns.DS); ok { if ds, ok := rr.(*dns.DS); ok {
log.Printf("DS record for %s: keytag=%d", domainName, ds.KeyTag) logger.Debug("DS record for %s: keytag=%d", domainName, ds.KeyTag)
} }
} }
} }
} else { } else {
// Root zone has no DS records - trusted by default // Root zone has no DS records - trusted by default
signedZone.DS = NewRRSet() signedZone.DS = NewRRSet()
log.Printf("Root zone - no DS records, trusted by default") logger.Debug("Root zone - no DS records, trusted by default")
} }
return signedZone, nil return signedZone, nil
@@ -125,12 +125,12 @@ func (ac *AuthenticationChain) queryDelegation(domainName string, queryFunc func
func (ac *AuthenticationChain) queryRRset(qname string, qtype uint16, queryFunc func(string, uint16) (*dns.Msg, error)) (*RRSet, error) { func (ac *AuthenticationChain) queryRRset(qname string, qtype uint16, queryFunc func(string, uint16) (*dns.Msg, error)) (*RRSet, error) {
r, err := queryFunc(qname, qtype) r, err := queryFunc(qname, qtype)
if err != nil { if err != nil {
log.Printf("cannot lookup %v", err) logger.Debug("cannot lookup %v", err)
return NewRRSet(), nil // Return empty RRSet instead of nil return NewRRSet(), nil // Return empty RRSet instead of nil
} }
if r.Rcode == dns.RcodeNameError { if r.Rcode == dns.RcodeNameError {
log.Printf("no such domain %s", qname) logger.Debug("no such domain %s", qname)
return NewRRSet(), nil // Return empty RRSet instead of nil return NewRRSet(), nil // Return empty RRSet instead of nil
} }
@@ -167,60 +167,60 @@ func (ac *AuthenticationChain) Verify(answerRRset *RRSet) error {
// Verify the answer RRset against target zone's keys // Verify the answer RRset against target zone's keys
err := targetZone.VerifyRRSIG(answerRRset) err := targetZone.VerifyRRSIG(answerRRset)
if err != nil { if err != nil {
log.Printf("Answer RRSIG verification failed: %v", err) logger.Debug("Answer RRSIG verification failed: %v", err)
return ErrInvalidRRsig return ErrInvalidRRsig
} }
// Validate the chain from root down // Validate the chain from root down
for _, zone := range ac.DelegationChain { for _, zone := range ac.DelegationChain {
log.Printf("Validating zone: %s", zone.Zone) logger.Debug("Validating zone: %s", zone.Zone)
// Verify DNSKEY RRset signature // Verify DNSKEY RRset signature
if !zone.HasDNSKeys() { if !zone.HasDNSKeys() {
log.Printf("No DNSKEYs for zone %s", zone.Zone) logger.Debug("No DNSKEYs for zone %s", zone.Zone)
return ErrDnskeyNotAvailable return ErrDnskeyNotAvailable
} }
err := zone.VerifyRRSIG(zone.DNSKey) err := zone.VerifyRRSIG(zone.DNSKey)
if err != nil { if err != nil {
log.Printf("DNSKEY validation failed for %s: %v", zone.Zone, err) logger.Debug("DNSKEY validation failed for %s: %v", zone.Zone, err)
return ErrRrsigValidationError return ErrRrsigValidationError
} }
// Skip ALL validation for root - just trust it // Skip ALL validation for root - just trust it
if zone.Zone == "." { if zone.Zone == "." {
log.Printf("Root zone - trusted by default, no validation performed") logger.Debug("Root zone - trusted by default, no validation performed")
continue continue
} }
// For non-root zones, validate DS records against parent zone // For non-root zones, validate DS records against parent zone
if zone.ParentZone == nil { if zone.ParentZone == nil {
log.Printf("Non-root zone %s has no parent", zone.Zone) logger.Debug("Non-root zone %s has no parent", zone.Zone)
return fmt.Errorf("non-root zone %s has no parent", zone.Zone) return fmt.Errorf("non-root zone %s has no parent", zone.Zone)
} }
if zone.DS == nil || zone.DS.IsEmpty() { if zone.DS == nil || zone.DS.IsEmpty() {
log.Printf("No DS records for zone %s", zone.Zone) logger.Debug("No DS records for zone %s", zone.Zone)
return ErrDsNotAvailable return ErrDsNotAvailable
} }
// Verify DS signature using parent's key // Verify DS signature using parent's key
err = zone.ParentZone.VerifyRRSIG(zone.DS) err = zone.ParentZone.VerifyRRSIG(zone.DS)
if err != nil { if err != nil {
log.Printf("DS signature validation failed for %s: %v", zone.Zone, err) logger.Debug("DS signature validation failed for %s: %v", zone.Zone, err)
return ErrRrsigValidationError return ErrRrsigValidationError
} }
// Verify DS matches this zone's DNSKEY // Verify DS matches this zone's DNSKEY
err = zone.VerifyDS(zone.DS.RRs) err = zone.VerifyDS(zone.DS.RRs)
if err != nil { if err != nil {
log.Printf("DS-DNSKEY validation failed for %s: %v", zone.Zone, err) logger.Debug("DS-DNSKEY validation failed for %s: %v", zone.Zone, err)
return ErrDsInvalid return ErrDsInvalid
} }
log.Printf("Zone %s validated successfully", zone.Zone) logger.Debug("Zone %s validated successfully", zone.Zone)
} }
log.Printf("DNSSEC validation successful for entire chain!") logger.Debug("DNSSEC validation successful for entire chain!")
return nil return nil
} }

View File

@@ -2,11 +2,11 @@ package dnssec
import ( import (
"fmt" "fmt"
"log"
"net" "net"
"strings" "strings"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -28,13 +28,13 @@ func NewAuthoritativeQuerier() *AuthoritativeQuerier {
} }
func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) (*dns.Msg, error) { func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) (*dns.Msg, error) {
log.Printf("Querying authoritative servers for %s type %d", qname, qtype) logger.Debug("Querying authoritative servers for %s type %d", qname, qtype)
var zone string var zone string
if qtype == dns.TypeDS { if qtype == dns.TypeDS {
zone = aq.getParentZone(qname) zone = aq.getParentZone(qname)
if zone == "" { if zone == "" {
log.Printf("No parent zone for %s - returning NXDOMAIN for DS query", qname) logger.Debug("No parent zone for %s - returning NXDOMAIN for DS query", qname)
msg := &dns.Msg{} msg := &dns.Msg{}
msg.SetRcode(&dns.Msg{}, dns.RcodeNameError) msg.SetRcode(&dns.Msg{}, dns.RcodeNameError)
return msg, nil return msg, nil
@@ -43,7 +43,7 @@ func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) (
zone = aq.findZone(qname) zone = aq.findZone(qname)
} }
log.Printf("Determined zone: %s for query %s type %d", zone, qname, qtype) logger.Debug("Determined zone: %s for query %s type %d", zone, qname, qtype)
// Get NS names (not IPs yet) // Get NS names (not IPs yet)
nsNames, err := aq.findAuthoritativeNSNames(zone) nsNames, err := aq.findAuthoritativeNSNames(zone)
@@ -59,15 +59,15 @@ func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) (
continue continue
} }
log.Printf("Trying server: %s (%s)", server, nsName) logger.Debug("Trying server: %s (%s)", server, nsName)
msg, err := aq.queryServer(server, qname, qtype) msg, err := aq.queryServer(server, qname, qtype)
if err != nil { if err != nil {
log.Printf("Server %s failed: %v", server, err) logger.Debug("Server %s failed: %v", server, err)
lastErr = err lastErr = err
continue continue
} }
log.Printf("Server %s responded, authoritative: %v, rcode: %d, answers: %d", server, msg.Authoritative, msg.Rcode, len(msg.Answer)) logger.Debug("Server %s responded, authoritative: %v, rcode: %d, answers: %d", server, msg.Authoritative, msg.Rcode, len(msg.Answer))
if (msg.Rcode == dns.RcodeSuccess && len(msg.Answer) > 0) || msg.Rcode == dns.RcodeNameError { if (msg.Rcode == dns.RcodeSuccess && len(msg.Answer) > 0) || msg.Rcode == dns.RcodeNameError {
return msg, nil return msg, nil
} }
@@ -82,11 +82,11 @@ func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) (
func (aq *AuthoritativeQuerier) findAuthoritativeNSNames(zone string) ([]string, error) { func (aq *AuthoritativeQuerier) findAuthoritativeNSNames(zone string) ([]string, error) {
if nsNames, exists := aq.nsCache[zone]; exists { if nsNames, exists := aq.nsCache[zone]; exists {
log.Printf("Using cached NS names for %s: %v", zone, nsNames) logger.Debug("Using cached NS names for %s: %v", zone, nsNames)
return nsNames, nil return nsNames, nil
} }
log.Printf("Looking for NS records for zone: %s", zone) logger.Debug("Looking for NS records for zone: %s", zone)
// Use a public resolver to find the NS records // Use a public resolver to find the NS records
resolver := &dns.Client{Timeout: 5 * time.Second} resolver := &dns.Client{Timeout: 5 * time.Second}
@@ -125,35 +125,35 @@ func (aq *AuthoritativeQuerier) findAuthoritativeNSNames(zone string) ([]string,
return nil, fmt.Errorf("no NS servers found for %s", zone) return nil, fmt.Errorf("no NS servers found for %s", zone)
} }
log.Printf("Found NS names for %s: %v", zone, nsNames) logger.Debug("Found NS names for %s: %v", zone, nsNames)
aq.nsCache[zone] = nsNames aq.nsCache[zone] = nsNames
log.Printf("Cached NS names for %s: %v", zone, nsNames) logger.Debug("Cached NS names for %s: %v", zone, nsNames)
return nsNames, nil return nsNames, nil
} }
func (aq *AuthoritativeQuerier) getParentZone(qname string) string { func (aq *AuthoritativeQuerier) getParentZone(qname string) string {
log.Printf("Getting parent zone for: %s", qname) logger.Debug("Getting parent zone for: %s", qname)
// Clean the qname // Clean the qname
qname = strings.TrimSuffix(qname, ".") qname = strings.TrimSuffix(qname, ".")
// Root zone has no parent // Root zone has no parent
if qname == "" || qname == "." { if qname == "" || qname == "." {
log.Printf("Root zone has no parent") logger.Debug("Root zone has no parent")
return "" return ""
} }
labels := dns.SplitDomainName(qname) labels := dns.SplitDomainName(qname)
log.Printf("Labels for %s: %v", qname, labels) logger.Debug("Labels for %s: %v", qname, labels)
if len(labels) <= 1 { if len(labels) <= 1 {
log.Printf("Parent of TLD %s is root", qname) logger.Debug("Parent of TLD %s is root", qname)
return "." // Parent of TLD is root return "." // Parent of TLD is root
} }
parentLabels := labels[1:] parentLabels := labels[1:]
parent := dns.Fqdn(strings.Join(parentLabels, ".")) parent := dns.Fqdn(strings.Join(parentLabels, "."))
log.Printf("Parent zone of %s is %s", qname, parent) logger.Debug("Parent zone of %s is %s", qname, parent)
return parent return parent
} }
@@ -167,26 +167,25 @@ func (aq *AuthoritativeQuerier) findZone(qname string) string {
return qname return qname
} }
func (aq *AuthoritativeQuerier) resolveNSToIP(nsName string) string { func (aq *AuthoritativeQuerier) resolveNSToIP(nsName string) string {
if ip, exists := aq.ipCache[nsName]; exists { if ip, exists := aq.ipCache[nsName]; exists {
log.Printf("Using cached IP for %s: %s", nsName, ip) logger.Debug("Using cached IP for %s: %s", nsName, ip)
return ip return ip
} }
nsName = strings.TrimSuffix(nsName, ".") nsName = strings.TrimSuffix(nsName, ".")
log.Printf("Resolving NS %s to IP", nsName) logger.Debug("Resolving NS %s to IP", nsName)
ips, err := net.LookupIP(nsName) ips, err := net.LookupIP(nsName)
if err != nil { if err != nil {
log.Printf("Failed to resolve %s: %v", nsName, err) logger.Debug("Failed to resolve %s: %v", nsName, err)
return "" return ""
} }
for _, ip := range ips { for _, ip := range ips {
if ip.To4() != nil { // Prefer IPv4 if ip.To4() != nil { // Prefer IPv4
result := ip.String() + ":53" result := ip.String() + ":53"
log.Printf("Resolved %s to %s", nsName, result) logger.Debug("Resolved %s to %s", nsName, result)
// Cache the result before returning // Cache the result before returning
aq.ipCache[nsName] = result aq.ipCache[nsName] = result
@@ -202,12 +201,12 @@ func (aq *AuthoritativeQuerier) queryServer(server, qname string, qtype uint16)
m.SetQuestion(dns.Fqdn(qname), qtype) m.SetQuestion(dns.Fqdn(qname), qtype)
m.SetEdns0(4096, true) // Enable DNSSEC m.SetEdns0(4096, true) // Enable DNSSEC
log.Printf("Querying %s for %s type %d", server, qname, qtype) logger.Debug("Querying %s for %s type %d", server, qname, qtype)
msg, _, err := aq.client.Exchange(m, server) msg, _, err := aq.client.Exchange(m, server)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Printf("Response from %s: rcode=%d, answers=%d", server, msg.Rcode, len(msg.Answer)) logger.Debug("Response from %s: rcode=%d, answers=%d", server, msg.Rcode, len(msg.Answer))
return msg, err return msg, err
} }

View File

@@ -19,9 +19,9 @@ package dnssec
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import ( import (
"log"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -65,7 +65,7 @@ func (r *RRSet) ValidateSignature(key *dns.DNSKEY) error {
err := r.RRSig.Verify(key, r.RRs) err := r.RRSig.Verify(key, r.RRs)
if err != nil { if err != nil {
log.Printf("RRSIG verification failed: %v", err) logger.Debug("RRSIG verification failed: %v", err)
return ErrRrsigValidationError return ErrRrsigValidationError
} }

View File

@@ -19,9 +19,9 @@ package dnssec
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import ( import (
"log"
"strings" "strings"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -61,7 +61,7 @@ func (z *SignedZone) VerifyRRSIG(signedRRset *RRSet) error {
key := z.LookupPubKey(signedRRset.RRSig.KeyTag) key := z.LookupPubKey(signedRRset.RRSig.KeyTag)
if key == nil { if key == nil {
log.Printf("DNSKEY keytag %d not found in zone %s", signedRRset.RRSig.KeyTag, z.Zone) logger.Debug("DNSKEY keytag %d not found in zone %s", signedRRset.RRSig.KeyTag, z.Zone)
return ErrDnskeyNotAvailable return ErrDnskeyNotAvailable
} }
@@ -69,36 +69,36 @@ func (z *SignedZone) VerifyRRSIG(signedRRset *RRSet) error {
} }
func (z *SignedZone) VerifyDS(dsRRset []dns.RR) error { func (z *SignedZone) VerifyDS(dsRRset []dns.RR) error {
log.Printf("Verifying DS for zone %s with %d DS records", z.Zone, len(dsRRset)) logger.Debug("Verifying DS for zone %s with %d DS records", z.Zone, len(dsRRset))
for _, rr := range dsRRset { for _, rr := range dsRRset {
ds, ok := rr.(*dns.DS) ds, ok := rr.(*dns.DS)
if !ok { if !ok {
continue continue
} }
log.Printf("Checking DS keytag %d, digestType %d", ds.KeyTag, ds.DigestType) logger.Debug("Checking DS keytag %d, digestType %d", ds.KeyTag, ds.DigestType)
if ds.DigestType != dns.SHA256 { if ds.DigestType != dns.SHA256 {
log.Printf("Unknown digest type (%d) on DS RR", ds.DigestType) logger.Debug("Unknown digest type (%d) on DS RR", ds.DigestType)
continue continue
} }
parentDsDigest := strings.ToUpper(ds.Digest) parentDsDigest := strings.ToUpper(ds.Digest)
key := z.LookupPubKey(ds.KeyTag) key := z.LookupPubKey(ds.KeyTag)
if key == nil { if key == nil {
log.Printf("DNSKEY keytag %d not found in zone %s", ds.KeyTag, z.Zone) logger.Debug("DNSKEY keytag %d not found in zone %s", ds.KeyTag, z.Zone)
return ErrDnskeyNotAvailable return ErrDnskeyNotAvailable
} }
dsDigest := strings.ToUpper(key.ToDS(ds.DigestType).Digest) dsDigest := strings.ToUpper(key.ToDS(ds.DigestType).Digest)
log.Printf("Parent DS digest: %s, Computed digest: %s", parentDsDigest, dsDigest) logger.Debug("Parent DS digest: %s, Computed digest: %s", parentDsDigest, dsDigest)
if parentDsDigest == dsDigest { if parentDsDigest == dsDigest {
log.Printf("DS validation successful for keytag %d", ds.KeyTag) logger.Debug("DS validation successful for keytag %d", ds.KeyTag)
return nil return nil
} }
log.Printf("DS does not match DNSKEY for keytag %d", ds.KeyTag) logger.Debug("DS does not match DNSKEY for keytag %d", ds.KeyTag)
} }
log.Printf("No matching DS found") logger.Debug("No matching DS found")
return ErrDsInvalid return ErrDsInvalid
} }

View File

@@ -17,10 +17,10 @@ package dnssec
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// ./common/dnssec/validator.go
import ( import (
"log" "github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -35,7 +35,10 @@ func NewValidator(queryFunc func(string, uint16) (*dns.Msg, error)) *Validator {
} }
func (v *Validator) ValidateResponse(msg *dns.Msg, qname string, qtype uint16) error { func (v *Validator) ValidateResponse(msg *dns.Msg, qname string, qtype uint16) error {
logger.Debug("Starting DNSSEC validation for %s %s", qname, dns.TypeToString[qtype])
if msg == nil || len(msg.Answer) == 0 { if msg == nil || len(msg.Answer) == 0 {
logger.Debug("No result for %s %s", qname, dns.TypeToString[qtype])
return ErrNoResult return ErrNoResult
} }
@@ -46,41 +49,48 @@ func (v *Validator) ValidateResponse(msg *dns.Msg, qname string, qtype uint16) e
case *dns.RRSIG: case *dns.RRSIG:
if t.TypeCovered == qtype { if t.TypeCovered == qtype {
rrset.RRSig = t rrset.RRSig = t
logger.Debug("Found RRSIG for %s %s (keytag: %d)", qname, dns.TypeToString[qtype], t.KeyTag)
} }
default: default:
if rr.Header().Rrtype == qtype { if rr.Header().Rrtype == qtype {
rrset.RRs = append(rrset.RRs, rr) rrset.RRs = append(rrset.RRs, rr)
logger.Debug("Found RR for %s %s: %s", qname, dns.TypeToString[qtype], rr.String())
} }
} }
} }
if rrset.IsEmpty() { if rrset.IsEmpty() {
logger.Debug("Empty RRSet for %s %s", qname, dns.TypeToString[qtype])
return ErrNoResult return ErrNoResult
} }
if !rrset.IsSigned() { if !rrset.IsSigned() {
logger.Debug("RRSet for %s %s is not signed", qname, dns.TypeToString[qtype])
return ErrResourceNotSigned return ErrResourceNotSigned
} }
// Check header integrity // Check header integrity
if err := rrset.CheckHeaderIntegrity(qname); err != nil { if err := rrset.CheckHeaderIntegrity(qname); err != nil {
logger.Debug("Header integrity check failed for %s %s: %v", qname, dns.TypeToString[qtype], err)
return err return err
} }
// Build and verify authentication chain // Build and verify authentication chain
signerName := rrset.SignerName() signerName := rrset.SignerName()
logger.Debug("Building authentication chain for signer: %s", signerName)
authChain := NewAuthenticationChain() authChain := NewAuthenticationChain()
if err := authChain.Populate(signerName, v.queryFunc); err != nil { if err := authChain.Populate(signerName, v.queryFunc); err != nil {
log.Printf("Cannot populate authentication chain: %s", err) logger.Debug("Cannot populate authentication chain for %s: %v", signerName, err)
return err return err
} }
if err := authChain.Verify(rrset); err != nil { if err := authChain.Verify(rrset); err != nil {
log.Printf("DNSSEC validation failed: %s", err) logger.Debug("DNSSEC validation failed for %s %s: %v", qname, dns.TypeToString[qtype], err)
return err return err
} }
logger.Debug("DNSSEC validation successful for %s %s", qname, dns.TypeToString[qtype])
return nil return nil
} }

41
common/logger/logger.go Normal file
View File

@@ -0,0 +1,41 @@
package logger
import (
"log"
"os"
)
var (
debugEnabled bool
infoLogger *log.Logger
debugLogger *log.Logger
errorLogger *log.Logger
)
func init() {
infoLogger = log.New(os.Stderr, "[INFO] ", log.Ltime|log.Lshortfile)
debugLogger = log.New(os.Stderr, "[DEBUG] ", log.Ltime|log.Lshortfile)
errorLogger = log.New(os.Stderr, "[ERROR] ", log.Ltime|log.Lshortfile)
}
func SetDebug(enabled bool) {
debugEnabled = enabled
}
func IsDebugEnabled() bool {
return debugEnabled
}
func Info(format string, args ...any) {
infoLogger.Printf(format, args...)
}
func Debug(format string, args ...any) {
if debugEnabled {
debugLogger.Printf(format, args...)
}
}
func Error(format string, args ...any) {
errorLogger.Printf(format, args...)
}

View File

@@ -5,6 +5,7 @@ import (
"net" "net"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -21,7 +22,10 @@ type Client struct {
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DO53 client: %s", config.HostAndPort)
if config.HostAndPort == "" { if config.HostAndPort == "" {
logger.Error("DO53 client creation failed: empty HostAndPort")
return nil, fmt.Errorf("do53: HostAndPort cannot be empty") return nil, fmt.Errorf("do53: HostAndPort cannot be empty")
} }
if config.WriteTimeout <= 0 { if config.WriteTimeout <= 0 {
@@ -31,6 +35,8 @@ func New(config Config) (*Client, error) {
config.ReadTimeout = 5 * time.Second config.ReadTimeout = 5 * time.Second
} }
logger.Debug("DO53 client created: %s (DNSSEC: %v)", config.HostAndPort, config.DNSSEC)
return &Client{ return &Client{
hostAndPort: config.HostAndPort, hostAndPort: config.HostAndPort,
config: config, config: config,
@@ -38,18 +44,32 @@ func New(config Config) (*Client, error) {
} }
func (c *Client) Close() { func (c *Client) Close() {
logger.Debug("Closing DO53 client")
} }
func (c *Client) createConnection() (*net.UDPConn, error) { func (c *Client) createConnection() (*net.UDPConn, error) {
udpAddr, err := net.ResolveUDPAddr("udp", c.hostAndPort) udpAddr, err := net.ResolveUDPAddr("udp", c.hostAndPort)
if err != nil { if err != nil {
logger.Error("DO53 failed to resolve address %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("failed to resolve UDP address: %w", err) return nil, fmt.Errorf("failed to resolve UDP address: %w", err)
} }
return net.DialUDP("udp", nil, udpAddr) conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
logger.Error("DO53 failed to connect to %s: %v", c.hostAndPort, err)
return nil, err
}
logger.Debug("DO53 connection established to %s", c.hostAndPort)
return conn, nil
} }
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) > 0 {
question := msg.Question[0]
logger.Debug("DO53 query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
}
// Create connection for this query // Create connection for this query
conn, err := c.createConnection() conn, err := c.createConnection()
if err != nil { if err != nil {
@@ -62,36 +82,45 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
} }
packedMsg, err := msg.Pack() packedMsg, err := msg.Pack()
if err != nil { if err != nil {
logger.Error("DO53 failed to pack message: %v", err)
return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err) return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err)
} }
// Send query // Send query
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
logger.Error("DO53 failed to set write deadline: %v", err)
return nil, fmt.Errorf("do53: failed to set write deadline: %w", err) return nil, fmt.Errorf("do53: failed to set write deadline: %w", err)
} }
if _, err := conn.Write(packedMsg); err != nil { if _, err := conn.Write(packedMsg); err != nil {
logger.Error("DO53 failed to send query to %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("do53: failed to send DNS query: %w", err) return nil, fmt.Errorf("do53: failed to send DNS query: %w", err)
} }
// Read response // Read response
if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
logger.Error("DO53 failed to set read deadline: %v", err)
return nil, fmt.Errorf("do53: failed to set read deadline: %w", err) return nil, fmt.Errorf("do53: failed to set read deadline: %w", err)
} }
buffer := make([]byte, dns.MaxMsgSize) buffer := make([]byte, dns.MaxMsgSize)
n, err := conn.Read(buffer) n, err := conn.Read(buffer)
if err != nil { if err != nil {
logger.Error("DO53 failed to read response from %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("do53: failed to read DNS response: %w", err) return nil, fmt.Errorf("do53: failed to read DNS response: %w", err)
} }
// Parse response // Parse response
response := new(dns.Msg) response := new(dns.Msg)
if err := response.Unpack(buffer[:n]); err != nil { if err := response.Unpack(buffer[:n]); err != nil {
logger.Error("DO53 failed to unpack response from %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("do53: failed to unpack DNS response: %w", err) return nil, fmt.Errorf("do53: failed to unpack DNS response: %w", err)
} }
if len(response.Answer) > 0 {
logger.Debug("DO53 response from %s: %d answers", c.hostAndPort, len(response.Answer))
}
return response, nil return response, nil
} }

View File

@@ -12,6 +12,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3" "github.com/quic-go/quic-go/http3"
@@ -36,8 +37,10 @@ type Client struct {
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DoH client: %s:%s%s", config.Host, config.Port, config.Path)
if config.Host == "" || config.Port == "" || config.Path == "" { if config.Host == "" || config.Port == "" || config.Path == "" {
fmt.Printf("%v,%v,%v", config.Host, config.Port, config.Path) logger.Error("DoH client creation failed: missing required fields")
return nil, errors.New("doh: host, port, and path must not be empty") return nil, errors.New("doh: host, port, and path must not be empty")
} }
@@ -48,6 +51,7 @@ func New(config Config) (*Client, error) {
parsedURL, err := url.Parse(rawURL) parsedURL, err := url.Parse(rawURL)
if err != nil { if err != nil {
logger.Error("Failed to parse DoH URL %s: %v", rawURL, err)
return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err) return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err)
} }
@@ -67,21 +71,26 @@ func New(config Config) (*Client, error) {
Transport: transport, Transport: transport,
} }
var transportType string
if config.HTTP2 { if config.HTTP2 {
httpClient.Transport = &http2.Transport{ httpClient.Transport = &http2.Transport{
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
AllowHTTP: true, AllowHTTP: true,
} }
} transportType = "HTTP/2"
} else if config.HTTP3 {
if config.HTTP3 {
quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig) quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig)
httpClient.Transport = &http3.Transport{ httpClient.Transport = &http3.Transport{
TLSClientConfig: quicTlsConfig, TLSClientConfig: quicTlsConfig,
QUICConfig: quicConfig, QUICConfig: quicConfig,
} }
transportType = "HTTP/3"
} else {
transportType = "HTTP/1.1"
} }
logger.Debug("DoH client created: %s (%s, DNSSEC: %v)", rawURL, transportType, config.DNSSEC)
return &Client{ return &Client{
httpClient: httpClient, httpClient: httpClient,
upstreamURL: parsedURL, upstreamURL: parsedURL,
@@ -90,6 +99,7 @@ func New(config Config) (*Client, error) {
} }
func (c *Client) Close() { func (c *Client) Close() {
logger.Debug("Closing DoH client")
if t, ok := c.httpClient.Transport.(*http.Transport); ok { if t, ok := c.httpClient.Transport.(*http.Transport); ok {
t.CloseIdleConnections() t.CloseIdleConnections()
} else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok { } else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok {
@@ -98,16 +108,24 @@ func (c *Client) Close() {
} }
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) > 0 {
question := msg.Question[0]
logger.Debug("DoH query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.upstreamURL.Host)
}
if c.config.DNSSEC { if c.config.DNSSEC {
msg.SetEdns0(4096, true) msg.SetEdns0(4096, true)
} }
packedMsg, err := msg.Pack() packedMsg, err := msg.Pack()
if err != nil { if err != nil {
logger.Error("DoH failed to pack DNS message: %v", err)
return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err) return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err)
} }
httpReq, err := http.NewRequest(http.MethodPost, c.upstreamURL.String(), bytes.NewReader(packedMsg)) httpReq, err := http.NewRequest(http.MethodPost, c.upstreamURL.String(), bytes.NewReader(packedMsg))
if err != nil { if err != nil {
logger.Error("DoH failed to create HTTP request: %v", err)
return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err) return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err)
} }
@@ -117,29 +135,37 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
httpResp, err := c.httpClient.Do(httpReq) httpResp, err := c.httpClient.Do(httpReq)
if err != nil { if err != nil {
logger.Error("DoH request failed to %s: %v", c.upstreamURL.Host, err)
return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err) return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err)
} }
defer httpResp.Body.Close() defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK { if httpResp.StatusCode != http.StatusOK {
logger.Error("DoH received non-200 status from %s: %s", c.upstreamURL.Host, httpResp.Status)
return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status) return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status)
} }
if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType { if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType {
logger.Error("DoH unexpected Content-Type from %s: %s", c.upstreamURL.Host, ct)
return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType) return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType)
} }
responseBody, err := io.ReadAll(httpResp.Body) responseBody, err := io.ReadAll(httpResp.Body)
if err != nil { if err != nil {
logger.Error("DoH failed reading response from %s: %v", c.upstreamURL.Host, err)
return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err) return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err)
} }
// Unpack the DNS message
recvMsg := new(dns.Msg) recvMsg := new(dns.Msg)
err = recvMsg.Unpack(responseBody) err = recvMsg.Unpack(responseBody)
if err != nil { if err != nil {
logger.Error("DoH failed to unpack response from %s: %v", c.upstreamURL.Host, err)
return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err) return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err)
} }
if len(recvMsg.Answer) > 0 {
logger.Debug("DoH response from %s: %d answers", c.upstreamURL.Host, len(recvMsg.Answer))
}
return recvMsg, nil return recvMsg, nil
} }

View File

@@ -10,6 +10,7 @@ import (
"net" "net"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
) )
@@ -32,6 +33,7 @@ type Client struct {
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port)
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
ServerName: config.Host, ServerName: config.Host,
@@ -42,10 +44,13 @@ func New(config Config) (*Client, error) {
targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(config.Host, config.Port)) targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(config.Host, config.Port))
if err != nil { if err != nil {
logger.Error("DoQ failed to resolve address %s:%s: %v", config.Host, config.Port, err)
return nil, err return nil, err
} }
udpConn, err := net.ListenUDP("udp", nil) udpConn, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
logger.Error("DoQ failed to create UDP connection: %v", err)
return nil, fmt.Errorf("failed to connect to target address: %w", err) return nil, fmt.Errorf("failed to connect to target address: %w", err)
} }
@@ -57,6 +62,8 @@ func New(config Config) (*Client, error) {
MaxIdleTimeout: 30 * time.Second, MaxIdleTimeout: 30 * time.Second,
} }
logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC)
return &Client{ return &Client{
targetAddr: targetAddr, targetAddr: targetAddr,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
@@ -69,22 +76,30 @@ func New(config Config) (*Client, error) {
} }
func (c *Client) Close() { func (c *Client) Close() {
logger.Debug("Closing DoQ client")
if c.udpConn != nil { if c.udpConn != nil {
c.udpConn.Close() c.udpConn.Close()
} }
} }
func (c *Client) OpenConnection() error { func (c *Client) OpenConnection() error {
logger.Debug("Opening DoQ connection to %s", c.targetAddr)
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
if err != nil { if err != nil {
logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err)
return err return err
} }
c.quicConn = quicConn c.quicConn = quicConn
logger.Debug("DoQ connection established to %s", c.targetAddr)
return nil return nil
} }
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) > 0 {
question := msg.Question[0]
logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr)
}
if c.quicConn == nil { if c.quicConn == nil {
err := c.OpenConnection() err := c.OpenConnection()
@@ -100,18 +115,21 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
} }
packed, err := msg.Pack() packed, err := msg.Pack()
if err != nil { if err != nil {
logger.Error("DoQ failed to pack message: %v", err)
return nil, fmt.Errorf("doq: failed to pack message: %w", err) return nil, fmt.Errorf("doq: failed to pack message: %w", err)
} }
var quicStream quic.Stream var quicStream quic.Stream
quicStream, err = c.quicConn.OpenStream() quicStream, err = c.quicConn.OpenStream()
if err != nil { if err != nil {
logger.Debug("DoQ stream failed, reconnecting: %v", err)
err = c.OpenConnection() err = c.OpenConnection()
if err != nil { if err != nil {
return nil, err return nil, err
} }
quicStream, err = c.quicConn.OpenStream() quicStream, err = c.quicConn.OpenStream()
if err != nil { if err != nil {
logger.Error("DoQ failed to open stream after reconnect: %v", err)
return nil, err return nil, err
} }
} }
@@ -119,42 +137,52 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
var lengthPrefixedMessage bytes.Buffer var lengthPrefixedMessage bytes.Buffer
err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(packed))) err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(packed)))
if err != nil { if err != nil {
logger.Error("DoQ failed to write message length: %v", err)
return nil, fmt.Errorf("failed to write message length: %w", err) return nil, fmt.Errorf("failed to write message length: %w", err)
} }
_, err = lengthPrefixedMessage.Write(packed) _, err = lengthPrefixedMessage.Write(packed)
if err != nil { if err != nil {
logger.Error("DoQ failed to write DNS message: %v", err)
return nil, fmt.Errorf("failed to write DNS message: %w", err) return nil, fmt.Errorf("failed to write DNS message: %w", err)
} }
_, err = quicStream.Write(lengthPrefixedMessage.Bytes()) _, err = quicStream.Write(lengthPrefixedMessage.Bytes())
if err != nil { if err != nil {
logger.Error("DoQ failed to write to stream: %v", err)
return nil, fmt.Errorf("failed writing to QUIC stream: %w", err) return nil, fmt.Errorf("failed writing to QUIC stream: %w", err)
} }
// Indicate that no further data will be written from this side
quicStream.Close() quicStream.Close()
lengthBuf := make([]byte, 2) lengthBuf := make([]byte, 2)
_, err = io.ReadFull(quicStream, lengthBuf) _, err = io.ReadFull(quicStream, lengthBuf)
if err != nil { if err != nil {
logger.Error("DoQ failed to read response length: %v", err)
return nil, fmt.Errorf("failed reading response length: %w", err) return nil, fmt.Errorf("failed reading response length: %w", err)
} }
messageLength := binary.BigEndian.Uint16(lengthBuf) messageLength := binary.BigEndian.Uint16(lengthBuf)
if messageLength == 0 { if messageLength == 0 {
logger.Error("DoQ received zero-length message")
return nil, fmt.Errorf("received zero-length message") return nil, fmt.Errorf("received zero-length message")
} }
responseBuf := make([]byte, messageLength) responseBuf := make([]byte, messageLength)
_, err = io.ReadFull(quicStream, responseBuf) _, err = io.ReadFull(quicStream, responseBuf)
if err != nil { if err != nil {
logger.Error("DoQ failed to read response data: %v", err)
return nil, fmt.Errorf("failed reading response data: %w", err) return nil, fmt.Errorf("failed reading response data: %w", err)
} }
recvMsg := new(dns.Msg) recvMsg := new(dns.Msg)
err = recvMsg.Unpack(responseBuf) err = recvMsg.Unpack(responseBuf)
if err != nil { if err != nil {
logger.Error("DoQ failed to parse response: %v", err)
return nil, fmt.Errorf("failed to parse DNS response: %w", err) return nil, fmt.Errorf("failed to parse DNS response: %w", err)
} }
if len(recvMsg.Answer) > 0 {
logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer))
}
return recvMsg, nil return recvMsg, nil
} }

View File

@@ -8,6 +8,7 @@ import (
"net" "net"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -27,7 +28,10 @@ type Client struct {
} }
func New(config Config) (*Client, error) { func New(config Config) (*Client, error) {
logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port)
if config.Host == "" { if config.Host == "" {
logger.Error("DoT client creation failed: empty host")
return nil, fmt.Errorf("dot: Host cannot be empty") return nil, fmt.Errorf("dot: Host cannot be empty")
} }
if config.WriteTimeout <= 0 { if config.WriteTimeout <= 0 {
@@ -43,6 +47,8 @@ func New(config Config) (*Client, error) {
ServerName: config.Host, ServerName: config.Host,
} }
logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC)
return &Client{ return &Client{
hostAndPort: hostAndPort, hostAndPort: hostAndPort,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
@@ -51,6 +57,7 @@ func New(config Config) (*Client, error) {
} }
func (c *Client) Close() { func (c *Client) Close() {
logger.Debug("Closing DoT client")
} }
func (c *Client) createConnection() (*tls.Conn, error) { func (c *Client) createConnection() (*tls.Conn, error) {
@@ -58,10 +65,23 @@ func (c *Client) createConnection() (*tls.Conn, error) {
Timeout: c.config.WriteTimeout, Timeout: c.config.WriteTimeout,
} }
return tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig) logger.Debug("Establishing DoT connection to %s", c.hostAndPort)
conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig)
if err != nil {
logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err)
return nil, err
}
logger.Debug("DoT connection established to %s", c.hostAndPort)
return conn, nil
} }
func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
if len(msg.Question) > 0 {
question := msg.Question[0]
logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort)
}
// Create connection for this query // Create connection for this query
conn, err := c.createConnection() conn, err := c.createConnection()
if err != nil { if err != nil {
@@ -75,6 +95,7 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
} }
packed, err := msg.Pack() packed, err := msg.Pack()
if err != nil { if err != nil {
logger.Error("DoT failed to pack message: %v", err)
return nil, fmt.Errorf("dot: failed to pack message: %w", err) return nil, fmt.Errorf("dot: failed to pack message: %w", err)
} }
@@ -85,40 +106,51 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) {
// Write query // Write query
if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil {
logger.Error("DoT failed to set write deadline: %v", err)
return nil, fmt.Errorf("dot: failed to set write deadline: %w", err) return nil, fmt.Errorf("dot: failed to set write deadline: %w", err)
} }
if _, err := conn.Write(data); err != nil { if _, err := conn.Write(data); err != nil {
logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("dot: failed to write message: %w", err) return nil, fmt.Errorf("dot: failed to write message: %w", err)
} }
// Read response // Read response
if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil {
logger.Error("DoT failed to set read deadline: %v", err)
return nil, fmt.Errorf("dot: failed to set read deadline: %w", err) return nil, fmt.Errorf("dot: failed to set read deadline: %w", err)
} }
// Read message length // Read message length
lengthBuf := make([]byte, 2) lengthBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, lengthBuf); err != nil { if _, err := io.ReadFull(conn, lengthBuf); err != nil {
logger.Error("DoT failed to read response length from %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("dot: failed to read response length: %w", err) return nil, fmt.Errorf("dot: failed to read response length: %w", err)
} }
msgLen := binary.BigEndian.Uint16(lengthBuf) msgLen := binary.BigEndian.Uint16(lengthBuf)
if msgLen > dns.MaxMsgSize { if msgLen > dns.MaxMsgSize {
logger.Error("DoT response too large from %s: %d bytes", c.hostAndPort, msgLen)
return nil, fmt.Errorf("dot: response message too large: %d", msgLen) return nil, fmt.Errorf("dot: response message too large: %d", msgLen)
} }
// Read message body // Read message body
buffer := make([]byte, msgLen) buffer := make([]byte, msgLen)
if _, err := io.ReadFull(conn, buffer); err != nil { if _, err := io.ReadFull(conn, buffer); err != nil {
logger.Error("DoT failed to read response from %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("dot: failed to read response: %w", err) return nil, fmt.Errorf("dot: failed to read response: %w", err)
} }
// Parse response // Parse response
response := new(dns.Msg) response := new(dns.Msg)
if err := response.Unpack(buffer); err != nil { if err := response.Unpack(buffer); err != nil {
logger.Error("DoT failed to unpack response from %s: %v", c.hostAndPort, err)
return nil, fmt.Errorf("dot: failed to unpack response: %w", err) return nil, fmt.Errorf("dot: failed to unpack response: %w", err)
} }
if len(response.Answer) > 0 {
logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer))
}
return response, nil return response, nil
} }

View File

@@ -1,195 +0,0 @@
// internal/client/client.go
package client
import (
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"time"
"github.com/afonsofrancof/sdns-perf/internal/protocols/do53"
"github.com/afonsofrancof/sdns-perf/internal/protocols/doh"
// "github.com/afonsofrancof/sdns-perf/internal/protocols/doq"
// "github.com/afonsofrancof/sdns-perf/internal/protocols/dot"
"github.com/miekg/dns"
)
// DNSClient defines the interface that all specific protocol clients must implement.
type DNSClient interface {
Query(domain string, queryType uint16) (*dns.Msg, error)
Close()
}
// Options holds common configuration options for creating any DNS client.
type Options struct {
Timeout time.Duration
DNSSEC bool
KeyLogPath string // Path for TLS key logging
}
type protocolType int
const (
protoUnknown protocolType = iota
protoDo53
protoDoT
protoDoH
protoDoH3
protoDoQ
)
// config holds the parsed details of an upstream server string.
// This is internal to the client package.
type config struct {
original string
protocol protocolType
host string
port string
path string
}
// parseUpstream takes a user-provided upstream string and attempts to determine
// the protocol, host, port, and path. (Internal helper)
func parseUpstream(upstreamStr string) (config, error) {
cfg := config{original: upstreamStr, protocol: protoUnknown}
// Try parsing as a full URL first
parsedURL, err := url.Parse(upstreamStr)
if err == nil && parsedURL.Scheme != "" && parsedURL.Host != "" {
cfg.host = parsedURL.Hostname()
cfg.port = parsedURL.Port()
cfg.path = parsedURL.Path
if cfg.path == "" {
cfg.path = "/" // Default path
}
switch strings.ToLower(parsedURL.Scheme) {
case "https", "doh":
cfg.protocol = protoDoH
if cfg.port == "" {
cfg.port = "443"
}
case "h3", "doh3":
cfg.protocol = protoDoH3
if cfg.port == "" {
cfg.port = "443"
}
case "tls", "dot":
cfg.protocol = protoDoT
if cfg.port == "" {
cfg.port = "853"
}
case "quic", "doq":
cfg.protocol = protoDoQ
if cfg.port == "" {
cfg.port = "853"
}
case "udp", "do53":
cfg.protocol = protoDo53
if cfg.port == "" {
cfg.port = "53"
}
default:
return cfg, fmt.Errorf("unsupported URL scheme: %q", parsedURL.Scheme)
}
return cfg, nil
}
// If not a valid URL or no scheme, assume plain DNS (Do53 UDP)
cfg.protocol = protoDo53
host, port, err := net.SplitHostPort(upstreamStr)
if err == nil {
cfg.host = host
cfg.port = port
if _, pErr := strconv.Atoi(port); pErr != nil {
return cfg, fmt.Errorf("invalid port %q in upstream %q: %w", port, upstreamStr, pErr)
}
} else {
cfg.host = upstreamStr
cfg.port = "53"
// Basic check for likely IPv6 without brackets and port
if strings.Contains(cfg.host, ":") && !strings.Contains(cfg.host, "[") {
_, resolveErr := net.ResolveUDPAddr("udp", net.JoinHostPort(cfg.host, cfg.port))
if resolveErr != nil {
return cfg, fmt.Errorf("invalid upstream format; could not parse %q as host:port or resolve as host with default port 53: %w", upstreamStr, err)
}
}
}
if cfg.host == "" {
return cfg, fmt.Errorf("could not extract host from upstream: %q", upstreamStr)
}
return cfg, nil
}
// New creates the appropriate DNS client based on the upstream string format.
// It returns an uninitialized client (connections are lazy).
func New(upstreamStr string, opts Options) (DNSClient, error) {
cfg, err := parseUpstream(upstreamStr)
if err != nil {
return nil, fmt.Errorf("client: failed to parse upstream %q: %w", upstreamStr, err)
}
var client DNSClient
var clientErr error
switch cfg.protocol {
case protoDo53:
// Ensure do53.New matches this signature
config := do53.Config{HostAndPort: net.JoinHostPort(cfg.host, cfg.port), DNSSEC: false}
client, clientErr = do53.New(config)
case protoDoH:
// Ensure doh.New matches this signature
config := doh.Config{Host: cfg.host, Port: cfg.port, Path: cfg.path, DNSSEC: false}
client, clientErr = doh.New(config)
case protoDoT:
// Ensure dot.New matches this signature
// client, clientErr = dot.New(cfg.hostPort(), opts.Timeout, opts.DNSSEC, opts.KeyLogPath)
// if clientErr == nil && client == nil {
// clientErr = fmt.Errorf("client: DoT package returned nil client without error")
// }
case protoDoQ:
// Ensure doq.New matches this signature
// client, clientErr = doq.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath)
// if clientErr == nil && client == nil {
// clientErr = fmt.Errorf("client: DoQ package returned nil client without error")
// }
case protoDoH3:
// Decide on DoH3 handling (fallback or error)
// Fallback example:
// fmt.Fprintf(os.Stderr, "Warning: DoH3 protocol (h3://) detected for %s. Attempting connection using standard DoH (HTTPS).\n", cfg.original)
// client, clientErr = doh.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath)
// Error example:
// clientErr = fmt.Errorf("client: DoH3 protocol (h3://) is not yet supported")
default:
clientErr = fmt.Errorf("client: unknown or unsupported protocol detected for upstream: %s", upstreamStr)
}
if clientErr != nil {
return nil, fmt.Errorf("client: failed to create client for %s: %w", upstreamStr, clientErr)
}
if client == nil {
// Should be caught by clientErr checks above, but as a safeguard
return nil, fmt.Errorf("client: internal error - nil client returned for %s", upstreamStr)
}
return client, nil
}
// Helper function to close key log writer if needed (can be used by specific clients)
func CloseKeyLogWriter(w io.WriteCloser) error {
if w != nil {
return w.Close()
}
return nil
}

View File

@@ -1,3 +0,0 @@
package dnscrypt
// DNSCrypt resolver implementation

View File

@@ -1,3 +0,0 @@
package dnssec
// DNSSEC resolver implementation

View File

@@ -1,129 +0,0 @@
package do53
import (
"fmt"
"log"
"net"
"sync"
"github.com/miekg/dns"
)
type Config struct {
HostAndPort string
DNSSEC bool
}
type Client struct {
udpAddr *net.UDPAddr
conn *net.UDPConn
responseChannels map[uint16]chan *dns.Msg
responseMutex *sync.Mutex
config Config
}
func New(config Config) (*Client, error) {
udpAddr, err := net.ResolveUDPAddr("udp", config.HostAndPort)
if err != nil {
return nil, fmt.Errorf("do53: failed to resolve UDP address %q: %w", config.HostAndPort, err)
}
conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
return nil, fmt.Errorf("do53: failed to dial UDP connection to %s: %w", config.HostAndPort, err)
}
responseChannels := map[uint16]chan *dns.Msg{}
rcMutex := new(sync.Mutex)
client := &Client{
udpAddr: udpAddr,
conn: conn,
responseChannels: responseChannels,
responseMutex: rcMutex,
config: config,
}
go client.receiveLoop()
return client, nil
}
func (c *Client) Close() {
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
}
func (c *Client) receiveLoop() {
buffer := make([]byte, dns.MaxMsgSize)
for {
// Reads one UDP Datagram
n, err := c.conn.Read(buffer)
if err != nil {
log.Printf("do53: failed to read DNS response: %s", err.Error())
}
recvMsg := new(dns.Msg)
err = recvMsg.Unpack(buffer[:n])
if err != nil {
log.Printf("do53: failed to unpack DNS response: %s", err.Error())
continue
}
c.responseMutex.Lock()
respChan, ok := c.responseChannels[recvMsg.Id]
delete(c.responseChannels, recvMsg.Id)
c.responseMutex.Unlock()
if ok {
respChan <- recvMsg
} else {
log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id)
}
}
}
func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), queryType)
msg.Id = dns.Id()
msg.RecursionDesired = true
if c.config.DNSSEC {
msg.SetEdns0(4096, true)
}
respChan := make(chan *dns.Msg)
c.responseMutex.Lock()
c.responseChannels[msg.Id] = respChan
c.responseMutex.Unlock()
packedMsg, err := msg.Pack()
if err != nil {
c.responseMutex.Lock()
delete(c.responseChannels, msg.Id)
c.responseMutex.Unlock()
return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err)
}
_, err = c.conn.Write(packedMsg)
if err != nil {
c.responseMutex.Lock()
delete(c.responseChannels, msg.Id)
c.responseMutex.Unlock()
return nil, fmt.Errorf("do53: failed to send DNS query: %w", err)
}
recvMsg := <-respChan
return recvMsg, nil
}

View File

@@ -1,40 +0,0 @@
package do53
import (
"github.com/miekg/dns"
)
func NewDNSMessage(domain string, queryType string) ([]byte, error) {
// TODO: Move this somewhere else and receive the type already parsed
var queryTypeValue uint16
switch queryType {
case "A":
queryTypeValue = dns.TypeA
case "AAAA":
queryTypeValue = dns.TypeAAAA
case "MX":
queryTypeValue = dns.TypeMX
case "CNAME":
queryTypeValue = dns.TypeCNAME
case "TXT":
queryTypeValue = dns.TypeTXT
default:
queryTypeValue = dns.TypeA
}
message := new(dns.Msg)
message.Id = dns.Id()
message.Response = false
message.Opcode = dns.OpcodeQuery
message.Question = make([]dns.Question, 1)
message.Question[0] = dns.Question{Name: domain, Qtype: uint16(queryTypeValue), Qclass: dns.ClassINET}
message.Compress = true
wireMsg, err := message.Pack()
if err != nil {
return nil, err
}
return wireMsg, nil
}

View File

@@ -1,125 +0,0 @@
package doh
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"github.com/miekg/dns"
)
const dnsMessageContentType = "application/dns-message"
type Config struct {
Host string
Port string
Path string
DNSSEC bool
}
type Client struct {
httpClient *http.Client
upstreamURL *url.URL
config Config
}
func New(config Config) (*Client, error) {
if config.Host == "" || config.Port == "" || config.Path == "" {
fmt.Printf("%v,%v,%v", config.Host,config.Port,config.Path)
return nil, errors.New("doh: host, port, and path must not be empty")
}
if !strings.HasPrefix(config.Path, "/") {
config.Path = "/" + config.Path
}
rawURL := "https://" + net.JoinHostPort(config.Host, config.Port) + config.Path
parsedURL, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err)
}
tlsConfig := &tls.Config{
ServerName: config.Host,
MinVersion: tls.VersionTLS12,
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
ForceAttemptHTTP2: true,
}
httpClient := &http.Client{
Transport: transport,
}
return &Client{
httpClient: httpClient,
upstreamURL: parsedURL,
config: config,
}, nil
}
// Close cleans up idle connections held by the underlying HTTP transport.
func (c *Client) Close() {
c.httpClient.CloseIdleConnections()
}
func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) {
msg := new(dns.Msg)
msg.SetQuestion(dns.Fqdn(domain), queryType)
msg.Id = dns.Id()
msg.RecursionDesired = true
if c.config.DNSSEC {
msg.SetEdns0(4096, true)
}
packedMsg, err := msg.Pack()
if err != nil {
return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err)
}
httpReq, err := http.NewRequest("POST", c.upstreamURL.String(), bytes.NewReader(packedMsg))
if err != nil {
return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err)
}
httpReq.Header.Set("User-Agent", "sdns-perf")
httpReq.Header.Set("Content-Type", dnsMessageContentType)
httpReq.Header.Set("Accept", dnsMessageContentType)
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err)
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status)
}
if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType {
return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType)
}
responseBody, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err)
}
// Unpack the DNS message
recvMsg := new(dns.Msg)
err = recvMsg.Unpack(responseBody)
if err != nil {
return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err)
}
return recvMsg, nil
}

View File

@@ -1,175 +0,0 @@
package doq
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"time"
"github.com/afonsofrancof/sdns-perf/internal/protocols/do53"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
)
type Client struct {
targetAddr *net.UDPAddr
keyLogFile *os.File
tlsConfig *tls.Config
udpConn *net.UDPConn
quicConn quic.Connection
quicTransport *quic.Transport
quicConfig *quic.Config
}
func New(target string) (*Client, error) {
keyLogFile, err := os.OpenFile(
"tls-key-log.txt",
os.O_APPEND|os.O_CREATE|os.O_WRONLY,
0600,
)
if err != nil {
return nil, fmt.Errorf("failed opening key log file: %w", err)
}
tlsConfig := &tls.Config{
// FIX: Actually check the domain name
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS13,
ClientSessionCache: tls.NewLRUClientSessionCache(100),
KeyLogWriter: keyLogFile,
NextProtos: []string{"doq"},
}
udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:6000")
if err != nil {
return nil, fmt.Errorf("failed to resolve target address: %w", err)
}
targetAddr, err := net.ResolveUDPAddr("udp", target)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to target address: %w", err)
}
quicTransport := quic.Transport{
Conn: udpConn,
}
quicConfig := quic.Config{
// Use the default value of 30 seconds
MaxIdleTimeout: 30 * time.Second,
}
return &Client{
targetAddr: targetAddr,
keyLogFile: keyLogFile,
tlsConfig: tlsConfig,
udpConn: udpConn,
quicConn: nil,
quicTransport: &quicTransport,
quicConfig: &quicConfig,
}, nil
}
func (c *Client) Close() {
if c.keyLogFile != nil {
c.keyLogFile.Close()
}
if c.udpConn != nil {
c.udpConn.Close()
}
}
func (c *Client) OpenConnection() error {
quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig)
if err != nil {
return err
}
c.quicConn = quicConn
return nil
}
func (c *Client) Query(domain, queryType string, dnssec bool) error {
if c.quicConn == nil {
err := c.OpenConnection()
if err != nil {
return err
}
}
DNSMessage, err := do53.NewDNSMessage(domain, queryType)
if err != nil {
return err
}
var quicStream quic.Stream
quicStream, err = c.quicConn.OpenStream()
if err != nil {
err = c.OpenConnection()
if err != nil {
return err
}
quicStream, err = c.quicConn.OpenStream()
if err != nil {
return err
}
}
var lengthPrefixedMessage bytes.Buffer
err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage)))
if err != nil {
return fmt.Errorf("failed to write message length: %w", err)
}
_, err = lengthPrefixedMessage.Write(DNSMessage)
if err != nil {
return fmt.Errorf("failed to write DNS message: %w", err)
}
_, err = quicStream.Write(lengthPrefixedMessage.Bytes())
if err != nil {
return fmt.Errorf("failed writing to QUIC stream: %w", err)
}
// Indicate that no further data will be written from this side
quicStream.Close()
lengthBuf := make([]byte, 2)
_, err = io.ReadFull(quicStream, lengthBuf)
if err != nil {
return fmt.Errorf("failed reading response length: %w", err)
}
messageLength := binary.BigEndian.Uint16(lengthBuf)
if messageLength == 0 {
return fmt.Errorf("received zero-length message")
}
responseBuf := make([]byte, messageLength)
_, err = io.ReadFull(quicStream, responseBuf)
if err != nil {
return fmt.Errorf("failed reading response data: %w", err)
}
recvMsg := new(dns.Msg)
err = recvMsg.Unpack(responseBuf)
if err != nil {
return fmt.Errorf("failed to parse DNS response: %w", err)
}
// TODO: Check if the response had no errors or TD bit set
fmt.Println(c.quicConn.ConnectionState().Used0RTT)
for _, answer := range recvMsg.Answer {
fmt.Println(answer.String())
}
return nil
}

View File

@@ -1,161 +0,0 @@
package dot
import (
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"log"
"net"
"os"
"sync"
"time"
"github.com/miekg/dns"
)
type Config struct {
Host string
Port string
DNSSEC bool
Debug bool
}
type Client struct {
config Config
serverAddr *net.TCPAddr
tcpConn *net.TCPConn
tlsConn *tls.Conn
tlsConfig *tls.Config
keyLogFile *os.File
sendChannel chan *dns.Msg
responseChannels map[uint16]chan *dns.Msg
responseMutex *sync.Mutex
}
func New(config Config) (*Client, error) {
serverAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(config.Host, config.Port))
if err != nil {
return nil, fmt.Errorf("dot: failed to resolve TCP address %q: %w", config.Host, err)
}
var keyLogFile *os.File
if config.Debug {
keyLogFile, err = os.OpenFile(
"tls-key-log.txt",
os.O_APPEND|os.O_CREATE|os.O_WRONLY,
0600,
)
if err != nil {
log.Printf("dot: failed opening TLS key log file: %v", err)
keyLogFile = nil
}
}
tlsConfig := &tls.Config{
ServerName: serverAddr.IP.String(),
MinVersion: tls.VersionTLS12,
KeyLogWriter: keyLogFile,
ClientSessionCache: tls.NewLRUClientSessionCache(100),
}
client := &Client{
config: config,
serverAddr: serverAddr,
tlsConfig: tlsConfig,
keyLogFile: keyLogFile,
}
go client.receiveLoop()
return client, nil
}
func (c *Client) Close() {
if c.tlsConn != nil {
c.tlsConn.Close()
c.tlsConn = nil
}
if c.tcpConn != nil {
c.tcpConn.Close()
c.tcpConn = nil
}
if c.keyLogFile != nil {
c.keyLogFile.Close()
c.keyLogFile = nil
}
}
func (c *Client) receiveLoop() {
lengthBuffer := make([]byte, 2)
buffer := make([]byte, dns.MaxMsgSize)
for {
msgSize, err := io.ReadFull(c.tlsConn, lengthBuffer)
if err != nil {
log.Printf("doh: failed to read the DNS message's size: %s", err.Error())
// FIX: HANDLE RECONNECTION
}
n, err := io.ReadFull(c.tlsConn, buffer[:msgSize])
if err != nil {
log.Printf("doh: failed to read the DNS message: %s", err.Error())
// FIX: HANDLE RECONNECTION
}
recvMsg := new(dns.Msg)
err = recvMsg.Unpack(buffer[:n])
if err != nil {
log.Printf("do53: failed to unpack DNS response: %s", err.Error())
continue
}
c.responseMutex.Lock()
respChan, ok := c.responseChannels[recvMsg.Id]
delete(c.responseChannels, recvMsg.Id)
c.responseMutex.Unlock()
if ok {
respChan <- recvMsg
} else {
log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id)
}
}
}
func (c *Client) connect(ctx context.Context) error {
tcpConn, err := net.DialTCP("tcp", nil, c.serverAddr)
if err != nil {
return fmt.Errorf("dot: failed to establish TCP connection: %w", err)
}
c.tcpConn.SetKeepAlive(true)
c.tcpConn.SetKeepAlivePeriod(1 * time.Minute)
tlsConn := tls.Client(c.tcpConn, c.tlsConfig)
err = tlsConn.HandshakeContext(ctx)
if err != nil {
c.tcpConn.Close()
c.tcpConn = nil
return fmt.Errorf("dot: failed to execute the TLS handshake: %w", err)
}
c.tlsConn = tlsConn
log.Println("dot: TCP/TLS connection established successfully.")
return nil
}
func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) {
//TODO
}

30
main.go
View File

@@ -2,12 +2,11 @@ package main
import ( import (
"fmt" "fmt"
"log"
"os"
"strings" "strings"
"time" "time"
"github.com/afonsofrancof/sdns-proxy/client" "github.com/afonsofrancof/sdns-proxy/client"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/afonsofrancof/sdns-proxy/server" "github.com/afonsofrancof/sdns-proxy/server"
"github.com/alecthomas/kong" "github.com/alecthomas/kong"
@@ -15,6 +14,7 @@ import (
) )
var cli struct { var cli struct {
Debug bool `help:"Enable debug logging globally." short:"D" env:"DEBUG"`
Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."` Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."`
Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."` Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."`
} }
@@ -41,7 +41,7 @@ type ListenCmd struct {
} }
func (q *QueryCmd) Run() error { func (q *QueryCmd) Run() error {
log.Printf("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)\n", logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)",
q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.Timeout) q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.Timeout)
opts := client.Options{ opts := client.Options{
@@ -50,14 +50,17 @@ func (q *QueryCmd) Run() error {
StrictValidation: q.StrictValidation, StrictValidation: q.StrictValidation,
} }
logger.Debug("Creating DNS client with options: %+v", opts)
dnsClient, err := client.New(q.Server, opts) dnsClient, err := client.New(q.Server, opts)
if err != nil { if err != nil {
logger.Error("Failed to create DNS client: %v", err)
return err return err
} }
defer dnsClient.Close() defer dnsClient.Close()
qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)] qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)]
if !ok { if !ok {
logger.Error("Invalid query type: %s", q.QueryType)
return fmt.Errorf("invalid query type: %s", q.QueryType) return fmt.Errorf("invalid query type: %s", q.QueryType)
} }
@@ -67,11 +70,15 @@ func (q *QueryCmd) Run() error {
msg.Id = dns.Id() msg.Id = dns.Id()
msg.RecursionDesired = true msg.RecursionDesired = true
logger.Debug("Sending DNS query: ID=%d, Question=%s %s", msg.Id, q.DomainName, q.QueryType)
recvMsg, err := dnsClient.Query(msg) recvMsg, err := dnsClient.Query(msg)
if err != nil { if err != nil {
logger.Error("DNS query failed: %v", err)
return err return err
} }
logger.Debug("Received DNS response: ID=%d, Rcode=%s, Answers=%d",
recvMsg.Id, dns.RcodeToString[recvMsg.Rcode], len(recvMsg.Answer))
printResponse(recvMsg.Question[0].Name, q.QueryType, recvMsg) printResponse(recvMsg.Question[0].Name, q.QueryType, recvMsg)
return nil return nil
} }
@@ -87,15 +94,17 @@ func (l *ListenCmd) Run() error {
Verbose: l.Verbose, Verbose: l.Verbose,
} }
logger.Debug("Server config: %+v", config)
srv, err := server.New(config) srv, err := server.New(config)
if err != nil { if err != nil {
logger.Error("Failed to create server: %v", err)
return fmt.Errorf("failed to create server: %w", err) return fmt.Errorf("failed to create server: %w", err)
} }
log.Printf("Starting DNS proxy server on %s", l.Address) logger.Info("Starting DNS proxy server on %s", l.Address)
log.Printf("Upstream server: %v", l.Upstream) logger.Info("Upstream server: %v", l.Upstream)
log.Printf("Fallback server: %v", l.Fallback) logger.Info("Fallback server: %v", l.Fallback)
log.Printf("Bootstrap server: %v", l.Bootstrap) logger.Info("Bootstrap server: %v", l.Bootstrap)
return srv.Start() return srv.Start()
} }
@@ -153,9 +162,6 @@ func printResponse(domain, qtype string, msg *dns.Msg) {
} }
func main() { func main() {
log.SetOutput(os.Stderr)
log.SetFlags(log.Ltime | log.Lshortfile)
kongCtx := kong.Parse(&cli, kongCtx := kong.Parse(&cli,
kong.Name("sdns-proxy"), kong.Name("sdns-proxy"),
kong.Description("A DNS client/server tool supporting multiple protocols."), kong.Description("A DNS client/server tool supporting multiple protocols."),
@@ -163,6 +169,10 @@ func main() {
kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}), kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}),
) )
// Set global debug flag
logger.SetDebug(cli.Debug)
logger.Debug("Debug logging enabled")
err := kongCtx.Run() err := kongCtx.Run()
kongCtx.FatalIfErrorf(err) kongCtx.FatalIfErrorf(err)
} }

View File

@@ -3,7 +3,6 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"net" "net"
"net/url" "net/url"
"os" "os"
@@ -14,6 +13,7 @@ import (
"time" "time"
"github.com/afonsofrancof/sdns-proxy/client" "github.com/afonsofrancof/sdns-proxy/client"
"github.com/afonsofrancof/sdns-proxy/common/logger"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@@ -50,7 +50,10 @@ type Server struct {
} }
func New(config Config) (*Server, error) { func New(config Config) (*Server, error) {
logger.Debug("Creating new server with config: %+v", config)
if config.Upstream == "" { if config.Upstream == "" {
logger.Error("Upstream server is required")
return nil, fmt.Errorf("upstream server is required") return nil, fmt.Errorf("upstream server is required")
} }
@@ -60,11 +63,17 @@ func New(config Config) (*Server, error) {
needsBootstrap = needsBootstrap || containsHostname(config.Fallback) needsBootstrap = needsBootstrap || containsHostname(config.Fallback)
} }
logger.Debug("Bootstrap needed: %v (upstream has hostname: %v, fallback has hostname: %v)",
needsBootstrap, containsHostname(config.Upstream),
config.Fallback != "" && containsHostname(config.Fallback))
if needsBootstrap && config.Bootstrap == "" { if needsBootstrap && config.Bootstrap == "" {
logger.Error("Bootstrap server is required when upstream or fallback contains hostnames")
return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames") return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames")
} }
if config.Bootstrap != "" && containsHostname(config.Bootstrap) { if config.Bootstrap != "" && containsHostname(config.Bootstrap) {
logger.Error("Bootstrap server cannot contain hostnames: %s", config.Bootstrap)
return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap) return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap)
} }
@@ -76,17 +85,21 @@ func New(config Config) (*Server, error) {
// Create bootstrap client if needed // Create bootstrap client if needed
if config.Bootstrap != "" { if config.Bootstrap != "" {
logger.Debug("Creating bootstrap client for %s", config.Bootstrap)
bootstrapClient, err := client.New(config.Bootstrap, client.Options{ bootstrapClient, err := client.New(config.Bootstrap, client.Options{
DNSSEC: false, // For now DNSSEC: false,
}) })
if err != nil { if err != nil {
logger.Error("Failed to create bootstrap client: %v", err)
return nil, fmt.Errorf("failed to create bootstrap client: %w", err) return nil, fmt.Errorf("failed to create bootstrap client: %w", err)
} }
s.bootstrapClient = bootstrapClient s.bootstrapClient = bootstrapClient
logger.Debug("Bootstrap client created successfully")
} }
// Initialize upstream and fallback clients // Initialize upstream and fallback clients
if err := s.initClients(); err != nil { if err := s.initClients(); err != nil {
logger.Error("Failed to initialize clients: %v", err)
return nil, fmt.Errorf("failed to initialize clients: %w", err) return nil, fmt.Errorf("failed to initialize clients: %w", err)
} }
@@ -100,86 +113,111 @@ func New(config Config) (*Server, error) {
Handler: mux, Handler: mux,
} }
logger.Debug("Server created successfully, listening on %s", config.Address)
return s, nil return s, nil
} }
func containsHostname(serverAddr string) bool { func containsHostname(serverAddr string) bool {
logger.Debug("Checking if %s contains hostname", serverAddr)
// Use the same parsing logic as the client package // Use the same parsing logic as the client package
parsedURL, err := url.Parse(serverAddr) parsedURL, err := url.Parse(serverAddr)
if err != nil { if err != nil {
logger.Debug("URL parsing failed for %s, treating as plain address", serverAddr)
// If URL parsing fails, assume it's a plain address // If URL parsing fails, assume it's a plain address
host, _, err := net.SplitHostPort(serverAddr) host, _, err := net.SplitHostPort(serverAddr)
if err != nil { if err != nil {
// Assume it's just a host // Assume it's just a host
return net.ParseIP(serverAddr) == nil isHostname := net.ParseIP(serverAddr) == nil
logger.Debug("Address %s is hostname: %v", serverAddr, isHostname)
return isHostname
} }
return net.ParseIP(host) == nil isHostname := net.ParseIP(host) == nil
logger.Debug("Host %s from %s is hostname: %v", host, serverAddr, isHostname)
return isHostname
} }
host := parsedURL.Hostname() host := parsedURL.Hostname()
if host == "" { if host == "" {
logger.Debug("No hostname found in URL %s", serverAddr)
return false return false
} }
return net.ParseIP(host) == nil isHostname := net.ParseIP(host) == nil
logger.Debug("Host %s from URL %s is hostname: %v", host, serverAddr, isHostname)
return isHostname
} }
func (s *Server) initClients() error { func (s *Server) initClients() error {
logger.Debug("Initializing DNS clients")
// Initialize upstream client // Initialize upstream client
resolvedUpstream, err := s.resolveServerAddress(s.config.Upstream) resolvedUpstream, err := s.resolveServerAddress(s.config.Upstream)
if err != nil { if err != nil {
logger.Error("Failed to resolve upstream %s: %v", s.config.Upstream, err)
return fmt.Errorf("failed to resolve upstream %s: %w", s.config.Upstream, err) return fmt.Errorf("failed to resolve upstream %s: %w", s.config.Upstream, err)
} }
logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream)
upstreamClient, err := client.New(resolvedUpstream, client.Options{ upstreamClient, err := client.New(resolvedUpstream, client.Options{
DNSSEC: s.config.DNSSEC, DNSSEC: s.config.DNSSEC,
}) })
if err != nil { if err != nil {
logger.Error("Failed to create upstream client: %v", err)
return fmt.Errorf("failed to create upstream client: %w", err) return fmt.Errorf("failed to create upstream client: %w", err)
} }
s.upstreamClient = upstreamClient s.upstreamClient = upstreamClient
if s.config.Verbose { if s.config.Verbose {
log.Printf("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream) logger.Info("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream)
} }
// Initialize fallback client if specified // Initialize fallback client if specified
if s.config.Fallback != "" { if s.config.Fallback != "" {
resolvedFallback, err := s.resolveServerAddress(s.config.Fallback) resolvedFallback, err := s.resolveServerAddress(s.config.Fallback)
if err != nil { if err != nil {
logger.Error("Failed to resolve fallback %s: %v", s.config.Fallback, err)
return fmt.Errorf("failed to resolve fallback %s: %w", s.config.Fallback, err) return fmt.Errorf("failed to resolve fallback %s: %w", s.config.Fallback, err)
} }
logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback)
fallbackClient, err := client.New(resolvedFallback, client.Options{ fallbackClient, err := client.New(resolvedFallback, client.Options{
DNSSEC: s.config.DNSSEC, DNSSEC: s.config.DNSSEC,
}) })
if err != nil { if err != nil {
logger.Error("Failed to create fallback client: %v", err)
return fmt.Errorf("failed to create fallback client: %w", err) return fmt.Errorf("failed to create fallback client: %w", err)
} }
s.fallbackClient = fallbackClient s.fallbackClient = fallbackClient
if s.config.Verbose { if s.config.Verbose {
log.Printf("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback) logger.Info("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback)
} }
} }
logger.Debug("All DNS clients initialized successfully")
return nil return nil
} }
func (s *Server) resolveServerAddress(serverAddr string) (string, error) { func (s *Server) resolveServerAddress(serverAddr string) (string, error) {
logger.Debug("Resolving server address: %s", serverAddr)
// If it doesn't contain hostnames, return as-is // If it doesn't contain hostnames, return as-is
if !containsHostname(serverAddr) { if !containsHostname(serverAddr) {
logger.Debug("Address %s contains no hostnames, returning as-is", serverAddr)
return serverAddr, nil return serverAddr, nil
} }
// If no bootstrap client, we can't resolve hostnames // If no bootstrap client, we can't resolve hostnames
if s.bootstrapClient == nil { if s.bootstrapClient == nil {
logger.Error("Cannot resolve hostname in %s: no bootstrap server configured", serverAddr)
return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr) return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr)
} }
// Use the same parsing logic as the client package // Use the same parsing logic as the client package
parsedURL, err := url.Parse(serverAddr) parsedURL, err := url.Parse(serverAddr)
if err != nil { if err != nil {
logger.Debug("Parsing %s as plain host:port format", serverAddr)
// Handle plain host:port format // Handle plain host:port format
host, port, err := net.SplitHostPort(serverAddr) host, port, err := net.SplitHostPort(serverAddr)
if err != nil { if err != nil {
@@ -188,6 +226,7 @@ func (s *Server) resolveServerAddress(serverAddr string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
logger.Debug("Resolved %s to %s", serverAddr, resolvedIP)
return resolvedIP, nil return resolvedIP, nil
} }
@@ -195,12 +234,15 @@ func (s *Server) resolveServerAddress(serverAddr string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return net.JoinHostPort(resolvedIP, port), nil resolved := net.JoinHostPort(resolvedIP, port)
logger.Debug("Resolved %s to %s", serverAddr, resolved)
return resolved, nil
} }
// Handle URL format // Handle URL format
hostname := parsedURL.Hostname() hostname := parsedURL.Hostname()
if hostname == "" { if hostname == "" {
logger.Error("No hostname in URL: %s", serverAddr)
return "", fmt.Errorf("no hostname in URL: %s", serverAddr) return "", fmt.Errorf("no hostname in URL: %s", serverAddr)
} }
@@ -217,21 +259,26 @@ func (s *Server) resolveServerAddress(serverAddr string) (string, error) {
parsedURL.Host = net.JoinHostPort(resolvedIP, port) parsedURL.Host = net.JoinHostPort(resolvedIP, port)
} }
return parsedURL.String(), nil resolved := parsedURL.String()
logger.Debug("Resolved URL %s to %s", serverAddr, resolved)
return resolved, nil
} }
func (s *Server) resolveHostname(hostname string) (string, error) { func (s *Server) resolveHostname(hostname string) (string, error) {
logger.Debug("Resolving hostname: %s", hostname)
// Check cache first // Check cache first
s.hostsMutex.RLock() s.hostsMutex.RLock()
if ip, exists := s.resolvedHosts[hostname]; exists { if ip, exists := s.resolvedHosts[hostname]; exists {
s.hostsMutex.RUnlock() s.hostsMutex.RUnlock()
logger.Debug("Found cached resolution for %s: %s", hostname, ip)
return ip, nil return ip, nil
} }
s.hostsMutex.RUnlock() s.hostsMutex.RUnlock()
// Resolve using bootstrap // Resolve using bootstrap
if s.config.Verbose { if s.config.Verbose {
log.Printf("Resolving hostname %s using bootstrap server", hostname) logger.Info("Resolving hostname %s using bootstrap server", hostname)
} }
msg := new(dns.Msg) msg := new(dns.Msg)
@@ -239,12 +286,16 @@ func (s *Server) resolveHostname(hostname string) (string, error) {
msg.Id = dns.Id() msg.Id = dns.Id()
msg.RecursionDesired = true msg.RecursionDesired = true
logger.Debug("Sending bootstrap query for %s (ID: %d)", hostname, msg.Id)
msg, err := s.bootstrapClient.Query(msg) msg, err := s.bootstrapClient.Query(msg)
if err != nil { if err != nil {
logger.Error("Bootstrap query failed for %s: %v", hostname, err)
return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err) return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err)
} }
logger.Debug("Bootstrap response for %s: %d answers", hostname, len(msg.Answer))
if len(msg.Answer) == 0 { if len(msg.Answer) == 0 {
logger.Error("No A records found for %s", hostname)
return "", fmt.Errorf("no A records found for %s", hostname) return "", fmt.Errorf("no A records found for %s", hostname)
} }
@@ -259,18 +310,21 @@ func (s *Server) resolveHostname(hostname string) (string, error) {
s.hostsMutex.Unlock() s.hostsMutex.Unlock()
if s.config.Verbose { if s.config.Verbose {
log.Printf("Resolved %s to %s", hostname, ip) logger.Info("Resolved %s to %s", hostname, ip)
} }
logger.Debug("Cached resolution: %s -> %s", hostname, ip)
return ip, nil return ip, nil
} }
} }
logger.Error("No valid A record found for %s", hostname)
return "", fmt.Errorf("no valid A record found for %s", hostname) return "", fmt.Errorf("no valid A record found for %s", hostname)
} }
func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
logger.Debug("Received request with no questions from %s", w.RemoteAddr())
dns.HandleFailed(w, r) dns.HandleFailed(w, r)
return return
} }
@@ -279,8 +333,11 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
domain := strings.ToLower(question.Name) domain := strings.ToLower(question.Name)
qtype := question.Qtype qtype := question.Qtype
logger.Debug("Handling DNS request: %s %s from %s (ID: %d)",
question.Name, dns.TypeToString[qtype], w.RemoteAddr(), r.Id)
if s.config.Verbose { if s.config.Verbose {
log.Printf("Query: %s %s from %s", logger.Info("Query: %s %s from %s",
question.Name, question.Name,
dns.TypeToString[qtype], dns.TypeToString[qtype],
w.RemoteAddr()) w.RemoteAddr())
@@ -290,40 +347,48 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil { if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil {
response := s.buildResponse(r, cachedRecords) response := s.buildResponse(r, cachedRecords)
if s.config.Verbose { if s.config.Verbose {
log.Printf("Cache hit: %s %s -> %d records", logger.Info("Cache hit: %s %s -> %d records",
question.Name, question.Name,
dns.TypeToString[qtype], dns.TypeToString[qtype],
len(cachedRecords)) len(cachedRecords))
} }
logger.Debug("Serving cached response for %s %s (%d records)",
question.Name, dns.TypeToString[qtype], len(cachedRecords))
w.WriteMsg(response) w.WriteMsg(response)
return return
} }
logger.Debug("Cache miss for %s %s, querying upstream", question.Name, dns.TypeToString[qtype])
// Try upstream first // Try upstream first
response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype) response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype)
if err != nil { if err != nil {
if s.config.Verbose { if s.config.Verbose {
log.Printf("Upstream query failed: %v", err) logger.Info("Upstream query failed: %v", err)
} }
logger.Debug("Upstream query failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err)
// Try fallback if available // Try fallback if available
if s.fallbackClient != nil { if s.fallbackClient != nil {
if s.config.Verbose { if s.config.Verbose {
log.Printf("Trying fallback server") logger.Info("Trying fallback server")
} }
logger.Debug("Attempting fallback query for %s %s", question.Name, dns.TypeToString[qtype])
response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype) response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype)
if err != nil { if err != nil {
log.Printf("Both upstream and fallback failed for %s %s: %v", logger.Error("Both upstream and fallback failed for %s %s: %v",
question.Name, question.Name,
dns.TypeToString[qtype], dns.TypeToString[qtype],
err) err)
} else {
logger.Debug("Fallback query succeeded for %s %s", question.Name, dns.TypeToString[qtype])
} }
} }
// If still failed, return SERVFAIL // If still failed, return SERVFAIL
if err != nil { if err != nil {
log.Printf("All servers failed for %s %s: %v", logger.Error("All servers failed for %s %s: %v",
question.Name, question.Name,
dns.TypeToString[qtype], dns.TypeToString[qtype],
err) err)
@@ -334,6 +399,8 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
return return
} }
} else {
logger.Debug("Upstream query succeeded for %s %s", question.Name, dns.TypeToString[qtype])
} }
// Cache successful response // Cache successful response
@@ -343,12 +410,14 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
response.Id = r.Id response.Id = r.Id
if s.config.Verbose { if s.config.Verbose {
log.Printf("Response: %s %s -> %d answers", logger.Info("Response: %s %s -> %d answers",
question.Name, question.Name,
dns.TypeToString[qtype], dns.TypeToString[qtype],
len(response.Answer)) len(response.Answer))
} }
logger.Debug("Sending response for %s %s: %d answers, rcode: %s",
question.Name, dns.TypeToString[qtype], len(response.Answer), dns.RcodeToString[response.Rcode])
w.WriteMsg(response) w.WriteMsg(response)
} }
@@ -360,17 +429,22 @@ func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR {
s.cacheMutex.RUnlock() s.cacheMutex.RUnlock()
if !exists { if !exists {
logger.Debug("No cache entry for %s %s", domain, dns.TypeToString[qtype])
return nil return nil
} }
// Check if expired and clean up on the spot // Check if expired and clean up on the spot
if time.Now().After(entry.expiresAt) { if time.Now().After(entry.expiresAt) {
logger.Debug("Cache entry expired for %s %s", domain, dns.TypeToString[qtype])
s.cacheMutex.Lock() s.cacheMutex.Lock()
delete(s.queryCache, key) delete(s.queryCache, key)
s.cacheMutex.Unlock() s.cacheMutex.Unlock()
return nil return nil
} }
logger.Debug("Cache hit for %s %s (%d records, expires in %v)",
domain, dns.TypeToString[qtype], len(entry.records), time.Until(entry.expiresAt))
// Return a copy of the cached records // Return a copy of the cached records
records := make([]dns.RR, len(entry.records)) records := make([]dns.RR, len(entry.records))
for i, rr := range entry.records { for i, rr := range entry.records {
@@ -382,14 +456,14 @@ func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR {
func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg { func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg {
response := new(dns.Msg) response := new(dns.Msg)
response.SetReply(request) response.SetReply(request)
response.Answer = records response.Answer = records
logger.Debug("Built response with %d records", len(records))
return response return response
} }
func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) { func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) {
if msg == nil || len(msg.Answer) == 0 { if msg == nil || len(msg.Answer) == 0 {
logger.Debug("Not caching empty response for %s %s", domain, dns.TypeToString[qtype])
return return
} }
@@ -408,11 +482,13 @@ func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) {
} }
if len(validRecords) == 0 { if len(validRecords) == 0 {
logger.Debug("No valid records to cache for %s %s", domain, dns.TypeToString[qtype])
return return
} }
// Don't cache responses with very low TTL // Don't cache responses with very low TTL
if minTTL < 10 { if minTTL < 10 {
logger.Debug("TTL too low (%ds) for caching %s %s", minTTL, domain, dns.TypeToString[qtype])
return return
} }
@@ -427,12 +503,16 @@ func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) {
s.cacheMutex.Unlock() s.cacheMutex.Unlock()
if s.config.Verbose { if s.config.Verbose {
log.Printf("Cached %d records for %s %s (TTL: %ds)", logger.Info("Cached %d records for %s %s (TTL: %ds)",
len(validRecords), domain, dns.TypeToString[qtype], minTTL) len(validRecords), domain, dns.TypeToString[qtype], minTTL)
} }
logger.Debug("Cached %d records for %s %s (TTL: %ds, expires: %v)",
len(validRecords), domain, dns.TypeToString[qtype], minTTL, entry.expiresAt)
} }
func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) { func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) {
logger.Debug("Querying upstream for %s %s", domain, dns.TypeToString[qtype])
// Create context with timeout // Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout) ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout)
defer cancel() defer cancel()
@@ -450,7 +530,15 @@ func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, q
msg.SetQuestion(dns.Fqdn(domain), qtype) msg.SetQuestion(dns.Fqdn(domain), qtype)
msg.Id = dns.Id() msg.Id = dns.Id()
msg.RecursionDesired = true msg.RecursionDesired = true
logger.Debug("Sending upstream query: %s %s (ID: %d)", domain, dns.TypeToString[qtype], msg.Id)
recvMsg, err := upstreamClient.Query(msg) recvMsg, err := upstreamClient.Query(msg)
if err != nil {
logger.Debug("Upstream query error for %s %s: %v", domain, dns.TypeToString[qtype], err)
} else {
logger.Debug("Upstream query response for %s %s: %d answers, rcode: %s",
domain, dns.TypeToString[qtype], len(recvMsg.Answer), dns.RcodeToString[recvMsg.Rcode])
}
resultChan <- result{msg: recvMsg, err: err} resultChan <- result{msg: recvMsg, err: err}
}() }()
@@ -458,6 +546,7 @@ func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, q
case res := <-resultChan: case res := <-resultChan:
return res.msg, res.err return res.msg, res.err
case <-ctx.Done(): case <-ctx.Done():
logger.Debug("Upstream query timeout for %s %s after %v", domain, dns.TypeToString[qtype], s.config.Timeout)
return nil, fmt.Errorf("upstream query timeout") return nil, fmt.Errorf("upstream query timeout")
} }
} }
@@ -466,29 +555,38 @@ func (s *Server) Start() error {
go func() { go func() {
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan sig := <-sigChan
log.Println("Shutting down DNS server...") logger.Info("Received signal %v, shutting down DNS server...", sig)
s.Shutdown() s.Shutdown()
}() }()
log.Printf("DNS proxy server listening on %s", s.config.Address) logger.Info("DNS proxy server listening on %s", s.config.Address)
logger.Debug("Server starting with timeout: %v, DNSSEC: %v", s.config.Timeout, s.config.DNSSEC)
return s.dnsServer.ListenAndServe() return s.dnsServer.ListenAndServe()
} }
func (s *Server) Shutdown() { func (s *Server) Shutdown() {
logger.Debug("Shutting down server components")
if s.dnsServer != nil { if s.dnsServer != nil {
logger.Debug("Shutting down DNS server")
s.dnsServer.Shutdown() s.dnsServer.Shutdown()
} }
if s.upstreamClient != nil { if s.upstreamClient != nil {
logger.Debug("Closing upstream client")
s.upstreamClient.Close() s.upstreamClient.Close()
} }
if s.fallbackClient != nil { if s.fallbackClient != nil {
logger.Debug("Closing fallback client")
s.fallbackClient.Close() s.fallbackClient.Close()
} }
if s.bootstrapClient != nil { if s.bootstrapClient != nil {
logger.Debug("Closing bootstrap client")
s.bootstrapClient.Close() s.bootstrapClient.Close()
} }
logger.Info("Server shutdown complete")
} }