refactor: refactor the code for qol.go

This commit is contained in:
2025-09-08 23:27:50 +01:00
parent 640f240b3f
commit de619583ea
8 changed files with 481 additions and 302 deletions

46
cmd/qol/qol.go Normal file
View File

@@ -0,0 +1,46 @@
package main
import (
"time"
"github.com/afonsofrancof/sdns-proxy/internal/qol"
"github.com/alecthomas/kong"
)
type CLI struct {
Run RunCmd `cmd:"" help:"Run measurements for given servers and domains"`
}
type RunCmd struct {
DomainsFile string `arg:"" help:"File with domains (one per line)"`
OutputDir string `short:"o" long:"output" default:"results" help:"Output directory"`
QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"`
Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"`
DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"`
Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"`
Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"`
}
func (r *RunCmd) Run() error {
config := qol.MeasurementConfig{
DomainsFile: r.DomainsFile,
OutputDir: r.OutputDir,
QueryType: r.QueryType,
DNSSEC: r.DNSSEC,
Interface: r.Interface,
Servers: r.Servers,
}
runner := qol.NewMeasurementRunner(config)
return runner.Run()
}
func main() {
ctx := kong.Parse(&CLI{},
kong.Name("dns-measurer"),
kong.Description("DNS secure protocols measurer with metrics + full pcap capture"),
kong.UsageOnError(),
)
err := ctx.Run()
ctx.FatalIfErrorf(err)
}

View File

