diff --git a/client/client.go b/client/client.go index f1d4e2f..7a8d533 100644 --- a/client/client.go +++ b/client/client.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/afonsofrancof/sdns-proxy/common/dnssec" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/afonsofrancof/sdns-proxy/common/protocols/do53" "github.com/afonsofrancof/sdns-proxy/common/protocols/doh" "github.com/afonsofrancof/sdns-proxy/common/protocols/doq" @@ -26,16 +27,19 @@ type ValidatingDNSClient struct { } type Options struct { - DNSSEC bool - ValidateOnly bool + DNSSEC bool + ValidateOnly bool StrictValidation bool } // New creates a DNS client based on the upstream string func New(upstream string, opts Options) (DNSClient, error) { + logger.Debug("Creating DNS client for upstream: %s with options: %+v", upstream, opts) + // Try to parse as URL first parsedURL, err := url.Parse(upstream) if err != nil { + logger.Error("Invalid upstream format: %v", err) return nil, fmt.Errorf("invalid upstream format: %w", err) } @@ -43,30 +47,26 @@ func New(upstream string, opts Options) (DNSClient, error) { // If it has a scheme, treat it as a full URL if parsedURL.Scheme != "" { + logger.Debug("Parsing %s as URL with scheme %s", upstream, parsedURL.Scheme) baseClient, err = createClientFromURL(parsedURL, opts) } else { // No scheme - treat as plain DNS address + logger.Debug("Parsing %s as plain DNS address", upstream) baseClient, err = createClientFromPlainAddress(upstream, opts) } if err != nil { + logger.Error("Failed to create base client: %v", err) return nil, err } // If DNSSEC is not enabled, return the base client if !opts.DNSSEC { + logger.Debug("DNSSEC disabled, returning base client") return baseClient, nil } - // Wrap with DNSSEC validation - // validator := dnssec.NewValidator(func(qname string, qtype uint16) (*dns.Msg, error) { - // msg := new(dns.Msg) - // msg.SetQuestion(dns.Fqdn(qname), qtype) - // msg.Id = dns.Id() - // msg.RecursionDesired = true - // msg.SetEdns0(4096, true) // Enable DNSSEC - // return baseClient.Query(msg) - // }) + logger.Debug("DNSSEC enabled, wrapping with validator") validator := dnssec.NewValidatorWithAuthoritativeQueries() return &ValidatingDNSClient{ @@ -77,9 +77,16 @@ func New(upstream string, opts Options) (DNSClient, error) { } func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) > 0 { + question := msg.Question[0] + logger.Debug("ValidatingDNSClient query: %s %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v)", + question.Name, dns.TypeToString[question.Qtype], v.options.DNSSEC, v.options.ValidateOnly, v.options.StrictValidation) + } + // Always query the upstream first response, err := v.client.Query(msg) if err != nil { + logger.Debug("Base client query failed: %v", err) return nil, err } @@ -90,6 +97,7 @@ func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) { // Extract question details for validation if len(msg.Question) == 0 { + logger.Debug("No questions in message, skipping DNSSEC validation") return response, nil } @@ -97,6 +105,8 @@ func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) { qname := question.Name qtype := question.Qtype + logger.Debug("Starting DNSSEC validation for %s %s", qname, dns.TypeToString[qtype]) + // Validate the response validationErr := v.validator.ValidateResponse(response, qname, qtype) @@ -104,28 +114,35 @@ func (v *ValidatingDNSClient) Query(msg *dns.Msg) (*dns.Msg, error) { if validationErr != nil { // Check if it's a "not signed" error if validationErr == dnssec.ErrResourceNotSigned { + logger.Debug("Domain %s is not DNSSEC signed", qname) if v.options.ValidateOnly { + logger.Error("Domain %s is not DNSSEC signed (ValidateOnly mode)", qname) return nil, fmt.Errorf("domain %s is not DNSSEC signed", qname) } // Return unsigned response if not in validate-only mode + logger.Debug("Returning unsigned response for %s", qname) return response, nil } // For other validation errors + logger.Debug("DNSSEC validation failed for %s: %v", qname, validationErr) if v.options.StrictValidation { + logger.Error("DNSSEC validation failed for %s (strict mode): %v", qname, validationErr) return nil, fmt.Errorf("DNSSEC validation failed for %s: %w", qname, validationErr) } // In non-strict mode, log the error but return the response - // (You might want to add logging here) + logger.Debug("DNSSEC validation failed for %s (non-strict mode), returning response anyway: %v", qname, validationErr) return response, nil } // Validation successful + logger.Debug("DNSSEC validation successful for %s %s", qname, dns.TypeToString[qtype]) return response, nil } func (v *ValidatingDNSClient) Close() { + logger.Debug("Closing ValidatingDNSClient") if v.client != nil { v.client.Close() } @@ -135,6 +152,7 @@ func createClientFromURL(parsedURL *url.URL, opts Options) (DNSClient, error) { scheme := strings.ToLower(parsedURL.Scheme) host := parsedURL.Hostname() if host == "" { + logger.Error("Missing host in upstream URL: %s", parsedURL.String()) return nil, fmt.Errorf("missing host in upstream URL") } @@ -148,6 +166,7 @@ func createClientFromURL(parsedURL *url.URL, opts Options) (DNSClient, error) { path = getDefaultPath(scheme) } + logger.Debug("Creating client from URL: scheme=%s, host=%s, port=%s, path=%s", scheme, host, port, path) return createClient(scheme, host, port, path, opts) } @@ -162,41 +181,49 @@ func createClientFromPlainAddress(address string, opts Options) (DNSClient, erro } if host == "" { + logger.Error("Empty host in address: %s", address) return nil, fmt.Errorf("empty host in address: %s", address) } + logger.Debug("Creating client from plain address: host=%s, port=%s", host, port) return createClient("", host, port, "", opts) } func getDefaultPort(scheme string) string { + port := "53" switch scheme { case "https", "doh", "doh3": - return "443" + port = "443" case "tls", "dot": - return "853" + port = "853" case "quic", "doq": - return "853" - default: - return "53" + port = "853" } + logger.Debug("Default port for scheme %s: %s", scheme, port) + return port } func getDefaultPath(scheme string) string { + path := "" switch scheme { case "https", "doh", "doh3": - return "/dns-query" - default: - return "" + path = "/dns-query" } + logger.Debug("Default path for scheme %s: %s", scheme, path) + return path } func createClient(scheme, host, port, path string, opts Options) (DNSClient, error) { + logger.Debug("Creating client: scheme=%s, host=%s, port=%s, path=%s, DNSSEC=%v", + scheme, host, port, path, opts.DNSSEC) + switch scheme { case "udp", "tcp", "do53", "": config := do53.Config{ HostAndPort: net.JoinHostPort(host, port), DNSSEC: opts.DNSSEC, } + logger.Debug("Creating DO53 client with config: %+v", config) return do53.New(config) case "http", "doh": @@ -207,6 +234,7 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err DNSSEC: opts.DNSSEC, HTTP3: false, } + logger.Debug("Creating DoH client with config: %+v", config) return doh.New(config) case "https", "doh3": @@ -217,6 +245,7 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err DNSSEC: opts.DNSSEC, HTTP3: true, } + logger.Debug("Creating DoH3 client with config: %+v", config) return doh.New(config) case "tls", "dot": @@ -225,6 +254,7 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err Port: port, DNSSEC: opts.DNSSEC, } + logger.Debug("Creating DoT client with config: %+v", config) return dot.New(config) case "doq": // DNS over QUIC @@ -233,9 +263,11 @@ func createClient(scheme, host, port, path string, opts Options) (DNSClient, err Port: port, DNSSEC: opts.DNSSEC, } + logger.Debug("Creating DoQ client with config: %+v", config) return doq.New(config) default: + logger.Error("Unsupported scheme: %s", scheme) return nil, fmt.Errorf("unsupported scheme: %s", scheme) } } diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go deleted file mode 100644 index c770c7d..0000000 --- a/cmd/resolver/main.go +++ /dev/null @@ -1,138 +0,0 @@ -package main - -import ( - "fmt" - "log" - "os" - "strings" - "time" - - "github.com/afonsofrancof/sdns-perf/internal/client" - - "github.com/alecthomas/kong" - "github.com/miekg/dns" -) - -var cli struct { - // Global flags - Verbose bool `help:"Enable verbose logging." short:"v"` - - Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."` - Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."` -} - -type QueryCmd struct { - DomainName string `help:"Domain name to resolve." arg:"" required:""` - Server string `help:"Upstream server address (e.g., https://1.1.1.1/dns-query, tls://1.1.1.1, 8.8.8.8)." short:"s" required:""` - QueryType string `help:"Query type (A, AAAA, MX, TXT, etc.)." short:"t" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` - DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"` - Timeout time.Duration `help:"Timeout for the query operation." default:"10s"` // Default might be higher now - KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"` -} - -func (q *QueryCmd) Run() error { - log.Printf("Querying %s for %s type %s (DNSSEC: %v, Timeout: %v)\n", - q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.Timeout) - - opts := client.Options{ - Timeout: q.Timeout, - DNSSEC: q.DNSSEC, - KeyLogPath: q.KeyLogFile, - } - - dnsClient, err := client.New(q.Server, opts) - if err != nil { - return err - } - defer dnsClient.Close() - - qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)] - if !ok { - return fmt.Errorf("invalid query type: %s", q.QueryType) - } - - dnsMsg, err := dnsClient.Query(q.DomainName, qTypeUint) - if err != nil { - return fmt.Errorf("query failed: %w ", err) - } - - printResponse(q.DomainName, q.QueryType, dnsMsg) - - return nil -} - -type ListenCmd struct { - Address string `help:"Address to listen on (e.g., :53, :8053)." default:":53"` - // Add other server-specific flags: default upstream, TLS cert/key paths etc. -} - -func (l *ListenCmd) Run() error { - return fmt.Errorf("server/listen mode not yet implemented") -} - -func printResponse(domain, qtype string, msg *dns.Msg) { - fmt.Println(";; QUESTION SECTION:") - - fmt.Printf(";%s.\tIN\t%s\n", dns.Fqdn(domain), strings.ToUpper(qtype)) - - fmt.Println("\n;; ANSWER SECTION:") - if len(msg.Answer) > 0 { - for _, rr := range msg.Answer { - fmt.Println(rr.String()) - } - } else { - fmt.Println(";; No records found in answer section.") - } - - if len(msg.Ns) > 0 { - fmt.Println("\n;; AUTHORITY SECTION:") - for _, rr := range msg.Ns { - fmt.Println(rr.String()) - } - } - if len(msg.Extra) > 0 { - hasRealExtra := false - for _, rr := range msg.Extra { - if rr.Header().Rrtype != dns.TypeOPT { - hasRealExtra = true - break - } - } - if hasRealExtra { - fmt.Println("\n;; ADDITIONAL SECTION:") - for _, rr := range msg.Extra { - if rr.Header().Rrtype != dns.TypeOPT { - fmt.Println(rr.String()) - } - } - } - } - - fmt.Printf("\n;; RCODE: %s, ID: %d", dns.RcodeToString[msg.Rcode], msg.Id) - opt := msg.IsEdns0() - if opt != nil { - fmt.Printf(", EDNS: version: %d; flags:", opt.Version()) - if opt.Do() { - fmt.Printf(" do;") - } else { - fmt.Printf(";") - } - fmt.Printf(" udp: %d", opt.UDPSize()) - } - fmt.Println() -} - -func main() { - log.SetOutput(os.Stderr) - log.SetFlags(log.Ltime | log.Lshortfile) - - kongCtx := kong.Parse(&cli, - kong.Name("sdns-perf"), - kong.Description("A DNS client/server tool supporting multiple protocols."), - kong.UsageOnError(), - kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}), - ) - - err := kongCtx.Run() - kongCtx.FatalIfErrorf(err) -} diff --git a/common/dnssec/authchain.go b/common/dnssec/authchain.go index a88b4c4..19b59ab 100644 --- a/common/dnssec/authchain.go +++ b/common/dnssec/authchain.go @@ -20,9 +20,9 @@ package dnssec import ( "fmt" - "log" "strings" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -57,13 +57,13 @@ func (ac *AuthenticationChain) Populate(domainName string, queryFunc func(string zones = append(zones, zone) } - log.Printf("Building DNSSEC chain for zones: %v", zones) + logger.Debug("Building DNSSEC chain for zones: %v", zones) ac.DelegationChain = make([]SignedZone, 0, len(zones)) // Query each zone from root down for i, zoneName := range zones { - log.Printf("Querying zone: %s", zoneName) + logger.Debug("Querying zone: %s", zoneName) delegation, err := ac.queryDelegation(zoneName, queryFunc) if err != nil { @@ -91,13 +91,13 @@ func (ac *AuthenticationChain) queryDelegation(domainName string, queryFunc func } signedZone.DNSKey = dnskeyRRset - log.Printf("Found %d DNSKEY records for %s", len(dnskeyRRset.RRs), domainName) + logger.Debug("Found %d DNSKEY records for %s", len(dnskeyRRset.RRs), domainName) // Populate public key lookup for _, rr := range signedZone.DNSKey.RRs { if dnskey, ok := rr.(*dns.DNSKEY); ok { signedZone.AddPubKey(dnskey) - log.Printf("Added DNSKEY for %s: keytag=%d, flags=%d, algorithm=%d", domainName, dnskey.KeyTag(), dnskey.Flags, dnskey.Algorithm) + logger.Debug("Added DNSKEY for %s: keytag=%d, flags=%d, algorithm=%d", domainName, dnskey.KeyTag(), dnskey.Flags, dnskey.Algorithm) } } @@ -106,17 +106,17 @@ func (ac *AuthenticationChain) queryDelegation(domainName string, queryFunc func dsRRset, _ := ac.queryRRset(domainName, dns.TypeDS, queryFunc) signedZone.DS = dsRRset if dsRRset != nil && len(dsRRset.RRs) > 0 { - log.Printf("Found %d DS records for %s", len(dsRRset.RRs), domainName) + logger.Debug("Found %d DS records for %s", len(dsRRset.RRs), domainName) for _, rr := range dsRRset.RRs { if ds, ok := rr.(*dns.DS); ok { - log.Printf("DS record for %s: keytag=%d", domainName, ds.KeyTag) + logger.Debug("DS record for %s: keytag=%d", domainName, ds.KeyTag) } } } } else { // Root zone has no DS records - trusted by default signedZone.DS = NewRRSet() - log.Printf("Root zone - no DS records, trusted by default") + logger.Debug("Root zone - no DS records, trusted by default") } return signedZone, nil @@ -125,12 +125,12 @@ func (ac *AuthenticationChain) queryDelegation(domainName string, queryFunc func func (ac *AuthenticationChain) queryRRset(qname string, qtype uint16, queryFunc func(string, uint16) (*dns.Msg, error)) (*RRSet, error) { r, err := queryFunc(qname, qtype) if err != nil { - log.Printf("cannot lookup %v", err) + logger.Debug("cannot lookup %v", err) return NewRRSet(), nil // Return empty RRSet instead of nil } if r.Rcode == dns.RcodeNameError { - log.Printf("no such domain %s", qname) + logger.Debug("no such domain %s", qname) return NewRRSet(), nil // Return empty RRSet instead of nil } @@ -167,60 +167,60 @@ func (ac *AuthenticationChain) Verify(answerRRset *RRSet) error { // Verify the answer RRset against target zone's keys err := targetZone.VerifyRRSIG(answerRRset) if err != nil { - log.Printf("Answer RRSIG verification failed: %v", err) + logger.Debug("Answer RRSIG verification failed: %v", err) return ErrInvalidRRsig } // Validate the chain from root down for _, zone := range ac.DelegationChain { - log.Printf("Validating zone: %s", zone.Zone) + logger.Debug("Validating zone: %s", zone.Zone) // Verify DNSKEY RRset signature if !zone.HasDNSKeys() { - log.Printf("No DNSKEYs for zone %s", zone.Zone) + logger.Debug("No DNSKEYs for zone %s", zone.Zone) return ErrDnskeyNotAvailable } err := zone.VerifyRRSIG(zone.DNSKey) if err != nil { - log.Printf("DNSKEY validation failed for %s: %v", zone.Zone, err) + logger.Debug("DNSKEY validation failed for %s: %v", zone.Zone, err) return ErrRrsigValidationError } // Skip ALL validation for root - just trust it if zone.Zone == "." { - log.Printf("Root zone - trusted by default, no validation performed") + logger.Debug("Root zone - trusted by default, no validation performed") continue } // For non-root zones, validate DS records against parent zone if zone.ParentZone == nil { - log.Printf("Non-root zone %s has no parent", zone.Zone) + logger.Debug("Non-root zone %s has no parent", zone.Zone) return fmt.Errorf("non-root zone %s has no parent", zone.Zone) } if zone.DS == nil || zone.DS.IsEmpty() { - log.Printf("No DS records for zone %s", zone.Zone) + logger.Debug("No DS records for zone %s", zone.Zone) return ErrDsNotAvailable } // Verify DS signature using parent's key err = zone.ParentZone.VerifyRRSIG(zone.DS) if err != nil { - log.Printf("DS signature validation failed for %s: %v", zone.Zone, err) + logger.Debug("DS signature validation failed for %s: %v", zone.Zone, err) return ErrRrsigValidationError } // Verify DS matches this zone's DNSKEY err = zone.VerifyDS(zone.DS.RRs) if err != nil { - log.Printf("DS-DNSKEY validation failed for %s: %v", zone.Zone, err) + logger.Debug("DS-DNSKEY validation failed for %s: %v", zone.Zone, err) return ErrDsInvalid } - log.Printf("Zone %s validated successfully", zone.Zone) + logger.Debug("Zone %s validated successfully", zone.Zone) } - log.Printf("DNSSEC validation successful for entire chain!") + logger.Debug("DNSSEC validation successful for entire chain!") return nil } diff --git a/common/dnssec/authoritative.go b/common/dnssec/authoritative.go index a9692e0..18c132f 100644 --- a/common/dnssec/authoritative.go +++ b/common/dnssec/authoritative.go @@ -2,11 +2,11 @@ package dnssec import ( "fmt" - "log" "net" "strings" "time" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -28,13 +28,13 @@ func NewAuthoritativeQuerier() *AuthoritativeQuerier { } func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) (*dns.Msg, error) { - log.Printf("Querying authoritative servers for %s type %d", qname, qtype) + logger.Debug("Querying authoritative servers for %s type %d", qname, qtype) var zone string if qtype == dns.TypeDS { zone = aq.getParentZone(qname) if zone == "" { - log.Printf("No parent zone for %s - returning NXDOMAIN for DS query", qname) + logger.Debug("No parent zone for %s - returning NXDOMAIN for DS query", qname) msg := &dns.Msg{} msg.SetRcode(&dns.Msg{}, dns.RcodeNameError) return msg, nil @@ -43,7 +43,7 @@ func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) ( zone = aq.findZone(qname) } - log.Printf("Determined zone: %s for query %s type %d", zone, qname, qtype) + logger.Debug("Determined zone: %s for query %s type %d", zone, qname, qtype) // Get NS names (not IPs yet) nsNames, err := aq.findAuthoritativeNSNames(zone) @@ -59,15 +59,15 @@ func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) ( continue } - log.Printf("Trying server: %s (%s)", server, nsName) + logger.Debug("Trying server: %s (%s)", server, nsName) msg, err := aq.queryServer(server, qname, qtype) if err != nil { - log.Printf("Server %s failed: %v", server, err) + logger.Debug("Server %s failed: %v", server, err) lastErr = err continue } - log.Printf("Server %s responded, authoritative: %v, rcode: %d, answers: %d", server, msg.Authoritative, msg.Rcode, len(msg.Answer)) + logger.Debug("Server %s responded, authoritative: %v, rcode: %d, answers: %d", server, msg.Authoritative, msg.Rcode, len(msg.Answer)) if (msg.Rcode == dns.RcodeSuccess && len(msg.Answer) > 0) || msg.Rcode == dns.RcodeNameError { return msg, nil } @@ -82,11 +82,11 @@ func (aq *AuthoritativeQuerier) QueryAuthoritative(qname string, qtype uint16) ( func (aq *AuthoritativeQuerier) findAuthoritativeNSNames(zone string) ([]string, error) { if nsNames, exists := aq.nsCache[zone]; exists { - log.Printf("Using cached NS names for %s: %v", zone, nsNames) + logger.Debug("Using cached NS names for %s: %v", zone, nsNames) return nsNames, nil } - log.Printf("Looking for NS records for zone: %s", zone) + logger.Debug("Looking for NS records for zone: %s", zone) // Use a public resolver to find the NS records resolver := &dns.Client{Timeout: 5 * time.Second} @@ -125,35 +125,35 @@ func (aq *AuthoritativeQuerier) findAuthoritativeNSNames(zone string) ([]string, return nil, fmt.Errorf("no NS servers found for %s", zone) } - log.Printf("Found NS names for %s: %v", zone, nsNames) + logger.Debug("Found NS names for %s: %v", zone, nsNames) aq.nsCache[zone] = nsNames - log.Printf("Cached NS names for %s: %v", zone, nsNames) + logger.Debug("Cached NS names for %s: %v", zone, nsNames) return nsNames, nil } func (aq *AuthoritativeQuerier) getParentZone(qname string) string { - log.Printf("Getting parent zone for: %s", qname) + logger.Debug("Getting parent zone for: %s", qname) // Clean the qname qname = strings.TrimSuffix(qname, ".") // Root zone has no parent if qname == "" || qname == "." { - log.Printf("Root zone has no parent") + logger.Debug("Root zone has no parent") return "" } labels := dns.SplitDomainName(qname) - log.Printf("Labels for %s: %v", qname, labels) + logger.Debug("Labels for %s: %v", qname, labels) if len(labels) <= 1 { - log.Printf("Parent of TLD %s is root", qname) + logger.Debug("Parent of TLD %s is root", qname) return "." // Parent of TLD is root } parentLabels := labels[1:] parent := dns.Fqdn(strings.Join(parentLabels, ".")) - log.Printf("Parent zone of %s is %s", qname, parent) + logger.Debug("Parent zone of %s is %s", qname, parent) return parent } @@ -167,27 +167,26 @@ func (aq *AuthoritativeQuerier) findZone(qname string) string { return qname } - func (aq *AuthoritativeQuerier) resolveNSToIP(nsName string) string { if ip, exists := aq.ipCache[nsName]; exists { - log.Printf("Using cached IP for %s: %s", nsName, ip) + logger.Debug("Using cached IP for %s: %s", nsName, ip) return ip } nsName = strings.TrimSuffix(nsName, ".") - log.Printf("Resolving NS %s to IP", nsName) + logger.Debug("Resolving NS %s to IP", nsName) ips, err := net.LookupIP(nsName) if err != nil { - log.Printf("Failed to resolve %s: %v", nsName, err) + logger.Debug("Failed to resolve %s: %v", nsName, err) return "" } for _, ip := range ips { if ip.To4() != nil { // Prefer IPv4 result := ip.String() + ":53" - log.Printf("Resolved %s to %s", nsName, result) - + logger.Debug("Resolved %s to %s", nsName, result) + // Cache the result before returning aq.ipCache[nsName] = result return result @@ -202,12 +201,12 @@ func (aq *AuthoritativeQuerier) queryServer(server, qname string, qtype uint16) m.SetQuestion(dns.Fqdn(qname), qtype) m.SetEdns0(4096, true) // Enable DNSSEC - log.Printf("Querying %s for %s type %d", server, qname, qtype) + logger.Debug("Querying %s for %s type %d", server, qname, qtype) msg, _, err := aq.client.Exchange(m, server) if err != nil { return nil, err } - log.Printf("Response from %s: rcode=%d, answers=%d", server, msg.Rcode, len(msg.Answer)) + logger.Debug("Response from %s: rcode=%d, answers=%d", server, msg.Rcode, len(msg.Answer)) return msg, err } diff --git a/common/dnssec/rrset.go b/common/dnssec/rrset.go index 704424a..c49c2fc 100644 --- a/common/dnssec/rrset.go +++ b/common/dnssec/rrset.go @@ -19,9 +19,9 @@ package dnssec // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import ( - "log" "time" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -65,7 +65,7 @@ func (r *RRSet) ValidateSignature(key *dns.DNSKEY) error { err := r.RRSig.Verify(key, r.RRs) if err != nil { - log.Printf("RRSIG verification failed: %v", err) + logger.Debug("RRSIG verification failed: %v", err) return ErrRrsigValidationError } diff --git a/common/dnssec/signedzone.go b/common/dnssec/signedzone.go index 749527a..6ee1e58 100644 --- a/common/dnssec/signedzone.go +++ b/common/dnssec/signedzone.go @@ -19,9 +19,9 @@ package dnssec // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import ( - "log" "strings" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -61,7 +61,7 @@ func (z *SignedZone) VerifyRRSIG(signedRRset *RRSet) error { key := z.LookupPubKey(signedRRset.RRSig.KeyTag) if key == nil { - log.Printf("DNSKEY keytag %d not found in zone %s", signedRRset.RRSig.KeyTag, z.Zone) + logger.Debug("DNSKEY keytag %d not found in zone %s", signedRRset.RRSig.KeyTag, z.Zone) return ErrDnskeyNotAvailable } @@ -69,36 +69,36 @@ func (z *SignedZone) VerifyRRSIG(signedRRset *RRSet) error { } func (z *SignedZone) VerifyDS(dsRRset []dns.RR) error { - log.Printf("Verifying DS for zone %s with %d DS records", z.Zone, len(dsRRset)) + logger.Debug("Verifying DS for zone %s with %d DS records", z.Zone, len(dsRRset)) for _, rr := range dsRRset { ds, ok := rr.(*dns.DS) if !ok { continue } - log.Printf("Checking DS keytag %d, digestType %d", ds.KeyTag, ds.DigestType) + logger.Debug("Checking DS keytag %d, digestType %d", ds.KeyTag, ds.DigestType) if ds.DigestType != dns.SHA256 { - log.Printf("Unknown digest type (%d) on DS RR", ds.DigestType) + logger.Debug("Unknown digest type (%d) on DS RR", ds.DigestType) continue } parentDsDigest := strings.ToUpper(ds.Digest) key := z.LookupPubKey(ds.KeyTag) if key == nil { - log.Printf("DNSKEY keytag %d not found in zone %s", ds.KeyTag, z.Zone) + logger.Debug("DNSKEY keytag %d not found in zone %s", ds.KeyTag, z.Zone) return ErrDnskeyNotAvailable } dsDigest := strings.ToUpper(key.ToDS(ds.DigestType).Digest) - log.Printf("Parent DS digest: %s, Computed digest: %s", parentDsDigest, dsDigest) + logger.Debug("Parent DS digest: %s, Computed digest: %s", parentDsDigest, dsDigest) if parentDsDigest == dsDigest { - log.Printf("DS validation successful for keytag %d", ds.KeyTag) + logger.Debug("DS validation successful for keytag %d", ds.KeyTag) return nil } - log.Printf("DS does not match DNSKEY for keytag %d", ds.KeyTag) + logger.Debug("DS does not match DNSKEY for keytag %d", ds.KeyTag) } - log.Printf("No matching DS found") + logger.Debug("No matching DS found") return ErrDsInvalid } diff --git a/common/dnssec/validator.go b/common/dnssec/validator.go index 6e85ddd..caeb00c 100644 --- a/common/dnssec/validator.go +++ b/common/dnssec/validator.go @@ -17,10 +17,10 @@ package dnssec // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN // ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF // OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +// ./common/dnssec/validator.go import ( - "log" - + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -35,7 +35,10 @@ func NewValidator(queryFunc func(string, uint16) (*dns.Msg, error)) *Validator { } func (v *Validator) ValidateResponse(msg *dns.Msg, qname string, qtype uint16) error { + logger.Debug("Starting DNSSEC validation for %s %s", qname, dns.TypeToString[qtype]) + if msg == nil || len(msg.Answer) == 0 { + logger.Debug("No result for %s %s", qname, dns.TypeToString[qtype]) return ErrNoResult } @@ -46,41 +49,48 @@ func (v *Validator) ValidateResponse(msg *dns.Msg, qname string, qtype uint16) e case *dns.RRSIG: if t.TypeCovered == qtype { rrset.RRSig = t + logger.Debug("Found RRSIG for %s %s (keytag: %d)", qname, dns.TypeToString[qtype], t.KeyTag) } default: if rr.Header().Rrtype == qtype { rrset.RRs = append(rrset.RRs, rr) + logger.Debug("Found RR for %s %s: %s", qname, dns.TypeToString[qtype], rr.String()) } } } if rrset.IsEmpty() { + logger.Debug("Empty RRSet for %s %s", qname, dns.TypeToString[qtype]) return ErrNoResult } if !rrset.IsSigned() { + logger.Debug("RRSet for %s %s is not signed", qname, dns.TypeToString[qtype]) return ErrResourceNotSigned } // Check header integrity if err := rrset.CheckHeaderIntegrity(qname); err != nil { + logger.Debug("Header integrity check failed for %s %s: %v", qname, dns.TypeToString[qtype], err) return err } // Build and verify authentication chain signerName := rrset.SignerName() + logger.Debug("Building authentication chain for signer: %s", signerName) authChain := NewAuthenticationChain() if err := authChain.Populate(signerName, v.queryFunc); err != nil { - log.Printf("Cannot populate authentication chain: %s", err) + logger.Debug("Cannot populate authentication chain for %s: %v", signerName, err) return err } if err := authChain.Verify(rrset); err != nil { - log.Printf("DNSSEC validation failed: %s", err) + logger.Debug("DNSSEC validation failed for %s %s: %v", qname, dns.TypeToString[qtype], err) return err } + logger.Debug("DNSSEC validation successful for %s %s", qname, dns.TypeToString[qtype]) return nil } diff --git a/common/logger/logger.go b/common/logger/logger.go new file mode 100644 index 0000000..a0cc113 --- /dev/null +++ b/common/logger/logger.go @@ -0,0 +1,41 @@ +package logger + +import ( + "log" + "os" +) + +var ( + debugEnabled bool + infoLogger *log.Logger + debugLogger *log.Logger + errorLogger *log.Logger +) + +func init() { + infoLogger = log.New(os.Stderr, "[INFO] ", log.Ltime|log.Lshortfile) + debugLogger = log.New(os.Stderr, "[DEBUG] ", log.Ltime|log.Lshortfile) + errorLogger = log.New(os.Stderr, "[ERROR] ", log.Ltime|log.Lshortfile) +} + +func SetDebug(enabled bool) { + debugEnabled = enabled +} + +func IsDebugEnabled() bool { + return debugEnabled +} + +func Info(format string, args ...any) { + infoLogger.Printf(format, args...) +} + +func Debug(format string, args ...any) { + if debugEnabled { + debugLogger.Printf(format, args...) + } +} + +func Error(format string, args ...any) { + errorLogger.Printf(format, args...) +} diff --git a/common/protocols/do53/do53.go b/common/protocols/do53/do53.go index 18c498a..da626d3 100644 --- a/common/protocols/do53/do53.go +++ b/common/protocols/do53/do53.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -21,7 +22,10 @@ type Client struct { } func New(config Config) (*Client, error) { + logger.Debug("Creating DO53 client: %s", config.HostAndPort) + if config.HostAndPort == "" { + logger.Error("DO53 client creation failed: empty HostAndPort") return nil, fmt.Errorf("do53: HostAndPort cannot be empty") } if config.WriteTimeout <= 0 { @@ -31,6 +35,8 @@ func New(config Config) (*Client, error) { config.ReadTimeout = 5 * time.Second } + logger.Debug("DO53 client created: %s (DNSSEC: %v)", config.HostAndPort, config.DNSSEC) + return &Client{ hostAndPort: config.HostAndPort, config: config, @@ -38,18 +44,32 @@ func New(config Config) (*Client, error) { } func (c *Client) Close() { + logger.Debug("Closing DO53 client") } func (c *Client) createConnection() (*net.UDPConn, error) { udpAddr, err := net.ResolveUDPAddr("udp", c.hostAndPort) if err != nil { + logger.Error("DO53 failed to resolve address %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("failed to resolve UDP address: %w", err) } - return net.DialUDP("udp", nil, udpAddr) + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + logger.Error("DO53 failed to connect to %s: %v", c.hostAndPort, err) + return nil, err + } + + logger.Debug("DO53 connection established to %s", c.hostAndPort) + return conn, nil } func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) > 0 { + question := msg.Question[0] + logger.Debug("DO53 query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort) + } + // Create connection for this query conn, err := c.createConnection() if err != nil { @@ -62,36 +82,45 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { } packedMsg, err := msg.Pack() - if err != nil { + logger.Error("DO53 failed to pack message: %v", err) return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err) } // Send query if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { + logger.Error("DO53 failed to set write deadline: %v", err) return nil, fmt.Errorf("do53: failed to set write deadline: %w", err) } if _, err := conn.Write(packedMsg); err != nil { + logger.Error("DO53 failed to send query to %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("do53: failed to send DNS query: %w", err) } // Read response if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil { + logger.Error("DO53 failed to set read deadline: %v", err) return nil, fmt.Errorf("do53: failed to set read deadline: %w", err) } buffer := make([]byte, dns.MaxMsgSize) n, err := conn.Read(buffer) if err != nil { + logger.Error("DO53 failed to read response from %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("do53: failed to read DNS response: %w", err) } // Parse response response := new(dns.Msg) if err := response.Unpack(buffer[:n]); err != nil { + logger.Error("DO53 failed to unpack response from %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("do53: failed to unpack DNS response: %w", err) } + if len(response.Answer) > 0 { + logger.Debug("DO53 response from %s: %d answers", c.hostAndPort, len(response.Answer)) + } + return response, nil } diff --git a/common/protocols/doh/doh.go b/common/protocols/doh/doh.go index 93a0e91..b42e873 100644 --- a/common/protocols/doh/doh.go +++ b/common/protocols/doh/doh.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/http3" @@ -36,8 +37,10 @@ type Client struct { } func New(config Config) (*Client, error) { + logger.Debug("Creating DoH client: %s:%s%s", config.Host, config.Port, config.Path) + if config.Host == "" || config.Port == "" || config.Path == "" { - fmt.Printf("%v,%v,%v", config.Host, config.Port, config.Path) + logger.Error("DoH client creation failed: missing required fields") return nil, errors.New("doh: host, port, and path must not be empty") } @@ -48,6 +51,7 @@ func New(config Config) (*Client, error) { parsedURL, err := url.Parse(rawURL) if err != nil { + logger.Error("Failed to parse DoH URL %s: %v", rawURL, err) return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err) } @@ -67,21 +71,26 @@ func New(config Config) (*Client, error) { Transport: transport, } + var transportType string if config.HTTP2 { httpClient.Transport = &http2.Transport{ TLSClientConfig: tlsConfig, AllowHTTP: true, } - } - - if config.HTTP3 { + transportType = "HTTP/2" + } else if config.HTTP3 { quicTlsConfig := http3.ConfigureTLSConfig(tlsConfig) httpClient.Transport = &http3.Transport{ TLSClientConfig: quicTlsConfig, QUICConfig: quicConfig, } + transportType = "HTTP/3" + } else { + transportType = "HTTP/1.1" } + logger.Debug("DoH client created: %s (%s, DNSSEC: %v)", rawURL, transportType, config.DNSSEC) + return &Client{ httpClient: httpClient, upstreamURL: parsedURL, @@ -90,6 +99,7 @@ func New(config Config) (*Client, error) { } func (c *Client) Close() { + logger.Debug("Closing DoH client") if t, ok := c.httpClient.Transport.(*http.Transport); ok { t.CloseIdleConnections() } else if t3, ok := c.httpClient.Transport.(*http3.Transport); ok { @@ -98,16 +108,24 @@ func (c *Client) Close() { } func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) > 0 { + question := msg.Question[0] + logger.Debug("DoH query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.upstreamURL.Host) + } + if c.config.DNSSEC { msg.SetEdns0(4096, true) } + packedMsg, err := msg.Pack() if err != nil { + logger.Error("DoH failed to pack DNS message: %v", err) return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err) } httpReq, err := http.NewRequest(http.MethodPost, c.upstreamURL.String(), bytes.NewReader(packedMsg)) if err != nil { + logger.Error("DoH failed to create HTTP request: %v", err) return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err) } @@ -117,29 +135,37 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { httpResp, err := c.httpClient.Do(httpReq) if err != nil { + logger.Error("DoH request failed to %s: %v", c.upstreamURL.Host, err) return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err) } defer httpResp.Body.Close() if httpResp.StatusCode != http.StatusOK { + logger.Error("DoH received non-200 status from %s: %s", c.upstreamURL.Host, httpResp.Status) return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status) } if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType { + logger.Error("DoH unexpected Content-Type from %s: %s", c.upstreamURL.Host, ct) return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType) } responseBody, err := io.ReadAll(httpResp.Body) if err != nil { + logger.Error("DoH failed reading response from %s: %v", c.upstreamURL.Host, err) return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err) } - // Unpack the DNS message recvMsg := new(dns.Msg) err = recvMsg.Unpack(responseBody) if err != nil { + logger.Error("DoH failed to unpack response from %s: %v", c.upstreamURL.Host, err) return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err) } + if len(recvMsg.Answer) > 0 { + logger.Debug("DoH response from %s: %d answers", c.upstreamURL.Host, len(recvMsg.Answer)) + } + return recvMsg, nil } diff --git a/common/protocols/doq/doq.go b/common/protocols/doq/doq.go index 12a6511..4f272ca 100644 --- a/common/protocols/doq/doq.go +++ b/common/protocols/doq/doq.go @@ -10,6 +10,7 @@ import ( "net" "time" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" "github.com/quic-go/quic-go" ) @@ -32,6 +33,7 @@ type Client struct { } func New(config Config) (*Client, error) { + logger.Debug("Creating DoQ client: %s:%s", config.Host, config.Port) tlsConfig := &tls.Config{ ServerName: config.Host, @@ -42,10 +44,13 @@ func New(config Config) (*Client, error) { targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(config.Host, config.Port)) if err != nil { + logger.Error("DoQ failed to resolve address %s:%s: %v", config.Host, config.Port, err) return nil, err } + udpConn, err := net.ListenUDP("udp", nil) if err != nil { + logger.Error("DoQ failed to create UDP connection: %v", err) return nil, fmt.Errorf("failed to connect to target address: %w", err) } @@ -57,6 +62,8 @@ func New(config Config) (*Client, error) { MaxIdleTimeout: 30 * time.Second, } + logger.Debug("DoQ client created: %s:%s (DNSSEC: %v)", config.Host, config.Port, config.DNSSEC) + return &Client{ targetAddr: targetAddr, tlsConfig: tlsConfig, @@ -69,22 +76,30 @@ func New(config Config) (*Client, error) { } func (c *Client) Close() { + logger.Debug("Closing DoQ client") if c.udpConn != nil { c.udpConn.Close() } } func (c *Client) OpenConnection() error { + logger.Debug("Opening DoQ connection to %s", c.targetAddr) quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) if err != nil { + logger.Error("DoQ connection failed to %s: %v", c.targetAddr, err) return err } c.quicConn = quicConn + logger.Debug("DoQ connection established to %s", c.targetAddr) return nil } func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) > 0 { + question := msg.Question[0] + logger.Debug("DoQ query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.targetAddr) + } if c.quicConn == nil { err := c.OpenConnection() @@ -100,18 +115,21 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { } packed, err := msg.Pack() if err != nil { + logger.Error("DoQ failed to pack message: %v", err) return nil, fmt.Errorf("doq: failed to pack message: %w", err) } var quicStream quic.Stream quicStream, err = c.quicConn.OpenStream() if err != nil { + logger.Debug("DoQ stream failed, reconnecting: %v", err) err = c.OpenConnection() if err != nil { return nil, err } quicStream, err = c.quicConn.OpenStream() if err != nil { + logger.Error("DoQ failed to open stream after reconnect: %v", err) return nil, err } } @@ -119,42 +137,52 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { var lengthPrefixedMessage bytes.Buffer err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(packed))) if err != nil { + logger.Error("DoQ failed to write message length: %v", err) return nil, fmt.Errorf("failed to write message length: %w", err) } _, err = lengthPrefixedMessage.Write(packed) if err != nil { + logger.Error("DoQ failed to write DNS message: %v", err) return nil, fmt.Errorf("failed to write DNS message: %w", err) } _, err = quicStream.Write(lengthPrefixedMessage.Bytes()) if err != nil { + logger.Error("DoQ failed to write to stream: %v", err) return nil, fmt.Errorf("failed writing to QUIC stream: %w", err) } - // Indicate that no further data will be written from this side quicStream.Close() lengthBuf := make([]byte, 2) _, err = io.ReadFull(quicStream, lengthBuf) if err != nil { + logger.Error("DoQ failed to read response length: %v", err) return nil, fmt.Errorf("failed reading response length: %w", err) } messageLength := binary.BigEndian.Uint16(lengthBuf) if messageLength == 0 { + logger.Error("DoQ received zero-length message") return nil, fmt.Errorf("received zero-length message") } responseBuf := make([]byte, messageLength) _, err = io.ReadFull(quicStream, responseBuf) if err != nil { + logger.Error("DoQ failed to read response data: %v", err) return nil, fmt.Errorf("failed reading response data: %w", err) } recvMsg := new(dns.Msg) err = recvMsg.Unpack(responseBuf) if err != nil { + logger.Error("DoQ failed to parse response: %v", err) return nil, fmt.Errorf("failed to parse DNS response: %w", err) } + if len(recvMsg.Answer) > 0 { + logger.Debug("DoQ response from %s: %d answers", c.targetAddr, len(recvMsg.Answer)) + } + return recvMsg, nil } diff --git a/common/protocols/dot/dot.go b/common/protocols/dot/dot.go index 35b2fe9..b236ab2 100644 --- a/common/protocols/dot/dot.go +++ b/common/protocols/dot/dot.go @@ -8,6 +8,7 @@ import ( "net" "time" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -27,7 +28,10 @@ type Client struct { } func New(config Config) (*Client, error) { + logger.Debug("Creating DoT client: %s:%s", config.Host, config.Port) + if config.Host == "" { + logger.Error("DoT client creation failed: empty host") return nil, fmt.Errorf("dot: Host cannot be empty") } if config.WriteTimeout <= 0 { @@ -43,6 +47,8 @@ func New(config Config) (*Client, error) { ServerName: config.Host, } + logger.Debug("DoT client created: %s (DNSSEC: %v)", hostAndPort, config.DNSSEC) + return &Client{ hostAndPort: hostAndPort, tlsConfig: tlsConfig, @@ -51,6 +57,7 @@ func New(config Config) (*Client, error) { } func (c *Client) Close() { + logger.Debug("Closing DoT client") } func (c *Client) createConnection() (*tls.Conn, error) { @@ -58,10 +65,23 @@ func (c *Client) createConnection() (*tls.Conn, error) { Timeout: c.config.WriteTimeout, } - return tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig) + logger.Debug("Establishing DoT connection to %s", c.hostAndPort) + conn, err := tls.DialWithDialer(dialer, "tcp", c.hostAndPort, c.tlsConfig) + if err != nil { + logger.Error("DoT connection failed to %s: %v", c.hostAndPort, err) + return nil, err + } + + logger.Debug("DoT connection established to %s", c.hostAndPort) + return conn, nil } func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { + if len(msg.Question) > 0 { + question := msg.Question[0] + logger.Debug("DoT query: %s %s to %s", question.Name, dns.TypeToString[question.Qtype], c.hostAndPort) + } + // Create connection for this query conn, err := c.createConnection() if err != nil { @@ -75,6 +95,7 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { } packed, err := msg.Pack() if err != nil { + logger.Error("DoT failed to pack message: %v", err) return nil, fmt.Errorf("dot: failed to pack message: %w", err) } @@ -85,40 +106,51 @@ func (c *Client) Query(msg *dns.Msg) (*dns.Msg, error) { // Write query if err := conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)); err != nil { + logger.Error("DoT failed to set write deadline: %v", err) return nil, fmt.Errorf("dot: failed to set write deadline: %w", err) } if _, err := conn.Write(data); err != nil { + logger.Error("DoT failed to write message to %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("dot: failed to write message: %w", err) } // Read response if err := conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)); err != nil { + logger.Error("DoT failed to set read deadline: %v", err) return nil, fmt.Errorf("dot: failed to set read deadline: %w", err) } // Read message length lengthBuf := make([]byte, 2) if _, err := io.ReadFull(conn, lengthBuf); err != nil { + logger.Error("DoT failed to read response length from %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("dot: failed to read response length: %w", err) } msgLen := binary.BigEndian.Uint16(lengthBuf) if msgLen > dns.MaxMsgSize { + logger.Error("DoT response too large from %s: %d bytes", c.hostAndPort, msgLen) return nil, fmt.Errorf("dot: response message too large: %d", msgLen) } // Read message body buffer := make([]byte, msgLen) if _, err := io.ReadFull(conn, buffer); err != nil { + logger.Error("DoT failed to read response from %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("dot: failed to read response: %w", err) } // Parse response response := new(dns.Msg) if err := response.Unpack(buffer); err != nil { + logger.Error("DoT failed to unpack response from %s: %v", c.hostAndPort, err) return nil, fmt.Errorf("dot: failed to unpack response: %w", err) } + if len(response.Answer) > 0 { + logger.Debug("DoT response from %s: %d answers", c.hostAndPort, len(response.Answer)) + } + return response, nil } diff --git a/internal/client/client.go b/internal/client/client.go deleted file mode 100644 index 0a81128..0000000 --- a/internal/client/client.go +++ /dev/null @@ -1,195 +0,0 @@ -// internal/client/client.go -package client - -import ( - "fmt" - "io" - "net" - "net/url" - "strconv" - "strings" - "time" - - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" - "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" - // "github.com/afonsofrancof/sdns-perf/internal/protocols/doq" - // "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" - - "github.com/miekg/dns" -) - -// DNSClient defines the interface that all specific protocol clients must implement. -type DNSClient interface { - Query(domain string, queryType uint16) (*dns.Msg, error) - Close() -} - -// Options holds common configuration options for creating any DNS client. -type Options struct { - Timeout time.Duration - DNSSEC bool - KeyLogPath string // Path for TLS key logging -} - -type protocolType int - -const ( - protoUnknown protocolType = iota - protoDo53 - protoDoT - protoDoH - protoDoH3 - protoDoQ -) - -// config holds the parsed details of an upstream server string. -// This is internal to the client package. -type config struct { - original string - protocol protocolType - host string - port string - path string -} - -// parseUpstream takes a user-provided upstream string and attempts to determine -// the protocol, host, port, and path. (Internal helper) -func parseUpstream(upstreamStr string) (config, error) { - cfg := config{original: upstreamStr, protocol: protoUnknown} - - // Try parsing as a full URL first - parsedURL, err := url.Parse(upstreamStr) - if err == nil && parsedURL.Scheme != "" && parsedURL.Host != "" { - cfg.host = parsedURL.Hostname() - cfg.port = parsedURL.Port() - cfg.path = parsedURL.Path - if cfg.path == "" { - cfg.path = "/" // Default path - } - - switch strings.ToLower(parsedURL.Scheme) { - case "https", "doh": - cfg.protocol = protoDoH - if cfg.port == "" { - cfg.port = "443" - } - case "h3", "doh3": - cfg.protocol = protoDoH3 - if cfg.port == "" { - cfg.port = "443" - } - case "tls", "dot": - cfg.protocol = protoDoT - if cfg.port == "" { - cfg.port = "853" - } - case "quic", "doq": - cfg.protocol = protoDoQ - if cfg.port == "" { - cfg.port = "853" - } - case "udp", "do53": - cfg.protocol = protoDo53 - if cfg.port == "" { - cfg.port = "53" - } - default: - return cfg, fmt.Errorf("unsupported URL scheme: %q", parsedURL.Scheme) - } - return cfg, nil - } - - // If not a valid URL or no scheme, assume plain DNS (Do53 UDP) - cfg.protocol = protoDo53 - host, port, err := net.SplitHostPort(upstreamStr) - if err == nil { - cfg.host = host - cfg.port = port - if _, pErr := strconv.Atoi(port); pErr != nil { - return cfg, fmt.Errorf("invalid port %q in upstream %q: %w", port, upstreamStr, pErr) - } - } else { - cfg.host = upstreamStr - cfg.port = "53" - // Basic check for likely IPv6 without brackets and port - if strings.Contains(cfg.host, ":") && !strings.Contains(cfg.host, "[") { - _, resolveErr := net.ResolveUDPAddr("udp", net.JoinHostPort(cfg.host, cfg.port)) - if resolveErr != nil { - return cfg, fmt.Errorf("invalid upstream format; could not parse %q as host:port or resolve as host with default port 53: %w", upstreamStr, err) - } - } - } - - if cfg.host == "" { - return cfg, fmt.Errorf("could not extract host from upstream: %q", upstreamStr) - } - - return cfg, nil -} - -// New creates the appropriate DNS client based on the upstream string format. -// It returns an uninitialized client (connections are lazy). -func New(upstreamStr string, opts Options) (DNSClient, error) { - cfg, err := parseUpstream(upstreamStr) - if err != nil { - return nil, fmt.Errorf("client: failed to parse upstream %q: %w", upstreamStr, err) - } - - var client DNSClient - var clientErr error - - switch cfg.protocol { - case protoDo53: - // Ensure do53.New matches this signature - config := do53.Config{HostAndPort: net.JoinHostPort(cfg.host, cfg.port), DNSSEC: false} - client, clientErr = do53.New(config) - - case protoDoH: - // Ensure doh.New matches this signature - config := doh.Config{Host: cfg.host, Port: cfg.port, Path: cfg.path, DNSSEC: false} - client, clientErr = doh.New(config) - - case protoDoT: - // Ensure dot.New matches this signature - // client, clientErr = dot.New(cfg.hostPort(), opts.Timeout, opts.DNSSEC, opts.KeyLogPath) - // if clientErr == nil && client == nil { - // clientErr = fmt.Errorf("client: DoT package returned nil client without error") - // } - - case protoDoQ: - // Ensure doq.New matches this signature - // client, clientErr = doq.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath) - // if clientErr == nil && client == nil { - // clientErr = fmt.Errorf("client: DoQ package returned nil client without error") - // } - - case protoDoH3: - // Decide on DoH3 handling (fallback or error) - // Fallback example: - // fmt.Fprintf(os.Stderr, "Warning: DoH3 protocol (h3://) detected for %s. Attempting connection using standard DoH (HTTPS).\n", cfg.original) - // client, clientErr = doh.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath) - // Error example: - // clientErr = fmt.Errorf("client: DoH3 protocol (h3://) is not yet supported") - - default: - clientErr = fmt.Errorf("client: unknown or unsupported protocol detected for upstream: %s", upstreamStr) - } - - if clientErr != nil { - return nil, fmt.Errorf("client: failed to create client for %s: %w", upstreamStr, clientErr) - } - if client == nil { - // Should be caught by clientErr checks above, but as a safeguard - return nil, fmt.Errorf("client: internal error - nil client returned for %s", upstreamStr) - } - - return client, nil -} - -// Helper function to close key log writer if needed (can be used by specific clients) -func CloseKeyLogWriter(w io.WriteCloser) error { - if w != nil { - return w.Close() - } - return nil -} diff --git a/internal/protocols/dnscrypt/dnscrypt.go b/internal/protocols/dnscrypt/dnscrypt.go deleted file mode 100644 index 9939b27..0000000 --- a/internal/protocols/dnscrypt/dnscrypt.go +++ /dev/null @@ -1,3 +0,0 @@ -package dnscrypt - -// DNSCrypt resolver implementation diff --git a/internal/protocols/dnssec/dnssec.go b/internal/protocols/dnssec/dnssec.go deleted file mode 100644 index fb852f6..0000000 --- a/internal/protocols/dnssec/dnssec.go +++ /dev/null @@ -1,3 +0,0 @@ -package dnssec - -// DNSSEC resolver implementation diff --git a/internal/protocols/do53/do53.go b/internal/protocols/do53/do53.go deleted file mode 100644 index ee659f9..0000000 --- a/internal/protocols/do53/do53.go +++ /dev/null @@ -1,129 +0,0 @@ -package do53 - -import ( - "fmt" - "log" - "net" - "sync" - - "github.com/miekg/dns" -) - -type Config struct { - HostAndPort string - DNSSEC bool -} - -type Client struct { - udpAddr *net.UDPAddr - conn *net.UDPConn - - responseChannels map[uint16]chan *dns.Msg - responseMutex *sync.Mutex - - config Config -} - -func New(config Config) (*Client, error) { - udpAddr, err := net.ResolveUDPAddr("udp", config.HostAndPort) - if err != nil { - return nil, fmt.Errorf("do53: failed to resolve UDP address %q: %w", config.HostAndPort, err) - } - - conn, err := net.DialUDP("udp", nil, udpAddr) - if err != nil { - return nil, fmt.Errorf("do53: failed to dial UDP connection to %s: %w", config.HostAndPort, err) - } - - responseChannels := map[uint16]chan *dns.Msg{} - rcMutex := new(sync.Mutex) - - client := &Client{ - udpAddr: udpAddr, - conn: conn, - responseChannels: responseChannels, - responseMutex: rcMutex, - config: config, - } - - go client.receiveLoop() - - return client, nil -} - -func (c *Client) Close() { - if c.conn != nil { - c.conn.Close() - c.conn = nil - } -} - -func (c *Client) receiveLoop() { - - buffer := make([]byte, dns.MaxMsgSize) - - for { - // Reads one UDP Datagram - n, err := c.conn.Read(buffer) - if err != nil { - log.Printf("do53: failed to read DNS response: %s", err.Error()) - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(buffer[:n]) - if err != nil { - log.Printf("do53: failed to unpack DNS response: %s", err.Error()) - continue - } - - c.responseMutex.Lock() - respChan, ok := c.responseChannels[recvMsg.Id] - delete(c.responseChannels, recvMsg.Id) - c.responseMutex.Unlock() - - if ok { - respChan <- recvMsg - } else { - log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id) - } - } - -} - -func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { - - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain), queryType) - msg.Id = dns.Id() - msg.RecursionDesired = true - - if c.config.DNSSEC { - msg.SetEdns0(4096, true) - } - - respChan := make(chan *dns.Msg) - - c.responseMutex.Lock() - c.responseChannels[msg.Id] = respChan - c.responseMutex.Unlock() - - packedMsg, err := msg.Pack() - if err != nil { - c.responseMutex.Lock() - delete(c.responseChannels, msg.Id) - c.responseMutex.Unlock() - return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err) - } - - _, err = c.conn.Write(packedMsg) - if err != nil { - c.responseMutex.Lock() - delete(c.responseChannels, msg.Id) - c.responseMutex.Unlock() - return nil, fmt.Errorf("do53: failed to send DNS query: %w", err) - } - - recvMsg := <-respChan - - return recvMsg, nil -} diff --git a/internal/protocols/do53/packet.go b/internal/protocols/do53/packet.go deleted file mode 100644 index 29f6a55..0000000 --- a/internal/protocols/do53/packet.go +++ /dev/null @@ -1,40 +0,0 @@ -package do53 - -import ( - "github.com/miekg/dns" -) - -func NewDNSMessage(domain string, queryType string) ([]byte, error) { - - // TODO: Move this somewhere else and receive the type already parsed - var queryTypeValue uint16 - switch queryType { - case "A": - queryTypeValue = dns.TypeA - case "AAAA": - queryTypeValue = dns.TypeAAAA - case "MX": - queryTypeValue = dns.TypeMX - case "CNAME": - queryTypeValue = dns.TypeCNAME - case "TXT": - queryTypeValue = dns.TypeTXT - default: - queryTypeValue = dns.TypeA - } - - message := new(dns.Msg) - - message.Id = dns.Id() - message.Response = false - message.Opcode = dns.OpcodeQuery - message.Question = make([]dns.Question, 1) - message.Question[0] = dns.Question{Name: domain, Qtype: uint16(queryTypeValue), Qclass: dns.ClassINET} - message.Compress = true - wireMsg, err := message.Pack() - if err != nil { - return nil, err - } - - return wireMsg, nil -} diff --git a/internal/protocols/doh/doh.go b/internal/protocols/doh/doh.go deleted file mode 100644 index fa17cdf..0000000 --- a/internal/protocols/doh/doh.go +++ /dev/null @@ -1,125 +0,0 @@ -package doh - -import ( - "bytes" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - - "github.com/miekg/dns" -) - -const dnsMessageContentType = "application/dns-message" - -type Config struct { - Host string - Port string - Path string - DNSSEC bool -} - -type Client struct { - httpClient *http.Client - upstreamURL *url.URL - config Config -} - -func New(config Config) (*Client, error) { - if config.Host == "" || config.Port == "" || config.Path == "" { - fmt.Printf("%v,%v,%v", config.Host,config.Port,config.Path) - return nil, errors.New("doh: host, port, and path must not be empty") - } - - if !strings.HasPrefix(config.Path, "/") { - config.Path = "/" + config.Path - } - rawURL := "https://" + net.JoinHostPort(config.Host, config.Port) + config.Path - - parsedURL, err := url.Parse(rawURL) - if err != nil { - return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err) - } - - tlsConfig := &tls.Config{ - ServerName: config.Host, - MinVersion: tls.VersionTLS12, - } - - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - ForceAttemptHTTP2: true, - } - - httpClient := &http.Client{ - Transport: transport, - } - - return &Client{ - httpClient: httpClient, - upstreamURL: parsedURL, - config: config, - }, nil -} - -// Close cleans up idle connections held by the underlying HTTP transport. -func (c *Client) Close() { - c.httpClient.CloseIdleConnections() -} - -func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { - msg := new(dns.Msg) - msg.SetQuestion(dns.Fqdn(domain), queryType) - msg.Id = dns.Id() - msg.RecursionDesired = true - - if c.config.DNSSEC { - msg.SetEdns0(4096, true) - } - - packedMsg, err := msg.Pack() - if err != nil { - return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err) - } - - httpReq, err := http.NewRequest("POST", c.upstreamURL.String(), bytes.NewReader(packedMsg)) - if err != nil { - return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err) - } - - httpReq.Header.Set("User-Agent", "sdns-perf") - httpReq.Header.Set("Content-Type", dnsMessageContentType) - httpReq.Header.Set("Accept", dnsMessageContentType) - - httpResp, err := c.httpClient.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err) - } - defer httpResp.Body.Close() - - if httpResp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status) - } - - if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType { - return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType) - } - - responseBody, err := io.ReadAll(httpResp.Body) - if err != nil { - return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err) - } - - // Unpack the DNS message - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBody) - if err != nil { - return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err) - } - - return recvMsg, nil -} diff --git a/internal/protocols/doq/doq.go b/internal/protocols/doq/doq.go deleted file mode 100644 index d99b95b..0000000 --- a/internal/protocols/doq/doq.go +++ /dev/null @@ -1,175 +0,0 @@ -package doq - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/binary" - "fmt" - "io" - "net" - "os" - "time" - - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" - "github.com/miekg/dns" - "github.com/quic-go/quic-go" -) - -type Client struct { - targetAddr *net.UDPAddr - keyLogFile *os.File - tlsConfig *tls.Config - udpConn *net.UDPConn - quicConn quic.Connection - quicTransport *quic.Transport - quicConfig *quic.Config -} - -func New(target string) (*Client, error) { - keyLogFile, err := os.OpenFile( - "tls-key-log.txt", - os.O_APPEND|os.O_CREATE|os.O_WRONLY, - 0600, - ) - if err != nil { - return nil, fmt.Errorf("failed opening key log file: %w", err) - } - - tlsConfig := &tls.Config{ - // FIX: Actually check the domain name - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS13, - ClientSessionCache: tls.NewLRUClientSessionCache(100), - KeyLogWriter: keyLogFile, - NextProtos: []string{"doq"}, - } - - udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:6000") - if err != nil { - return nil, fmt.Errorf("failed to resolve target address: %w", err) - } - targetAddr, err := net.ResolveUDPAddr("udp", target) - if err != nil { - return nil, err - } - udpConn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, fmt.Errorf("failed to connect to target address: %w", err) - } - - quicTransport := quic.Transport{ - Conn: udpConn, - } - - quicConfig := quic.Config{ - // Use the default value of 30 seconds - MaxIdleTimeout: 30 * time.Second, - } - - return &Client{ - targetAddr: targetAddr, - keyLogFile: keyLogFile, - tlsConfig: tlsConfig, - udpConn: udpConn, - quicConn: nil, - quicTransport: &quicTransport, - quicConfig: &quicConfig, - }, nil -} - -func (c *Client) Close() { - if c.keyLogFile != nil { - c.keyLogFile.Close() - } - if c.udpConn != nil { - c.udpConn.Close() - } -} - -func (c *Client) OpenConnection() error { - quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) - if err != nil { - return err - } - - c.quicConn = quicConn - return nil -} - -func (c *Client) Query(domain, queryType string, dnssec bool) error { - - if c.quicConn == nil { - err := c.OpenConnection() - if err != nil { - return err - } - } - - DNSMessage, err := do53.NewDNSMessage(domain, queryType) - if err != nil { - return err - } - - var quicStream quic.Stream - quicStream, err = c.quicConn.OpenStream() - if err != nil { - err = c.OpenConnection() - if err != nil { - return err - } - quicStream, err = c.quicConn.OpenStream() - if err != nil { - return err - } - } - - var lengthPrefixedMessage bytes.Buffer - err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) - if err != nil { - return fmt.Errorf("failed to write message length: %w", err) - } - _, err = lengthPrefixedMessage.Write(DNSMessage) - if err != nil { - return fmt.Errorf("failed to write DNS message: %w", err) - } - - _, err = quicStream.Write(lengthPrefixedMessage.Bytes()) - if err != nil { - return fmt.Errorf("failed writing to QUIC stream: %w", err) - } - // Indicate that no further data will be written from this side - quicStream.Close() - - lengthBuf := make([]byte, 2) - _, err = io.ReadFull(quicStream, lengthBuf) - if err != nil { - return fmt.Errorf("failed reading response length: %w", err) - } - - messageLength := binary.BigEndian.Uint16(lengthBuf) - if messageLength == 0 { - return fmt.Errorf("received zero-length message") - } - - responseBuf := make([]byte, messageLength) - _, err = io.ReadFull(quicStream, responseBuf) - if err != nil { - return fmt.Errorf("failed reading response data: %w", err) - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBuf) - if err != nil { - return fmt.Errorf("failed to parse DNS response: %w", err) - } - - // TODO: Check if the response had no errors or TD bit set - - fmt.Println(c.quicConn.ConnectionState().Used0RTT) - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } - - return nil -} diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go deleted file mode 100644 index d314281..0000000 --- a/internal/protocols/dot/dot.go +++ /dev/null @@ -1,161 +0,0 @@ -package dot - -import ( - "context" - "crypto/tls" - "encoding/binary" - "fmt" - "io" - "log" - "net" - "os" - "sync" - "time" - - "github.com/miekg/dns" -) - -type Config struct { - Host string - Port string - DNSSEC bool - Debug bool -} - -type Client struct { - config Config - - serverAddr *net.TCPAddr - - tcpConn *net.TCPConn - tlsConn *tls.Conn - tlsConfig *tls.Config - keyLogFile *os.File - - sendChannel chan *dns.Msg - - responseChannels map[uint16]chan *dns.Msg - responseMutex *sync.Mutex -} - -func New(config Config) (*Client, error) { - serverAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(config.Host, config.Port)) - if err != nil { - return nil, fmt.Errorf("dot: failed to resolve TCP address %q: %w", config.Host, err) - } - - var keyLogFile *os.File - if config.Debug { - keyLogFile, err = os.OpenFile( - "tls-key-log.txt", - os.O_APPEND|os.O_CREATE|os.O_WRONLY, - 0600, - ) - if err != nil { - log.Printf("dot: failed opening TLS key log file: %v", err) - keyLogFile = nil - } - } - - tlsConfig := &tls.Config{ - ServerName: serverAddr.IP.String(), - MinVersion: tls.VersionTLS12, - KeyLogWriter: keyLogFile, - ClientSessionCache: tls.NewLRUClientSessionCache(100), - } - - client := &Client{ - config: config, - serverAddr: serverAddr, - tlsConfig: tlsConfig, - keyLogFile: keyLogFile, - } - - go client.receiveLoop() - - return client, nil -} - -func (c *Client) Close() { - if c.tlsConn != nil { - c.tlsConn.Close() - c.tlsConn = nil - } - - if c.tcpConn != nil { - c.tcpConn.Close() - c.tcpConn = nil - } - - if c.keyLogFile != nil { - c.keyLogFile.Close() - c.keyLogFile = nil - } -} - -func (c *Client) receiveLoop() { - - lengthBuffer := make([]byte, 2) - buffer := make([]byte, dns.MaxMsgSize) - - for { - msgSize, err := io.ReadFull(c.tlsConn, lengthBuffer) - if err != nil { - log.Printf("doh: failed to read the DNS message's size: %s", err.Error()) - // FIX: HANDLE RECONNECTION - } - n, err := io.ReadFull(c.tlsConn, buffer[:msgSize]) - if err != nil { - log.Printf("doh: failed to read the DNS message: %s", err.Error()) - // FIX: HANDLE RECONNECTION - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(buffer[:n]) - if err != nil { - log.Printf("do53: failed to unpack DNS response: %s", err.Error()) - continue - } - - c.responseMutex.Lock() - respChan, ok := c.responseChannels[recvMsg.Id] - delete(c.responseChannels, recvMsg.Id) - c.responseMutex.Unlock() - - if ok { - respChan <- recvMsg - } else { - log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id) - } - - } - -} - -func (c *Client) connect(ctx context.Context) error { - tcpConn, err := net.DialTCP("tcp", nil, c.serverAddr) - if err != nil { - return fmt.Errorf("dot: failed to establish TCP connection: %w", err) - } - - c.tcpConn.SetKeepAlive(true) - c.tcpConn.SetKeepAlivePeriod(1 * time.Minute) - - tlsConn := tls.Client(c.tcpConn, c.tlsConfig) - err = tlsConn.HandshakeContext(ctx) - if err != nil { - c.tcpConn.Close() - c.tcpConn = nil - return fmt.Errorf("dot: failed to execute the TLS handshake: %w", err) - } - - c.tlsConn = tlsConn - - log.Println("dot: TCP/TLS connection established successfully.") - - return nil -} - -func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { - //TODO -} diff --git a/main.go b/main.go index ba0a058..254004b 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,11 @@ package main import ( "fmt" - "log" - "os" "strings" "time" "github.com/afonsofrancof/sdns-proxy/client" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/afonsofrancof/sdns-proxy/server" "github.com/alecthomas/kong" @@ -15,6 +14,7 @@ import ( ) var cli struct { + Debug bool `help:"Enable debug logging globally." short:"D" env:"DEBUG"` Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."` Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."` } @@ -41,7 +41,7 @@ type ListenCmd struct { } func (q *QueryCmd) Run() error { - log.Printf("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)\n", + logger.Info("Querying %s for %s type %s (DNSSEC: %v, ValidateOnly: %v, StrictValidation: %v, Timeout: %v)", q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.ValidateOnly, q.StrictValidation, q.Timeout) opts := client.Options{ @@ -50,14 +50,17 @@ func (q *QueryCmd) Run() error { StrictValidation: q.StrictValidation, } + logger.Debug("Creating DNS client with options: %+v", opts) dnsClient, err := client.New(q.Server, opts) if err != nil { + logger.Error("Failed to create DNS client: %v", err) return err } defer dnsClient.Close() qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)] if !ok { + logger.Error("Invalid query type: %s", q.QueryType) return fmt.Errorf("invalid query type: %s", q.QueryType) } @@ -67,11 +70,15 @@ func (q *QueryCmd) Run() error { msg.Id = dns.Id() msg.RecursionDesired = true + logger.Debug("Sending DNS query: ID=%d, Question=%s %s", msg.Id, q.DomainName, q.QueryType) recvMsg, err := dnsClient.Query(msg) if err != nil { + logger.Error("DNS query failed: %v", err) return err } + logger.Debug("Received DNS response: ID=%d, Rcode=%s, Answers=%d", + recvMsg.Id, dns.RcodeToString[recvMsg.Rcode], len(recvMsg.Answer)) printResponse(recvMsg.Question[0].Name, q.QueryType, recvMsg) return nil } @@ -87,15 +94,17 @@ func (l *ListenCmd) Run() error { Verbose: l.Verbose, } + logger.Debug("Server config: %+v", config) srv, err := server.New(config) if err != nil { + logger.Error("Failed to create server: %v", err) return fmt.Errorf("failed to create server: %w", err) } - log.Printf("Starting DNS proxy server on %s", l.Address) - log.Printf("Upstream server: %v", l.Upstream) - log.Printf("Fallback server: %v", l.Fallback) - log.Printf("Bootstrap server: %v", l.Bootstrap) + logger.Info("Starting DNS proxy server on %s", l.Address) + logger.Info("Upstream server: %v", l.Upstream) + logger.Info("Fallback server: %v", l.Fallback) + logger.Info("Bootstrap server: %v", l.Bootstrap) return srv.Start() } @@ -153,9 +162,6 @@ func printResponse(domain, qtype string, msg *dns.Msg) { } func main() { - log.SetOutput(os.Stderr) - log.SetFlags(log.Ltime | log.Lshortfile) - kongCtx := kong.Parse(&cli, kong.Name("sdns-proxy"), kong.Description("A DNS client/server tool supporting multiple protocols."), @@ -163,6 +169,10 @@ func main() { kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}), ) + // Set global debug flag + logger.SetDebug(cli.Debug) + logger.Debug("Debug logging enabled") + err := kongCtx.Run() kongCtx.FatalIfErrorf(err) } diff --git a/server/server.go b/server/server.go index 79d25dc..5bb9581 100644 --- a/server/server.go +++ b/server/server.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "log" "net" "net/url" "os" @@ -14,6 +13,7 @@ import ( "time" "github.com/afonsofrancof/sdns-proxy/client" + "github.com/afonsofrancof/sdns-proxy/common/logger" "github.com/miekg/dns" ) @@ -50,7 +50,10 @@ type Server struct { } func New(config Config) (*Server, error) { + logger.Debug("Creating new server with config: %+v", config) + if config.Upstream == "" { + logger.Error("Upstream server is required") return nil, fmt.Errorf("upstream server is required") } @@ -60,11 +63,17 @@ func New(config Config) (*Server, error) { needsBootstrap = needsBootstrap || containsHostname(config.Fallback) } + logger.Debug("Bootstrap needed: %v (upstream has hostname: %v, fallback has hostname: %v)", + needsBootstrap, containsHostname(config.Upstream), + config.Fallback != "" && containsHostname(config.Fallback)) + if needsBootstrap && config.Bootstrap == "" { + logger.Error("Bootstrap server is required when upstream or fallback contains hostnames") return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames") } if config.Bootstrap != "" && containsHostname(config.Bootstrap) { + logger.Error("Bootstrap server cannot contain hostnames: %s", config.Bootstrap) return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap) } @@ -76,17 +85,21 @@ func New(config Config) (*Server, error) { // Create bootstrap client if needed if config.Bootstrap != "" { + logger.Debug("Creating bootstrap client for %s", config.Bootstrap) bootstrapClient, err := client.New(config.Bootstrap, client.Options{ - DNSSEC: false, // For now + DNSSEC: false, }) if err != nil { + logger.Error("Failed to create bootstrap client: %v", err) return nil, fmt.Errorf("failed to create bootstrap client: %w", err) } s.bootstrapClient = bootstrapClient + logger.Debug("Bootstrap client created successfully") } // Initialize upstream and fallback clients if err := s.initClients(); err != nil { + logger.Error("Failed to initialize clients: %v", err) return nil, fmt.Errorf("failed to initialize clients: %w", err) } @@ -100,86 +113,111 @@ func New(config Config) (*Server, error) { Handler: mux, } + logger.Debug("Server created successfully, listening on %s", config.Address) return s, nil } func containsHostname(serverAddr string) bool { + logger.Debug("Checking if %s contains hostname", serverAddr) + // Use the same parsing logic as the client package parsedURL, err := url.Parse(serverAddr) if err != nil { + logger.Debug("URL parsing failed for %s, treating as plain address", serverAddr) // If URL parsing fails, assume it's a plain address host, _, err := net.SplitHostPort(serverAddr) if err != nil { // Assume it's just a host - return net.ParseIP(serverAddr) == nil + isHostname := net.ParseIP(serverAddr) == nil + logger.Debug("Address %s is hostname: %v", serverAddr, isHostname) + return isHostname } - return net.ParseIP(host) == nil + isHostname := net.ParseIP(host) == nil + logger.Debug("Host %s from %s is hostname: %v", host, serverAddr, isHostname) + return isHostname } host := parsedURL.Hostname() if host == "" { + logger.Debug("No hostname found in URL %s", serverAddr) return false } - return net.ParseIP(host) == nil + isHostname := net.ParseIP(host) == nil + logger.Debug("Host %s from URL %s is hostname: %v", host, serverAddr, isHostname) + return isHostname } func (s *Server) initClients() error { + logger.Debug("Initializing DNS clients") + // Initialize upstream client resolvedUpstream, err := s.resolveServerAddress(s.config.Upstream) if err != nil { + logger.Error("Failed to resolve upstream %s: %v", s.config.Upstream, err) return fmt.Errorf("failed to resolve upstream %s: %w", s.config.Upstream, err) } + logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream) upstreamClient, err := client.New(resolvedUpstream, client.Options{ DNSSEC: s.config.DNSSEC, }) if err != nil { + logger.Error("Failed to create upstream client: %v", err) return fmt.Errorf("failed to create upstream client: %w", err) } s.upstreamClient = upstreamClient if s.config.Verbose { - log.Printf("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream) + logger.Info("Initialized upstream client: %s -> %s", s.config.Upstream, resolvedUpstream) } // Initialize fallback client if specified if s.config.Fallback != "" { resolvedFallback, err := s.resolveServerAddress(s.config.Fallback) if err != nil { + logger.Error("Failed to resolve fallback %s: %v", s.config.Fallback, err) return fmt.Errorf("failed to resolve fallback %s: %w", s.config.Fallback, err) } + logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback) fallbackClient, err := client.New(resolvedFallback, client.Options{ DNSSEC: s.config.DNSSEC, }) if err != nil { + logger.Error("Failed to create fallback client: %v", err) return fmt.Errorf("failed to create fallback client: %w", err) } s.fallbackClient = fallbackClient if s.config.Verbose { - log.Printf("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback) + logger.Info("Initialized fallback client: %s -> %s", s.config.Fallback, resolvedFallback) } } + logger.Debug("All DNS clients initialized successfully") return nil } func (s *Server) resolveServerAddress(serverAddr string) (string, error) { + logger.Debug("Resolving server address: %s", serverAddr) + // If it doesn't contain hostnames, return as-is if !containsHostname(serverAddr) { + logger.Debug("Address %s contains no hostnames, returning as-is", serverAddr) return serverAddr, nil } // If no bootstrap client, we can't resolve hostnames if s.bootstrapClient == nil { + logger.Error("Cannot resolve hostname in %s: no bootstrap server configured", serverAddr) return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr) } // Use the same parsing logic as the client package parsedURL, err := url.Parse(serverAddr) if err != nil { + logger.Debug("Parsing %s as plain host:port format", serverAddr) // Handle plain host:port format host, port, err := net.SplitHostPort(serverAddr) if err != nil { @@ -188,6 +226,7 @@ func (s *Server) resolveServerAddress(serverAddr string) (string, error) { if err != nil { return "", err } + logger.Debug("Resolved %s to %s", serverAddr, resolvedIP) return resolvedIP, nil } @@ -195,12 +234,15 @@ func (s *Server) resolveServerAddress(serverAddr string) (string, error) { if err != nil { return "", err } - return net.JoinHostPort(resolvedIP, port), nil + resolved := net.JoinHostPort(resolvedIP, port) + logger.Debug("Resolved %s to %s", serverAddr, resolved) + return resolved, nil } // Handle URL format hostname := parsedURL.Hostname() if hostname == "" { + logger.Error("No hostname in URL: %s", serverAddr) return "", fmt.Errorf("no hostname in URL: %s", serverAddr) } @@ -217,21 +259,26 @@ func (s *Server) resolveServerAddress(serverAddr string) (string, error) { parsedURL.Host = net.JoinHostPort(resolvedIP, port) } - return parsedURL.String(), nil + resolved := parsedURL.String() + logger.Debug("Resolved URL %s to %s", serverAddr, resolved) + return resolved, nil } func (s *Server) resolveHostname(hostname string) (string, error) { + logger.Debug("Resolving hostname: %s", hostname) + // Check cache first s.hostsMutex.RLock() if ip, exists := s.resolvedHosts[hostname]; exists { s.hostsMutex.RUnlock() + logger.Debug("Found cached resolution for %s: %s", hostname, ip) return ip, nil } s.hostsMutex.RUnlock() // Resolve using bootstrap if s.config.Verbose { - log.Printf("Resolving hostname %s using bootstrap server", hostname) + logger.Info("Resolving hostname %s using bootstrap server", hostname) } msg := new(dns.Msg) @@ -239,12 +286,16 @@ func (s *Server) resolveHostname(hostname string) (string, error) { msg.Id = dns.Id() msg.RecursionDesired = true + logger.Debug("Sending bootstrap query for %s (ID: %d)", hostname, msg.Id) msg, err := s.bootstrapClient.Query(msg) if err != nil { + logger.Error("Bootstrap query failed for %s: %v", hostname, err) return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err) } + logger.Debug("Bootstrap response for %s: %d answers", hostname, len(msg.Answer)) if len(msg.Answer) == 0 { + logger.Error("No A records found for %s", hostname) return "", fmt.Errorf("no A records found for %s", hostname) } @@ -259,18 +310,21 @@ func (s *Server) resolveHostname(hostname string) (string, error) { s.hostsMutex.Unlock() if s.config.Verbose { - log.Printf("Resolved %s to %s", hostname, ip) + logger.Info("Resolved %s to %s", hostname, ip) } + logger.Debug("Cached resolution: %s -> %s", hostname, ip) return ip, nil } } + logger.Error("No valid A record found for %s", hostname) return "", fmt.Errorf("no valid A record found for %s", hostname) } func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { + logger.Debug("Received request with no questions from %s", w.RemoteAddr()) dns.HandleFailed(w, r) return } @@ -279,8 +333,11 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { domain := strings.ToLower(question.Name) qtype := question.Qtype + logger.Debug("Handling DNS request: %s %s from %s (ID: %d)", + question.Name, dns.TypeToString[qtype], w.RemoteAddr(), r.Id) + if s.config.Verbose { - log.Printf("Query: %s %s from %s", + logger.Info("Query: %s %s from %s", question.Name, dns.TypeToString[qtype], w.RemoteAddr()) @@ -290,40 +347,48 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil { response := s.buildResponse(r, cachedRecords) if s.config.Verbose { - log.Printf("Cache hit: %s %s -> %d records", + logger.Info("Cache hit: %s %s -> %d records", question.Name, dns.TypeToString[qtype], len(cachedRecords)) } + logger.Debug("Serving cached response for %s %s (%d records)", + question.Name, dns.TypeToString[qtype], len(cachedRecords)) w.WriteMsg(response) return } + logger.Debug("Cache miss for %s %s, querying upstream", question.Name, dns.TypeToString[qtype]) + // Try upstream first response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype) if err != nil { if s.config.Verbose { - log.Printf("Upstream query failed: %v", err) + logger.Info("Upstream query failed: %v", err) } + logger.Debug("Upstream query failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err) // Try fallback if available if s.fallbackClient != nil { if s.config.Verbose { - log.Printf("Trying fallback server") + logger.Info("Trying fallback server") } + logger.Debug("Attempting fallback query for %s %s", question.Name, dns.TypeToString[qtype]) response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype) if err != nil { - log.Printf("Both upstream and fallback failed for %s %s: %v", + logger.Error("Both upstream and fallback failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err) + } else { + logger.Debug("Fallback query succeeded for %s %s", question.Name, dns.TypeToString[qtype]) } } // If still failed, return SERVFAIL if err != nil { - log.Printf("All servers failed for %s %s: %v", + logger.Error("All servers failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err) @@ -334,6 +399,8 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) return } + } else { + logger.Debug("Upstream query succeeded for %s %s", question.Name, dns.TypeToString[qtype]) } // Cache successful response @@ -343,12 +410,14 @@ func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { response.Id = r.Id if s.config.Verbose { - log.Printf("Response: %s %s -> %d answers", + logger.Info("Response: %s %s -> %d answers", question.Name, dns.TypeToString[qtype], len(response.Answer)) } + logger.Debug("Sending response for %s %s: %d answers, rcode: %s", + question.Name, dns.TypeToString[qtype], len(response.Answer), dns.RcodeToString[response.Rcode]) w.WriteMsg(response) } @@ -360,17 +429,22 @@ func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR { s.cacheMutex.RUnlock() if !exists { + logger.Debug("No cache entry for %s %s", domain, dns.TypeToString[qtype]) return nil } // Check if expired and clean up on the spot if time.Now().After(entry.expiresAt) { + logger.Debug("Cache entry expired for %s %s", domain, dns.TypeToString[qtype]) s.cacheMutex.Lock() delete(s.queryCache, key) s.cacheMutex.Unlock() return nil } + logger.Debug("Cache hit for %s %s (%d records, expires in %v)", + domain, dns.TypeToString[qtype], len(entry.records), time.Until(entry.expiresAt)) + // Return a copy of the cached records records := make([]dns.RR, len(entry.records)) for i, rr := range entry.records { @@ -382,14 +456,14 @@ func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR { func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg { response := new(dns.Msg) response.SetReply(request) - response.Answer = records - + logger.Debug("Built response with %d records", len(records)) return response } func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) { if msg == nil || len(msg.Answer) == 0 { + logger.Debug("Not caching empty response for %s %s", domain, dns.TypeToString[qtype]) return } @@ -408,11 +482,13 @@ func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) { } if len(validRecords) == 0 { + logger.Debug("No valid records to cache for %s %s", domain, dns.TypeToString[qtype]) return } // Don't cache responses with very low TTL if minTTL < 10 { + logger.Debug("TTL too low (%ds) for caching %s %s", minTTL, domain, dns.TypeToString[qtype]) return } @@ -427,12 +503,16 @@ func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) { s.cacheMutex.Unlock() if s.config.Verbose { - log.Printf("Cached %d records for %s %s (TTL: %ds)", + logger.Info("Cached %d records for %s %s (TTL: %ds)", len(validRecords), domain, dns.TypeToString[qtype], minTTL) } + logger.Debug("Cached %d records for %s %s (TTL: %ds, expires: %v)", + len(validRecords), domain, dns.TypeToString[qtype], minTTL, entry.expiresAt) } func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) { + logger.Debug("Querying upstream for %s %s", domain, dns.TypeToString[qtype]) + // Create context with timeout ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout) defer cancel() @@ -450,7 +530,15 @@ func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, q msg.SetQuestion(dns.Fqdn(domain), qtype) msg.Id = dns.Id() msg.RecursionDesired = true + + logger.Debug("Sending upstream query: %s %s (ID: %d)", domain, dns.TypeToString[qtype], msg.Id) recvMsg, err := upstreamClient.Query(msg) + if err != nil { + logger.Debug("Upstream query error for %s %s: %v", domain, dns.TypeToString[qtype], err) + } else { + logger.Debug("Upstream query response for %s %s: %d answers, rcode: %s", + domain, dns.TypeToString[qtype], len(recvMsg.Answer), dns.RcodeToString[recvMsg.Rcode]) + } resultChan <- result{msg: recvMsg, err: err} }() @@ -458,6 +546,7 @@ func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, q case res := <-resultChan: return res.msg, res.err case <-ctx.Done(): + logger.Debug("Upstream query timeout for %s %s after %v", domain, dns.TypeToString[qtype], s.config.Timeout) return nil, fmt.Errorf("upstream query timeout") } } @@ -466,29 +555,38 @@ func (s *Server) Start() error { go func() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - <-sigChan - log.Println("Shutting down DNS server...") + sig := <-sigChan + logger.Info("Received signal %v, shutting down DNS server...", sig) s.Shutdown() }() - log.Printf("DNS proxy server listening on %s", s.config.Address) + logger.Info("DNS proxy server listening on %s", s.config.Address) + logger.Debug("Server starting with timeout: %v, DNSSEC: %v", s.config.Timeout, s.config.DNSSEC) return s.dnsServer.ListenAndServe() } func (s *Server) Shutdown() { + logger.Debug("Shutting down server components") + if s.dnsServer != nil { + logger.Debug("Shutting down DNS server") s.dnsServer.Shutdown() } if s.upstreamClient != nil { + logger.Debug("Closing upstream client") s.upstreamClient.Close() } if s.fallbackClient != nil { + logger.Debug("Closing fallback client") s.fallbackClient.Close() } if s.bootstrapClient != nil { + logger.Debug("Closing bootstrap client") s.bootstrapClient.Close() } + + logger.Info("Server shutdown complete") }