diff --git a/common/protocols/doh/doh.go b/common/protocols/doh/doh.go index c1d7a57..ff67118 100644 --- a/common/protocols/doh/doh.go +++ b/common/protocols/doh/doh.go @@ -38,6 +38,10 @@ type Client struct { } func New(config Config) (*Client, error) { + if config.HTTP3 { + config.KeepAlive = true + } + logger.Debug("Creating DoH client: %s:%s%s (KeepAlive: %v)", config.Host, config.Port, config.Path, config.KeepAlive) if config.Host == "" || config.Port == "" || config.Path == "" { @@ -76,10 +80,6 @@ func New(config Config) (*Client, error) { if config.KeepAlive { transport.MaxIdleConnsPerHost = 10 transport.MaxConnsPerHost = 10 - } else { - transport.MaxIdleConns = 0 - transport.MaxIdleConnsPerHost = 0 - transport.DisableKeepAlives = true } httpClient := &http.Client{ @@ -93,11 +93,6 @@ func New(config Config) (*Client, error) { TLSClientConfig: tlsConfig, AllowHTTP: true, } - - if !config.KeepAlive { - http2Transport.DisableCompression = true - } - httpClient.Transport = http2Transport transportType = "HTTP/2" } else if config.HTTP3 { @@ -106,11 +101,6 @@ func New(config Config) (*Client, error) { TLSClientConfig: quicTlsConfig, QUICConfig: quicConfig, } - - if !config.KeepAlive { - http3Transport.DisableCompression = true - } - httpClient.Transport = http3Transport transportType = "HTTP/3" } else { diff --git a/internal/qol/capture/pcap.go b/internal/qol/capture/pcap.go index e342098..6c638a8 100644 --- a/internal/qol/capture/pcap.go +++ b/internal/qol/capture/pcap.go @@ -3,7 +3,9 @@ package capture import ( "context" "fmt" + "net" "os" + "strings" "sync" "github.com/google/gopacket" @@ -19,12 +21,91 @@ type PacketCapture struct { err error } -func NewPacketCapture(iface, outputPath string) (*PacketCapture, error) { +func getLocalIPs() ([]string, error) { + var localIPs []string + + addrs, err := net.InterfaceAddrs() + if err != nil { + return nil, fmt.Errorf("failed to get network interfaces: %w", err) + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + + // Skip loopback + if ip == nil || ip.IsLoopback() { + continue + } + + localIPs = append(localIPs, ip.String()) + } + + if len(localIPs) == 0 { + return nil, fmt.Errorf("no non-loopback IPs found") + } + + return localIPs, nil +} + +func buildBPFFilter(protocol string, localIPs []string) string { + // Build filter for this machine's IPs + var hostFilters []string + for _, ip := range localIPs { + hostFilters = append(hostFilters, fmt.Sprintf("host %s", ip)) + } + testMachineFilter := "(" + strings.Join(hostFilters, " or ") + ")" + + // Protocol-specific ports + var portFilter string + switch strings.ToLower(protocol) { + case "udp": + portFilter = "(port 53)" + case "tls", "dot": + portFilter = "(port 53 or port 853)" + case "https", "doh": + portFilter = "(port 53 or port 443)" + case "doq": + portFilter = "(port 53 or port 853)" + case "doh3": + portFilter = "(port 53 or port 443)" + default: + portFilter = "(port 53 or port 443 or port 853)" + } + + // Exclude private-to-private traffic (LAN-to-LAN, includes Docker ranges) + privateExclude := "not (src net (10.0.0.0/8 or 172.16.0.0/12 or 192.168.0.0/16) and dst net (10.0.0.0/8 or 172.16.0.0/12 or 192.168.0.0/16))" + + // Combine: test machine AND protocol ports AND NOT (private to private) + return testMachineFilter + " and " + portFilter + " and " + privateExclude +} + +func NewPacketCapture(iface, outputPath, protocol string) (*PacketCapture, error) { handle, err := pcap.OpenLive(iface, 65535, true, pcap.BlockForever) if err != nil { return nil, fmt.Errorf("pcap open (try running as root): %w", err) } + // Get local IPs dynamically + localIPs, err := getLocalIPs() + if err != nil { + handle.Close() + return nil, fmt.Errorf("failed to get local IPs: %w", err) + } + + // Build and apply BPF filter + bpfFilter := buildBPFFilter(protocol, localIPs) + + if err := handle.SetBPFFilter(bpfFilter); err != nil { + handle.Close() + return nil, fmt.Errorf("failed to set BPF filter '%s': %w", bpfFilter, err) + } + file, err := os.Create(outputPath) if err != nil { handle.Close() diff --git a/internal/qol/measurement.go b/internal/qol/measurement.go index 8917bfc..da43fea 100644 --- a/internal/qol/measurement.go +++ b/internal/qol/measurement.go @@ -94,7 +94,10 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT fmt.Printf(">>> Measuring %s (dnssec=%v, auth=%v%s) → %s\n", upstream, r.config.DNSSEC, r.config.AuthoritativeDNSSEC, keepAliveStr, relPath) // Setup packet capture - packetCapture, err := capture.NewPacketCapture(r.config.Interface, pcapPath) + proto := DetectProtocol(upstream) + + // Setup packet capture with protocol-aware filtering + packetCapture, err := capture.NewPacketCapture(r.config.Interface, pcapPath, proto) if err != nil { return err } @@ -107,6 +110,7 @@ func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qT } defer writer.Close() + time.Sleep(time.Second) // Run measurements return r.runQueries(dnsClient, upstream, domains, qType, writer, packetCapture) }