@@ -0,0 +1,101 @@
package capture
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/pcap"
"github.com/google/gopacket/pcapgo"
)
type PacketCapture struct {
handle *pcap.Handle
writer *pcapgo.Writer
file *os.File
mu sync.Mutex
err error
}
func NewPacketCapture(iface, outputPath string) (*PacketCapture, error) {
handle, err := pcap.OpenLive(iface, 65535, true, pcap.BlockForever)
if err != nil {
return nil, fmt.Errorf("pcap open (try running as root): %w", err)
}
file, err := os.Create(outputPath)
if err != nil {
handle.Close()
return nil, fmt.Errorf("create pcap file: %w", err)
}
writer := pcapgo.NewWriter(file)
if err := writer.WriteFileHeader(65535, handle.LinkType()); err != nil {
handle.Close()
file.Close()
return nil, fmt.Errorf("pcap header: %w", err)
}
return &PacketCapture{
handle: handle,
writer: writer,
file: file,
}, nil
}
func (pc *PacketCapture) Start(ctx context.Context) error {
psrc := gopacket.NewPacketSource(pc.handle, pc.handle.LinkType())
pktCh := psrc.Packets()
go func() {
for {
select {
case pkt, ok := <-pktCh:
if !ok {
return
}
ci := pkt.Metadata().CaptureInfo
if err := pc.writer.WritePacket(ci, pkt.Data()); err != nil {
pc.mu.Lock()
if pc.err == nil {
pc.err = fmt.Errorf("pcap write error: %w", err)
}
pc.mu.Unlock()
}
case <-ctx.Done():
return
}
}
}()
return nil
}
func (pc *PacketCapture) GetError() error {
pc.mu.Lock()
defer pc.mu.Unlock()
return pc.err
}
func (pc *PacketCapture) Close() error {
var errs []error
if pc.handle != nil {
pc.handle.Close()
}
if pc.file != nil {
if err := pc.file.Close(); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
}

233
internal/qol/measurement.go Normal file
View File

@@ -0,0 +1,233 @@
package qol
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"time"
"github.com/afonsofrancof/sdns-proxy/client"
"github.com/afonsofrancof/sdns-proxy/internal/qol/capture"
"github.com/afonsofrancof/sdns-proxy/internal/qol/results"
"github.com/google/gopacket/pcap"
"github.com/miekg/dns"
)
type MeasurementConfig struct {
DomainsFile string
OutputDir string
QueryType string
DNSSEC bool
Interface string
Servers []string
}
type MeasurementRunner struct {
config MeasurementConfig
}
func NewMeasurementRunner(config MeasurementConfig) *MeasurementRunner {
return &MeasurementRunner{config: config}
}
func (r *MeasurementRunner) Run() error {
if err := r.checkCapturePermissions(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: %v\n", err)
fmt.Fprintf(os.Stderr, "Packet capture may fail. Consider running as root/administrator.\n")
}
domains, err := r.readDomainsFile()
if err != nil {
return fmt.Errorf("failed reading domains: %w", err)
}
if len(r.config.Servers) == 0 {
return fmt.Errorf("at least one server must be provided")
}
if err := os.MkdirAll(r.config.OutputDir, 0755); err != nil {
return fmt.Errorf("mkdir output: %w", err)
}
qType, ok := dns.StringToType[strings.ToUpper(r.config.QueryType)]
if !ok {
return fmt.Errorf("invalid qtype: %s", r.config.QueryType)
}
for _, upstream := range r.config.Servers {
if err := r.runMeasurement(upstream, domains, qType); err != nil {
fmt.Fprintf(os.Stderr, "error on server %s: %v\n", upstream, err)
}
}
return nil
}
func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qType uint16) error {
// Setup DNS client
dnsClient, err := r.setupDNSClient(upstream)
if err != nil {
return fmt.Errorf("failed creating client: %w", err)
}
defer dnsClient.Close()
// Setup output files
jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC)
fmt.Printf(">>> Measuring %s (dnssec=%v) → %s\n", upstream, r.config.DNSSEC,
strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/"))
// Setup packet capture
packetCapture, err := capture.NewPacketCapture(r.config.Interface, pcapPath)
if err != nil {
return err
}
defer packetCapture.Close()
// Setup results writer
writer, err := results.NewMetricsWriter(jsonPath)
if err != nil {
return err
}
defer writer.Close()
// Run measurements
return r.runQueries(dnsClient, upstream, domains, qType, writer, packetCapture)
}
func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) {
opts := client.Options{DNSSEC: r.config.DNSSEC}
return client.New(upstream, opts)
}
func (r *MeasurementRunner) runQueries(dnsClient client.DNSClient, upstream string,
domains []string, qType uint16, writer *results.MetricsWriter,
packetCapture *capture.PacketCapture) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := packetCapture.Start(ctx); err != nil {
return err
}
failureCount := 0
const maxFailures = 5
proto := DetectProtocol(upstream)
for _, domain := range domains {
if failureCount >= maxFailures {
fmt.Printf("⚠ Skipping remaining domains (too many failures: %d)\n", failureCount)
break
}
metric := r.performQuery(dnsClient, domain, upstream, proto, qType)
if metric.ResponseCode == "ERROR" {
failureCount++
}
if err := writer.WriteMetric(metric); err != nil {
fmt.Fprintf(os.Stderr, "encode error: %v\n", err)
}
r.printQueryResult(metric)
time.Sleep(10 * time.Millisecond)
}
time.Sleep(100 * time.Millisecond)
if err := packetCapture.GetError(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: packet capture errors occurred: %v\n", err)
}
return nil
}
func (r *MeasurementRunner) performQuery(dnsClient client.DNSClient, domain, upstream, proto string, qType uint16) results.DNSMetric {
metric := results.DNSMetric{
Domain: domain,
QueryType: r.config.QueryType,
Protocol: proto,
DNSSEC: r.config.DNSSEC,
DNSServer: upstream,
Timestamp: time.Now(),
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.SetQuestion(dns.Fqdn(domain), qType)
packed, err := msg.Pack()
if err != nil {
metric.ResponseCode = "ERROR"
metric.Error = fmt.Sprintf("pack request: %v", err)
return metric
}
metric.RequestSize = len(packed)
start := time.Now()
resp, err := dnsClient.Query(msg)
metric.Duration = time.Since(start).Nanoseconds()
metric.DurationMs = float64(metric.Duration) / 1e6
if err != nil {
metric.ResponseCode = "ERROR"
metric.Error = err.Error()
return metric
}
respBytes, err := resp.Pack()
if err != nil {
metric.ResponseCode = "ERROR"
metric.Error = fmt.Sprintf("pack response: %v", err)
return metric
}
metric.ResponseSize = len(respBytes)
metric.ResponseCode = dns.RcodeToString[resp.Rcode]
return metric
}
func (r *MeasurementRunner) printQueryResult(metric results.DNSMetric) {
statusIcon := "✓"
if metric.ResponseCode == "ERROR" {
statusIcon = "✗"
}
fmt.Printf("%s %s [%s] %s %.2fms\n",
statusIcon, metric.Domain, metric.Protocol, metric.ResponseCode, metric.DurationMs)
}
func (r *MeasurementRunner) readDomainsFile() ([]string, error) {
f, err := os.Open(r.config.DomainsFile)
if err != nil {
return nil, err
}
defer f.Close()
var out []string
sc := bufio.NewScanner(f)
for sc.Scan() {
l := strings.TrimSpace(sc.Text())
if l != "" && !strings.HasPrefix(l, "#") {
out = append(out, l)
}
}
return out, sc.Err()
}
func (r *MeasurementRunner) checkCapturePermissions() error {
handle, err := pcap.OpenLive("any", 65535, false, time.Millisecond*100)
if err != nil {
if strings.Contains(err.Error(), "permission") ||
strings.Contains(err.Error(), "Operation not permitted") {
return fmt.Errorf("insufficient permissions for packet capture")
}
return nil
}
handle.Close()
return nil
}

View File

@@ -0,0 +1,51 @@
package results
import (
"encoding/json"
"fmt"
"os"
"time"
)
type DNSMetric struct {
Domain string `json:"domain"`
QueryType string `json:"query_type"`
Protocol string `json:"protocol"`
DNSSEC bool `json:"dnssec"`
DNSServer string `json:"dns_server"`
Timestamp time.Time `json:"timestamp"`
Duration int64 `json:"duration_ns"`
DurationMs float64 `json:"duration_ms"`
RequestSize int `json:"request_size_bytes"`
ResponseSize int `json:"response_size_bytes"`
ResponseCode string `json:"response_code"`
Error string `json:"error,omitempty"`
}
type MetricsWriter struct {
encoder *json.Encoder
file *os.File
}
func NewMetricsWriter(path string) (*MetricsWriter, error) {
file, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create json output: %w", err)
}
return &MetricsWriter{
encoder: json.NewEncoder(file),
file: file,
}, nil
}
func (mw *MetricsWriter) WriteMetric(metric DNSMetric) error {
return mw.encoder.Encode(metric)
}
func (mw *MetricsWriter) Close() error {
if mw.file != nil {
return mw.file.Close()
}
return nil
}

