From 7f1b11883210ee5e200be1c660d179d379448550 Mon Sep 17 00:00:00 2001 From: afranco Date: Sat, 27 Sep 2025 22:28:22 +0100 Subject: [PATCH] fix(paths): change output paths --- analyze_dns_metrics.py | 206 ++++++++++++++ internal/qol/measurement.go | 11 +- internal/qol/utils.go | 109 +++++++- server/server.go | 540 +++++++++++++++++++++++++++++++++++- 4 files changed, 851 insertions(+), 15 deletions(-) create mode 100644 analyze_dns_metrics.py diff --git a/analyze_dns_metrics.py b/analyze_dns_metrics.py new file mode 100644 index 0000000..7079945 --- /dev/null +++ b/analyze_dns_metrics.py @@ -0,0 +1,206 @@ +import csv +import os +import statistics +from collections import defaultdict +from pathlib import Path + +def map_server_to_resolver(server): + """Map server address/domain to resolver name""" + server_lower = server.lower() + + if '1.1.1.1' in server_lower or 'cloudflare' in server_lower: + return 'Cloudflare' + elif '8.8.8.8' in server_lower or 'dns.google' in server_lower: + return 'Google' + elif '9.9.9.9' in server_lower or 'quad9' in server_lower: + return 'Quad9' + elif 'adguard' in server_lower: + return 'AdGuard' + else: + return server # Fallback to original server name + +def extract_from_new_format(filename): + """Parse new filename format: protocol[-flags]-timestamp.csv""" + base = filename.replace('.csv', '') + parts = base.split('-') + + if len(parts) < 2: + return None, None, None + + protocol = parts[0] + timestamp = parts[-1] + + # Flags are everything between protocol and timestamp + flags_str = '-'.join(parts[1:-1]) + dnssec_status = 'on' if 'dnssec' in flags_str else 'off' + keepalive_status = 'on' if 'persist' in flags_str else 'off' + + return protocol, dnssec_status, keepalive_status + +def extract_server_info(file_path, dns_server_field): + """Extract info using directory structure and filename""" + path = Path(file_path) + + # Expect structure like: results/date/server/filename.csv + parts = path.parts + if len(parts) >= 3 and parts[-3].isdigit() and len(parts[-3]) == 10: # date folder like 2024-03-01 + server = parts[-2] # server folder + filename = parts[-1] + + protocol, dnssec_status, keepalive_status = extract_from_new_format(filename) + if protocol: + return protocol, server, dnssec_status, keepalive_status + + # Fallback to old parsing if structure doesn't match + filename = path.name + old_parts = filename.replace('.csv', '').split('_') + + if len(old_parts) >= 6: + protocol = old_parts[0] + + try: + dnssec_idx = old_parts.index('dnssec') + keepalive_idx = old_parts.index('keepalive') + + server_parts = old_parts[1:dnssec_idx] + server = '_'.join(server_parts) + + dnssec_status = old_parts[dnssec_idx + 1] if dnssec_idx + 1 < len(old_parts) else 'off' + keepalive_status = old_parts[keepalive_idx + 1] if keepalive_idx + 1 < len(old_parts) else 'off' + + return protocol, server, dnssec_status, keepalive_status + + except ValueError: + pass + + # Even older format fallback + if len(old_parts) >= 4: + protocol = old_parts[0] + dnssec_status = 'on' if 'dnssec_on' in filename else 'off' + keepalive_status = 'on' if 'keepalive_on' in filename else 'off' + server = '_'.join(old_parts[1:-4]) if len(old_parts) > 4 else old_parts[1] + + return protocol, server, dnssec_status, keepalive_status + + return None, None, None, None + +def analyze_dns_data(root_directory, output_file): + """Analyze DNS data and generate metrics""" + + # Dictionary to store measurements: {(resolver, protocol, dnssec, keepalive): [durations]} + measurements = defaultdict(list) + + # Walk through all directories + for root, dirs, files in os.walk(root_directory): + for file in files: + if file.endswith('.csv'): + file_path = os.path.join(root, file) + print(f"Processing: {file_path}") + + try: + with open(file_path, 'r', newline='') as csvfile: + reader = csv.DictReader(csvfile) + + for row_num, row in enumerate(reader, 2): # Start at 2 since header is row 1 + try: + protocol, server, dnssec_status, keepalive_status = extract_server_info( + file_path, row.get('dns_server', '')) + + if protocol and server: + resolver = map_server_to_resolver(server) + duration_ms = float(row.get('duration_ms', 0)) + + # Only include successful queries + if row.get('response_code', '') in ['NOERROR', '']: + key = (resolver, protocol, dnssec_status, keepalive_status) + measurements[key].append(duration_ms) + + except (ValueError, TypeError) as e: + print(f"Data parse error in {file_path} row {row_num}: {e}") + continue + + except Exception as e: + print(f"Error processing file {file_path}: {e}") + continue + + # Calculate statistics and group by resolver, dnssec, and keepalive + resolver_results = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + + for (resolver, protocol, dnssec, keepalive), durations in measurements.items(): + if durations: + stats = { + 'protocol': protocol.upper(), + 'total_queries': len(durations), + 'avg_latency_ms': round(statistics.mean(durations), 3), + 'median_latency_ms': round(statistics.median(durations), 3), + 'min_latency_ms': round(min(durations), 3), + 'max_latency_ms': round(max(durations), 3), + 'std_dev_ms': round(statistics.stdev(durations) if len(durations) > 1 else 0, 3), + 'p95_latency_ms': round(statistics.quantiles(durations, n=20)[18], 3) if len(durations) >= 20 else round(max(durations), 3), + 'p99_latency_ms': round(statistics.quantiles(durations, n=100)[98], 3) if len(durations) >= 100 else round(max(durations), 3) + } + resolver_results[dnssec][keepalive][resolver].append(stats) + + # Sort each resolver's results by average latency + for dnssec in resolver_results: + for keepalive in resolver_results[dnssec]: + for resolver in resolver_results[dnssec][keepalive]: + resolver_results[dnssec][keepalive][resolver].sort(key=lambda x: x['avg_latency_ms']) + + # Write to CSV with all data + all_results = [] + for dnssec in resolver_results: + for keepalive in resolver_results[dnssec]: + for resolver, results in resolver_results[dnssec][keepalive].items(): + for result in results: + result['resolver'] = resolver + result['dnssec'] = dnssec + result['keepalive'] = keepalive + all_results.append(result) + + with open(output_file, 'w', newline='') as csvfile: + fieldnames = [ + 'resolver', 'protocol', 'dnssec', 'keepalive', 'total_queries', + 'avg_latency_ms', 'median_latency_ms', 'min_latency_ms', + 'max_latency_ms', 'std_dev_ms', 'p95_latency_ms', 'p99_latency_ms' + ] + + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(all_results) + + print(f"\nAnalysis complete! Full results written to {output_file}") + print(f"Total measurements: {sum(len(durations) for durations in measurements.values())}") + + def print_resolver_table(resolver, results, dnssec_status, keepalive_status): + """Print a formatted table for a resolver""" + ka_indicator = "PERSISTENT" if keepalive_status == 'on' else "NEW CONNECTION" + print(f"\n{resolver} DNS Resolver (DNSSEC {dnssec_status.upper()}, {ka_indicator})") + print("=" * 100) + print(f"{'Protocol':<12} {'Queries':<8} {'Avg(ms)':<10} {'Median(ms)':<12} {'Min(ms)':<10} {'Max(ms)':<10} {'P95(ms)':<10}") + print("-" * 100) + + for result in results: + print(f"{result['protocol']:<12} {result['total_queries']:<8} " + f"{result['avg_latency_ms']:<10} {result['median_latency_ms']:<12} " + f"{result['min_latency_ms']:<10} {result['max_latency_ms']:<10} " + f"{result['p95_latency_ms']:<10}") + + # Print tables organized by DNSSEC and KeepAlive status + for dnssec_status in ['off', 'on']: + if dnssec_status in resolver_results: + print(f"\n{'#' * 60}") + print(f"# DNS RESOLVERS - DNSSEC {dnssec_status.upper()}") + print(f"{'#' * 60}") + + for keepalive_status in ['off', 'on']: + if keepalive_status in resolver_results[dnssec_status]: + for resolver in sorted(resolver_results[dnssec_status][keepalive_status].keys()): + results = resolver_results[dnssec_status][keepalive_status][resolver] + print_resolver_table(resolver, results, dnssec_status, keepalive_status) + +if __name__ == "__main__": + root_dir = "." + output_file = "dns_metrics.csv" + + analyze_dns_data(root_dir, output_file) diff --git a/internal/qol/measurement.go b/internal/qol/measurement.go index f87ecd5..98fb93f 100644 --- a/internal/qol/measurement.go +++ b/internal/qol/measurement.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "os" + "path/filepath" "strings" "time" @@ -77,13 +78,19 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT // Setup output files csvPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC, r.config.KeepAlive) + // Create directory if it doesn't exist + if err := os.MkdirAll(filepath.Dir(csvPath), 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + keepAliveStr := "" if r.config.KeepAlive { keepAliveStr = " (keep-alive)" } - fmt.Printf(">>> Measuring %s (dnssec=%v%s) → %s\n", upstream, r.config.DNSSEC, keepAliveStr, - strings.TrimSuffix(strings.TrimSuffix(csvPath, ".csv"), r.config.OutputDir+"/")) + // Show relative path for cleaner output + relPath, _ := filepath.Rel(r.config.OutputDir, csvPath) + fmt.Printf(">>> Measuring %s (dnssec=%v%s) → %s\n", upstream, r.config.DNSSEC, keepAliveStr, relPath) // Setup packet capture packetCapture, err := capture.NewPacketCapture(r.config.Interface, pcapPath) diff --git a/internal/qol/utils.go b/internal/qol/utils.go index 5d4cfd7..9a8dc74 100644 --- a/internal/qol/utils.go +++ b/internal/qol/utils.go @@ -11,19 +11,110 @@ import ( func GenerateOutputPaths(outputDir, upstream string, dnssec, keepAlive bool) (csvPath, pcapPath string) { proto := DetectProtocol(upstream) serverName := ExtractServerName(upstream) - ts := time.Now().Format("20060102_1504") - dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec] - keepAliveStr := map[bool]string{true: "on", false: "off"}[keepAlive] + cleanServer := cleanServerName(serverName) - base := fmt.Sprintf("%s_%s_dnssec_%s_keepalive_%s_%s", - proto, sanitize(serverName), dnssecStr, keepAliveStr, ts) + // Create date-based subdirectory + date := time.Now().Format("2006-01-02") + timestamp := time.Now().Format("150405") - return filepath.Join(outputDir, base+".csv"), - filepath.Join(outputDir, base+".pcap") + // Organize by date and server + subDir := filepath.Join(outputDir, date, cleanServer) + + // Create simple filename + base := proto + + // Add flags if enabled + var flags []string + if dnssec { + flags = append(flags, "dnssec") + } + if keepAlive { + flags = append(flags, "persist") + } + + if len(flags) > 0 { + base = fmt.Sprintf("%s-%s", base, strings.Join(flags, "-")) + } + + // Add timestamp + filename := fmt.Sprintf("%s-%s", base, timestamp) + + return filepath.Join(subDir, filename+".csv"), + filepath.Join(subDir, filename+".pcap") } -func sanitize(s string) string { - return strings.NewReplacer(":", "_", "/", "_", ".", "_").Replace(s) +func cleanServerName(server string) string { + // Map common servers to short names + serverMap := map[string]string{ + "1.1.1.1": "cloudflare", + "1.0.0.1": "cloudflare", + "cloudflare-dns.com": "cloudflare", + "one.one.one.one": "cloudflare", + "8.8.8.8": "google", + "8.8.4.4": "google", + "dns.google": "google", + "dns.google.com": "google", + "9.9.9.9": "quad9", + "149.112.112.112": "quad9", + "dns.quad9.net": "quad9", + "208.67.222.222": "opendns", + "208.67.220.220": "opendns", + "resolver1.opendns.com": "opendns", + "94.140.14.14": "adguard", + "94.140.15.15": "adguard", + "dns.adguard.com": "adguard", + } + + // Clean the server name first + cleaned := strings.ToLower(server) + cleaned = strings.TrimPrefix(cleaned, "https://") + cleaned = strings.TrimPrefix(cleaned, "http://") + cleaned = strings.Split(cleaned, "/")[0] // Remove path + cleaned = strings.Split(cleaned, ":")[0] // Remove port + + // Check if we have a mapping + if shortName, exists := serverMap[cleaned]; exists { + return shortName + } + + // For unknown servers, create a reasonable short name + parts := strings.Split(cleaned, ".") + if len(parts) >= 2 { + // For domains like dns.example.com, take "example" + if len(parts) >= 3 { + return parts[len(parts)-2] // Second to last part + } + // For IPs or simple domains, take first part + return parts[0] + } + + return sanitizeShort(cleaned) +} + +func sanitizeShort(s string) string { + // Keep only alphanumeric and dash + var result strings.Builder + for _, r := range s { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + result.WriteRune(r) + } else if r == '.' || r == '_' || r == '-' { + result.WriteRune('-') + } + } + + cleaned := result.String() + // Remove consecutive dashes and trim + for strings.Contains(cleaned, "--") { + cleaned = strings.ReplaceAll(cleaned, "--", "-") + } + cleaned = strings.Trim(cleaned, "-") + + // Limit length + if len(cleaned) > 15 { + cleaned = cleaned[:15] + } + + return cleaned } func DetectProtocol(upstream string) string { diff --git a/server/server.go b/server/server.go index e80e1f4..a021e1f 100644 --- a/server/server.go +++ b/server/server.go @@ -1,15 +1,155 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/url" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "github.com/afonsofrancof/sdns-proxy/client" + "github.com/afonsofrancof/sdns-proxy/common/logger" + "github.com/miekg/dns" +) + type Config struct { Address string Upstream string Fallback string Bootstrap string DNSSEC bool - KeepAlive bool + KeepAlive bool // Added KeepAlive field Timeout time.Duration Verbose bool } -// Update the initClients method: +type cacheKey struct { + domain string + qtype uint16 +} + +type cacheEntry struct { + records []dns.RR + expiresAt time.Time +} + +type Server struct { + config Config + upstreamClient client.DNSClient + fallbackClient client.DNSClient + bootstrapClient client.DNSClient + resolvedHosts map[string]string + queryCache map[cacheKey]*cacheEntry + hostsMutex sync.RWMutex + cacheMutex sync.RWMutex + dnsServer *dns.Server +} + +func New(config Config) (*Server, error) { + logger.Debug("Creating new server with config: %+v", config) + + if config.Upstream == "" { + logger.Error("Upstream server is required") + return nil, fmt.Errorf("upstream server is required") + } + + // Check if we need bootstrap server + needsBootstrap := containsHostname(config.Upstream) + if config.Fallback != "" { + needsBootstrap = needsBootstrap || containsHostname(config.Fallback) + } + + logger.Debug("Bootstrap needed: %v (upstream has hostname: %v, fallback has hostname: %v)", + needsBootstrap, containsHostname(config.Upstream), + config.Fallback != "" && containsHostname(config.Fallback)) + + if needsBootstrap && config.Bootstrap == "" { + logger.Error("Bootstrap server is required when upstream or fallback contains hostnames") + return nil, fmt.Errorf("bootstrap server is required when upstream or fallback contains hostnames") + } + + if config.Bootstrap != "" && containsHostname(config.Bootstrap) { + logger.Error("Bootstrap server cannot contain hostnames: %s", config.Bootstrap) + return nil, fmt.Errorf("bootstrap server cannot contain hostnames: %s", config.Bootstrap) + } + + s := &Server{ + config: config, + resolvedHosts: make(map[string]string), + queryCache: make(map[cacheKey]*cacheEntry), + } + + // Create bootstrap client if needed + if config.Bootstrap != "" { + logger.Debug("Creating bootstrap client for %s", config.Bootstrap) + bootstrapClient, err := client.New(config.Bootstrap, client.Options{ + DNSSEC: false, + KeepAlive: config.KeepAlive, // Pass KeepAlive to bootstrap client + }) + if err != nil { + logger.Error("Failed to create bootstrap client: %v", err) + return nil, fmt.Errorf("failed to create bootstrap client: %w", err) + } + s.bootstrapClient = bootstrapClient + logger.Debug("Bootstrap client created successfully") + } + + // Initialize upstream and fallback clients + if err := s.initClients(); err != nil { + logger.Error("Failed to initialize clients: %v", err) + return nil, fmt.Errorf("failed to initialize clients: %w", err) + } + + // Setup DNS server + mux := dns.NewServeMux() + mux.HandleFunc(".", s.handleDNSRequest) + + s.dnsServer = &dns.Server{ + Addr: config.Address, + Net: "udp", + Handler: mux, + } + + logger.Debug("Server created successfully, listening on %s", config.Address) + return s, nil +} + +func containsHostname(serverAddr string) bool { + logger.Debug("Checking if %s contains hostname", serverAddr) + + // Use the same parsing logic as the client package + parsedURL, err := url.Parse(serverAddr) + if err != nil { + logger.Debug("URL parsing failed for %s, treating as plain address", serverAddr) + // If URL parsing fails, assume it's a plain address + host, _, err := net.SplitHostPort(serverAddr) + if err != nil { + // Assume it's just a host + isHostname := net.ParseIP(serverAddr) == nil + logger.Debug("Address %s is hostname: %v", serverAddr, isHostname) + return isHostname + } + isHostname := net.ParseIP(host) == nil + logger.Debug("Host %s from %s is hostname: %v", host, serverAddr, isHostname) + return isHostname + } + + host := parsedURL.Hostname() + if host == "" { + logger.Debug("No hostname found in URL %s", serverAddr) + return false + } + + isHostname := net.ParseIP(host) == nil + logger.Debug("Host %s from URL %s is hostname: %v", host, serverAddr, isHostname) + return isHostname +} + func (s *Server) initClients() error { logger.Debug("Initializing DNS clients") @@ -23,7 +163,7 @@ func (s *Server) initClients() error { logger.Debug("Creating upstream client for %s (resolved: %s)", s.config.Upstream, resolvedUpstream) upstreamClient, err := client.New(resolvedUpstream, client.Options{ DNSSEC: s.config.DNSSEC, - KeepAlive: s.config.KeepAlive, + KeepAlive: s.config.KeepAlive, // Pass KeepAlive to upstream client }) if err != nil { logger.Error("Failed to create upstream client: %v", err) @@ -46,7 +186,7 @@ func (s *Server) initClients() error { logger.Debug("Creating fallback client for %s (resolved: %s)", s.config.Fallback, resolvedFallback) fallbackClient, err := client.New(resolvedFallback, client.Options{ DNSSEC: s.config.DNSSEC, - KeepAlive: s.config.KeepAlive, + KeepAlive: s.config.KeepAlive, // Pass KeepAlive to fallback client }) if err != nil { logger.Error("Failed to create fallback client: %v", err) @@ -62,3 +202,395 @@ func (s *Server) initClients() error { logger.Debug("All DNS clients initialized successfully") return nil } + +func (s *Server) resolveServerAddress(serverAddr string) (string, error) { + logger.Debug("Resolving server address: %s", serverAddr) + + // If it doesn't contain hostnames, return as-is + if !containsHostname(serverAddr) { + logger.Debug("Address %s contains no hostnames, returning as-is", serverAddr) + return serverAddr, nil + } + + // If no bootstrap client, we can't resolve hostnames + if s.bootstrapClient == nil { + logger.Error("Cannot resolve hostname in %s: no bootstrap server configured", serverAddr) + return "", fmt.Errorf("cannot resolve hostname in %s: no bootstrap server configured", serverAddr) + } + + // Use the same parsing logic as the client package + parsedURL, err := url.Parse(serverAddr) + if err != nil { + logger.Debug("Parsing %s as plain host:port format", serverAddr) + // Handle plain host:port format + host, port, err := net.SplitHostPort(serverAddr) + if err != nil { + // Assume it's just a hostname + resolvedIP, err := s.resolveHostname(serverAddr) + if err != nil { + return "", err + } + logger.Debug("Resolved %s to %s", serverAddr, resolvedIP) + return resolvedIP, nil + } + + resolvedIP, err := s.resolveHostname(host) + if err != nil { + return "", err + } + resolved := net.JoinHostPort(resolvedIP, port) + logger.Debug("Resolved %s to %s", serverAddr, resolved) + return resolved, nil + } + + // Handle URL format + hostname := parsedURL.Hostname() + if hostname == "" { + logger.Error("No hostname in URL: %s", serverAddr) + return "", fmt.Errorf("no hostname in URL: %s", serverAddr) + } + + resolvedIP, err := s.resolveHostname(hostname) + if err != nil { + return "", err + } + + // Replace hostname with IP in the URL + port := parsedURL.Port() + if port == "" { + parsedURL.Host = resolvedIP + } else { + parsedURL.Host = net.JoinHostPort(resolvedIP, port) + } + + resolved := parsedURL.String() + logger.Debug("Resolved URL %s to %s", serverAddr, resolved) + return resolved, nil +} + +func (s *Server) resolveHostname(hostname string) (string, error) { + logger.Debug("Resolving hostname: %s", hostname) + + // Check cache first + s.hostsMutex.RLock() + if ip, exists := s.resolvedHosts[hostname]; exists { + s.hostsMutex.RUnlock() + logger.Debug("Found cached resolution for %s: %s", hostname, ip) + return ip, nil + } + s.hostsMutex.RUnlock() + + // Resolve using bootstrap + if s.config.Verbose { + logger.Info("Resolving hostname %s using bootstrap server", hostname) + } + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(hostname), dns.TypeA) + msg.Id = dns.Id() + msg.RecursionDesired = true + + logger.Debug("Sending bootstrap query for %s (ID: %d)", hostname, msg.Id) + msg, err := s.bootstrapClient.Query(msg) + if err != nil { + logger.Error("Bootstrap query failed for %s: %v", hostname, err) + return "", fmt.Errorf("failed to resolve %s via bootstrap: %w", hostname, err) + } + + logger.Debug("Bootstrap response for %s: %d answers", hostname, len(msg.Answer)) + if len(msg.Answer) == 0 { + logger.Error("No A records found for %s", hostname) + return "", fmt.Errorf("no A records found for %s", hostname) + } + + // Find first A record + for _, rr := range msg.Answer { + if a, ok := rr.(*dns.A); ok { + ip := a.A.String() + + // Cache the result + s.hostsMutex.Lock() + s.resolvedHosts[hostname] = ip + s.hostsMutex.Unlock() + + if s.config.Verbose { + logger.Info("Resolved %s to %s", hostname, ip) + } + logger.Debug("Cached resolution: %s -> %s", hostname, ip) + + return ip, nil + } + } + + logger.Error("No valid A record found for %s", hostname) + return "", fmt.Errorf("no valid A record found for %s", hostname) +} + +func (s *Server) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + logger.Debug("Received request with no questions from %s", w.RemoteAddr()) + dns.HandleFailed(w, r) + return + } + + question := r.Question[0] + domain := strings.ToLower(question.Name) + qtype := question.Qtype + + logger.Debug("Handling DNS request: %s %s from %s (ID: %d)", + question.Name, dns.TypeToString[qtype], w.RemoteAddr(), r.Id) + + if s.config.Verbose { + logger.Info("Query: %s %s from %s", + question.Name, + dns.TypeToString[qtype], + w.RemoteAddr()) + } + + // Check cache first + if cachedRecords := s.getCachedRecords(domain, qtype); cachedRecords != nil { + response := s.buildResponse(r, cachedRecords) + if s.config.Verbose { + logger.Info("Cache hit: %s %s -> %d records", + question.Name, + dns.TypeToString[qtype], + len(cachedRecords)) + } + logger.Debug("Serving cached response for %s %s (%d records)", + question.Name, dns.TypeToString[qtype], len(cachedRecords)) + w.WriteMsg(response) + return + } + + logger.Debug("Cache miss for %s %s, querying upstream", question.Name, dns.TypeToString[qtype]) + + // Try upstream first + response, err := s.queryUpstream(s.upstreamClient, question.Name, qtype) + if err != nil { + if s.config.Verbose { + logger.Info("Upstream query failed: %v", err) + } + logger.Debug("Upstream query failed for %s %s: %v", question.Name, dns.TypeToString[qtype], err) + + // Try fallback if available + if s.fallbackClient != nil { + if s.config.Verbose { + logger.Info("Trying fallback server") + } + logger.Debug("Attempting fallback query for %s %s", question.Name, dns.TypeToString[qtype]) + + response, err = s.queryUpstream(s.fallbackClient, question.Name, qtype) + if err != nil { + logger.Error("Both upstream and fallback failed for %s %s: %v", + question.Name, + dns.TypeToString[qtype], + err) + } else { + logger.Debug("Fallback query succeeded for %s %s", question.Name, dns.TypeToString[qtype]) + } + } + + // If still failed, return SERVFAIL + if err != nil { + logger.Error("All servers failed for %s %s: %v", + question.Name, + dns.TypeToString[qtype], + err) + + m := new(dns.Msg) + m.SetReply(r) + m.Rcode = dns.RcodeServerFailure + w.WriteMsg(m) + return + } + } else { + logger.Debug("Upstream query succeeded for %s %s", question.Name, dns.TypeToString[qtype]) + } + + // Cache successful response + s.cacheResponse(domain, qtype, response) + + // Copy request ID to response + response.Id = r.Id + + if s.config.Verbose { + logger.Info("Response: %s %s -> %d answers", + question.Name, + dns.TypeToString[qtype], + len(response.Answer)) + } + + logger.Debug("Sending response for %s %s: %d answers, rcode: %s", + question.Name, dns.TypeToString[qtype], len(response.Answer), dns.RcodeToString[response.Rcode]) + w.WriteMsg(response) +} + +func (s *Server) getCachedRecords(domain string, qtype uint16) []dns.RR { + key := cacheKey{domain: domain, qtype: qtype} + + s.cacheMutex.RLock() + entry, exists := s.queryCache[key] + s.cacheMutex.RUnlock() + + if !exists { + logger.Debug("No cache entry for %s %s", domain, dns.TypeToString[qtype]) + return nil + } + + // Check if expired and clean up on the spot + if time.Now().After(entry.expiresAt) { + logger.Debug("Cache entry expired for %s %s", domain, dns.TypeToString[qtype]) + s.cacheMutex.Lock() + delete(s.queryCache, key) + s.cacheMutex.Unlock() + return nil + } + + logger.Debug("Cache hit for %s %s (%d records, expires in %v)", + domain, dns.TypeToString[qtype], len(entry.records), time.Until(entry.expiresAt)) + + // Return a copy of the cached records + records := make([]dns.RR, len(entry.records)) + for i, rr := range entry.records { + records[i] = dns.Copy(rr) + } + return records +} + +func (s *Server) buildResponse(request *dns.Msg, records []dns.RR) *dns.Msg { + response := new(dns.Msg) + response.SetReply(request) + response.Answer = records + logger.Debug("Built response with %d records", len(records)) + return response +} + +func (s *Server) cacheResponse(domain string, qtype uint16, msg *dns.Msg) { + if msg == nil || len(msg.Answer) == 0 { + logger.Debug("Not caching empty response for %s %s", domain, dns.TypeToString[qtype]) + return + } + + var validRecords []dns.RR + minTTL := uint32(3600) + + // Find minimum TTL from answer records + for _, rr := range msg.Answer { + // Only cache records that match our query type or are CNAMEs + if rr.Header().Rrtype == qtype || rr.Header().Rrtype == dns.TypeCNAME { + validRecords = append(validRecords, dns.Copy(rr)) + if rr.Header().Ttl < minTTL { + minTTL = rr.Header().Ttl + } + } + } + + if len(validRecords) == 0 { + logger.Debug("No valid records to cache for %s %s", domain, dns.TypeToString[qtype]) + return + } + + // Don't cache responses with very low TTL + if minTTL < 10 { + logger.Debug("TTL too low (%ds) for caching %s %s", minTTL, domain, dns.TypeToString[qtype]) + return + } + + key := cacheKey{domain: domain, qtype: qtype} + entry := &cacheEntry{ + records: validRecords, + expiresAt: time.Now().Add(time.Duration(minTTL) * time.Second), + } + + s.cacheMutex.Lock() + s.queryCache[key] = entry + s.cacheMutex.Unlock() + + if s.config.Verbose { + logger.Info("Cached %d records for %s %s (TTL: %ds)", + len(validRecords), domain, dns.TypeToString[qtype], minTTL) + } + logger.Debug("Cached %d records for %s %s (TTL: %ds, expires: %v)", + len(validRecords), domain, dns.TypeToString[qtype], minTTL, entry.expiresAt) +} + +func (s *Server) queryUpstream(upstreamClient client.DNSClient, domain string, qtype uint16) (*dns.Msg, error) { + logger.Debug("Querying upstream for %s %s", domain, dns.TypeToString[qtype]) + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout) + defer cancel() + + // Channel to receive result + type result struct { + msg *dns.Msg + err error + } + resultChan := make(chan result, 1) + + // Query in goroutine to respect context timeout + go func() { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), qtype) + msg.Id = dns.Id() + msg.RecursionDesired = true + + logger.Debug("Sending upstream query: %s %s (ID: %d)", domain, dns.TypeToString[qtype], msg.Id) + recvMsg, err := upstreamClient.Query(msg) + if err != nil { + logger.Debug("Upstream query error for %s %s: %v", domain, dns.TypeToString[qtype], err) + } else { + logger.Debug("Upstream query response for %s %s: %d answers, rcode: %s", + domain, dns.TypeToString[qtype], len(recvMsg.Answer), dns.RcodeToString[recvMsg.Rcode]) + } + resultChan <- result{msg: recvMsg, err: err} + }() + + select { + case res := <-resultChan: + return res.msg, res.err + case <-ctx.Done(): + logger.Debug("Upstream query timeout for %s %s after %v", domain, dns.TypeToString[qtype], s.config.Timeout) + return nil, fmt.Errorf("upstream query timeout") + } +} + +func (s *Server) Start() error { + go func() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigChan + logger.Info("Received signal %v, shutting down DNS server...", sig) + s.Shutdown() + }() + + logger.Info("DNS proxy server listening on %s", s.config.Address) + logger.Debug("Server starting with timeout: %v, DNSSEC: %v, KeepAlive: %v", s.config.Timeout, s.config.DNSSEC, s.config.KeepAlive) + return s.dnsServer.ListenAndServe() +} + +func (s *Server) Shutdown() { + logger.Debug("Shutting down server components") + + if s.dnsServer != nil { + logger.Debug("Shutting down DNS server") + s.dnsServer.Shutdown() + } + + if s.upstreamClient != nil { + logger.Debug("Closing upstream client") + s.upstreamClient.Close() + } + + if s.fallbackClient != nil { + logger.Debug("Closing fallback client") + s.fallbackClient.Close() + } + + if s.bootstrapClient != nil { + logger.Debug("Closing bootstrap client") + s.bootstrapClient.Close() + } + + logger.Info("Server shutdown complete") +}