49
internal/qol/utils.go Normal file
View File

@@ -0,0 +1,49 @@
package qol
import (
"fmt"
"net/url"
"path/filepath"
"strings"
"time"
)
func GenerateOutputPaths(outputDir, upstream string, dnssec bool) (jsonPath, pcapPath string) {
proto := DetectProtocol(upstream)
serverName := ExtractServerName(upstream)
ts := time.Now().Format("20060102_1504")
dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec]
base := fmt.Sprintf("%s_%s_dnssec_%s_%s",
proto, sanitize(serverName), dnssecStr, ts)
return filepath.Join(outputDir, base+".jsonl"),
filepath.Join(outputDir, base+".pcap")
}
func sanitize(s string) string {
return strings.NewReplacer(":", "_", "/", "_", ".", "_").Replace(s)
}
func DetectProtocol(upstream string) string {
if strings.Contains(upstream, "://") {
u, err := url.Parse(upstream)
if err == nil && u.Scheme != "" {
return strings.ToLower(u.Scheme)
}
}
return "do53"
}
func ExtractServerName(upstream string) string {
if strings.Contains(upstream, "://") {
u, err := url.Parse(upstream)
if err == nil {
if u.Scheme == "https" && u.Path != "" && u.Path != "/" {
return u.Host + u.Path
}
return u.Host
}
}
return upstream
}

301
qol.go
View File

@@ -1,301 +0,0 @@
package main
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/afonsofrancof/sdns-proxy/client"
"github.com/alecthomas/kong"
"github.com/google/gopacket"
"github.com/google/gopacket/pcap"
"github.com/google/gopacket/pcapgo"
"github.com/miekg/dns"
)
type CLI struct {
Run RunCmd `cmd:"" help:"Run measurements for given servers and domains"`
}
type RunCmd struct {
DomainsFile string `arg:"" help:"File with domains (one per line)"`
OutputDir string `short:"o" long:"output" default:"results" help:"Output directory"`
QueryType string `short:"t" long:"type" default:"A" help:"DNS query type"`
Repeat int `short:"r" long:"repeat" default:"5" help:"Queries per domain (sequential)"`
Timeout time.Duration `long:"timeout" default:"5s" help:"Query timeout (informational)"`
DNSSEC bool `long:"dnssec" help:"Enable DNSSEC"`
Interface string `long:"iface" default:"any" help:"Capture interface (e.g., eth0, any)"`
Servers []string `short:"s" long:"server" help:"Upstream servers (udp://..., tls://..., https://..., doq://...)"`
}
type DNSMetric struct {
Domain string `json:"domain"`
QueryType string `json:"query_type"`
Protocol string `json:"protocol"`
DNSSEC bool `json:"dnssec"`
DNSServer string `json:"dns_server"`
Timestamp time.Time `json:"timestamp"`
Duration int64 `json:"duration_ns"`
DurationMs float64 `json:"duration_ms"`
RequestSize int `json:"request_size_bytes"`
ResponseSize int `json:"response_size_bytes"`
ResponseCode string `json:"response_code"`
Error string `json:"error,omitempty"`
}
func (r *RunCmd) Run() error {
// Check if running with sufficient privileges for packet capture
if err := checkCapturePermissions(); err != nil {
fmt.Fprintf(os.Stderr, "Warning: %v\n", err)
fmt.Fprintf(os.Stderr, "Packet capture may fail. Consider running as root/administrator.\n")
}
domains, err := readDomainsFile(r.DomainsFile)
if err != nil {
return fmt.Errorf("failed reading domains: %w", err)
}
if len(r.Servers) == 0 {
return fmt.Errorf("at least one --server must be provided")
}
if err := os.MkdirAll(r.OutputDir, 0755); err != nil {
return fmt.Errorf("mkdir output: %w", err)
}
qType, ok := dns.StringToType[strings.ToUpper(r.QueryType)]
if !ok {
return fmt.Errorf("invalid qtype: %s", r.QueryType)
}
for _, upstream := range r.Servers {
if err := r.runOne(upstream, domains, qType); err != nil {
fmt.Fprintf(os.Stderr, "error on server %s: %v\n", upstream, err)
}
}
return nil
}
func (r *RunCmd) runOne(upstream string, domains []string, qType uint16) error {
opts := client.Options{DNSSEC: r.DNSSEC}
dnsClient, err := client.New(upstream, opts)
if err != nil {
return fmt.Errorf("failed creating client: %w", err)
}
defer dnsClient.Close()
// file naming
proto := detectProtocol(upstream)
ts := time.Now().Format("20060102_1504")
dnssecStr := "off"
if r.DNSSEC {
dnssecStr = "on"
}
base := fmt.Sprintf("%s_%s_dnssec_%s_%s",
proto, sanitize(upstream), dnssecStr, ts)
jsonPath := filepath.Join(r.OutputDir, base+".jsonl")
pcapPath := filepath.Join(r.OutputDir, base+".pcap")
fmt.Printf(">>> Measuring %s (dnssec=%v) → %s\n", upstream, r.DNSSEC, base)
// setup pcap capture
handle, err := pcap.OpenLive(r.Interface, 65535, true, pcap.BlockForever)
if err != nil {
return fmt.Errorf("pcap open (try running as root): %w", err)
}
defer handle.Close()
pcapFile, err := os.Create(pcapPath)
if err != nil {
return fmt.Errorf("create pcap file: %w", err)
}
defer pcapFile.Close()
writer := pcapgo.NewWriter(pcapFile)
if err := writer.WriteFileHeader(65535, handle.LinkType()); err != nil {
return fmt.Errorf("pcap header: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
psrc := gopacket.NewPacketSource(handle, handle.LinkType())
pktCh := psrc.Packets()
var wg sync.WaitGroup
var captureErr error
captureMutex := sync.Mutex{}
// Start packet capture goroutine
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case pkt, ok := <-pktCh:
if !ok {
return
}
ci := pkt.Metadata().CaptureInfo
if err := writer.WritePacket(ci, pkt.Data()); err != nil {
captureMutex.Lock()
if captureErr == nil {
captureErr = fmt.Errorf("pcap write error: %w", err)
}
captureMutex.Unlock()
fmt.Fprintf(os.Stderr, "pcap write error: %v\n", err)
}
case <-ctx.Done():
return
}
}
}()
// open JSONL output
out, err := os.Create(jsonPath)
if err != nil {
cancel()
wg.Wait()
return fmt.Errorf("create json out: %w", err)
}
defer out.Close()
enc := json.NewEncoder(out)
// sequential measurement
for _, domain := range domains {
for rep := 0; rep < r.Repeat; rep++ {
metric := performQuery(dnsClient, domain, upstream, proto, qType, r.QueryType, r.DNSSEC)
if err := enc.Encode(metric); err != nil {
fmt.Fprintf(os.Stderr, "encode error: %v\n", err)
}
fmt.Printf("✓ %s [%s] %s %.2fms\n",
metric.Domain, metric.Protocol, metric.ResponseCode, metric.DurationMs)
// Small delay to allow packet capture to catch up
time.Sleep(10 * time.Millisecond)
}
}
// Allow some time for final packets to be captured
time.Sleep(100 * time.Millisecond)
cancel()
wg.Wait()
// Check if there were capture errors
captureMutex.Lock()
defer captureMutex.Unlock()
if captureErr != nil {
fmt.Fprintf(os.Stderr, "Warning: packet capture errors occurred: %v\n", captureErr)
}
return nil
}
func performQuery(dnsClient client.DNSClient, domain, upstream, proto string,
qType uint16, qTypeStr string, dnssec bool) DNSMetric {
metric := DNSMetric{
Domain: domain,
QueryType: qTypeStr,
Protocol: proto,
DNSSEC: dnssec,
DNSServer: upstream,
Timestamp: time.Now(),
}
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.RecursionDesired = true
msg.SetQuestion(dns.Fqdn(domain), qType)
packed, err := msg.Pack()
if err != nil {
metric.ResponseCode = "ERROR"
metric.Error = fmt.Sprintf("pack request: %v", err)
return metric
}
metric.RequestSize = len(packed)
start := time.Now()
resp, err := dnsClient.Query(msg)
metric.Duration = time.Since(start).Nanoseconds()
metric.DurationMs = float64(metric.Duration) / 1e6
if err != nil {
metric.ResponseCode = "ERROR"
metric.Error = err.Error()
return metric
}
respBytes, err := resp.Pack()
if err != nil {
metric.ResponseCode = "ERROR"
metric.Error = fmt.Sprintf("pack response: %v", err)
return metric
}
metric.ResponseSize = len(respBytes)
metric.ResponseCode = dns.RcodeToString[resp.Rcode]
return metric
}
func readDomainsFile(path string) ([]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
var out []string
sc := bufio.NewScanner(f)
for sc.Scan() {
l := strings.TrimSpace(sc.Text())
if l != "" && !strings.HasPrefix(l, "#") {
out = append(out, l)
}
}
return out, sc.Err()
}
func sanitize(s string) string {
return strings.NewReplacer(":", "_", "/", "_").Replace(s)
}
func detectProtocol(upstream string) string {
if strings.Contains(upstream, "://") {
u, err := url.Parse(upstream)
if err == nil && u.Scheme != "" {
return strings.ToLower(u.Scheme)
}
}
return "do53"
}
func checkCapturePermissions() error {
// Try to open a test interface to check permissions
handle, err := pcap.OpenLive("any", 65535, false, time.Millisecond*100)
if err != nil {
if strings.Contains(err.Error(), "permission") ||
strings.Contains(err.Error(), "Operation not permitted") {
return fmt.Errorf("insufficient permissions for packet capture")
}
// Other errors might be due to interface availability, which is acceptable
return nil
}
handle.Close()
return nil
}
func main() {
ctx := kong.Parse(&CLI{},
kong.Name("dns-measurer"),
kong.Description("DNS secure protocols measurer with metrics + full pcap capture"),
kong.UsageOnError(),
)
err := ctx.Run()
ctx.FatalIfErrorf(err)
}

2
run.sh
View File

@@ -17,7 +17,7 @@ SERVERS=(
-s "tls://dns.adguard-dns.com:853" -s "tls://dns.adguard-dns.com:853"
-s "https://dns.google/dns-query" -s "https://dns.google/dns-query"
-s "https://cloudflare-dns.com/dns-query" -s "https://cloudflare-dns.com/dns-query"
-s "https://dns.quad9.net/dns-query" -s "https://dns10.quad9.net/dns-query"
-s "https://dns.adguard-dns.com/dns-query" -s "https://dns.adguard-dns.com/dns-query"
-s "doq://dns.adguard-dns.com:853" -s "doq://dns.adguard-dns.com:853"
) )