259 lines
6.3 KiB
Go
259 lines
6.3 KiB
Go
package qol
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/afonsofrancof/sdns-proxy/client"
|
|
"github.com/afonsofrancof/sdns-proxy/internal/qol/capture"
|
|
"github.com/afonsofrancof/sdns-proxy/internal/qol/results"
|
|
"github.com/google/gopacket/pcap"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type MeasurementConfig struct {
|
|
DomainsFile string
|
|
OutputDir string
|
|
QueryType string
|
|
DNSSEC bool
|
|
KeepAlive bool
|
|
Interface string
|
|
Servers []string
|
|
}
|
|
|
|
type MeasurementRunner struct {
|
|
config MeasurementConfig
|
|
}
|
|
|
|
func NewMeasurementRunner(config MeasurementConfig) *MeasurementRunner {
|
|
return &MeasurementRunner{config: config}
|
|
}
|
|
|
|
func (r *MeasurementRunner) Run() error {
|
|
if err := r.checkCapturePermissions(); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Warning: %v\n", err)
|
|
fmt.Fprintf(os.Stderr, "Packet capture may fail. Consider running as root/administrator.\n")
|
|
}
|
|
|
|
domains, err := r.readDomainsFile()
|
|
if err != nil {
|
|
return fmt.Errorf("failed reading domains: %w", err)
|
|
}
|
|
|
|
if len(r.config.Servers) == 0 {
|
|
return fmt.Errorf("at least one server must be provided")
|
|
}
|
|
|
|
if err := os.MkdirAll(r.config.OutputDir, 0755); err != nil {
|
|
return fmt.Errorf("mkdir output: %w", err)
|
|
}
|
|
|
|
qType, ok := dns.StringToType[strings.ToUpper(r.config.QueryType)]
|
|
if !ok {
|
|
return fmt.Errorf("invalid qtype: %s", r.config.QueryType)
|
|
}
|
|
|
|
for _, upstream := range r.config.Servers {
|
|
if err := r.runMeasurement(upstream, domains, qType); err != nil {
|
|
fmt.Fprintf(os.Stderr, "error on server %s: %v\n", upstream, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qType uint16) error {
|
|
// Setup DNS client
|
|
dnsClient, err := r.setupDNSClient(upstream)
|
|
if err != nil {
|
|
return fmt.Errorf("failed creating client: %w", err)
|
|
}
|
|
defer dnsClient.Close()
|
|
|
|
// Setup output files
|
|
csvPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC, r.config.KeepAlive)
|
|
|
|
keepAliveStr := ""
|
|
if r.config.KeepAlive {
|
|
keepAliveStr = " (keep-alive)"
|
|
}
|
|
|
|
fmt.Printf(">>> Measuring %s (dnssec=%v%s) → %s\n", upstream, r.config.DNSSEC, keepAliveStr,
|
|
strings.TrimSuffix(strings.TrimSuffix(csvPath, ".csv"), r.config.OutputDir+"/"))
|
|
|
|
// Setup packet capture
|
|
packetCapture, err := capture.NewPacketCapture(r.config.Interface, pcapPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer packetCapture.Close()
|
|
|
|
// Setup results writer
|
|
writer, err := results.NewMetricsWriter(csvPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer writer.Close()
|
|
|
|
// Run measurements
|
|
return r.runQueries(dnsClient, upstream, domains, qType, writer, packetCapture)
|
|
}
|
|
|
|
func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) {
|
|
opts := client.Options{
|
|
DNSSEC: r.config.DNSSEC,
|
|
KeepAlive: r.config.KeepAlive,
|
|
}
|
|
return client.New(upstream, opts)
|
|
}
|
|
|
|
func (r *MeasurementRunner) runQueries(dnsClient client.DNSClient, upstream string,
|
|
domains []string, qType uint16, writer *results.MetricsWriter,
|
|
packetCapture *capture.PacketCapture) error {
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
if err := packetCapture.Start(ctx); err != nil {
|
|
return err
|
|
}
|
|
|
|
failureCount := 0
|
|
const maxFailures = 5
|
|
proto := DetectProtocol(upstream)
|
|
|
|
for _, domain := range domains {
|
|
if failureCount >= maxFailures {
|
|
fmt.Printf("⚠ Skipping remaining domains (too many failures: %d)\n", failureCount)
|
|
break
|
|
}
|
|
|
|
metric := r.performQuery(dnsClient, domain, upstream, proto, qType)
|
|
|
|
if metric.ResponseCode == "ERROR" {
|
|
failureCount++
|
|
} else if metric.ResponseCode == "NOERROR" {
|
|
failureCount = 0
|
|
}
|
|
|
|
if err := writer.WriteMetric(metric); err != nil {
|
|
fmt.Fprintf(os.Stderr, "encode error: %v\n", err)
|
|
}
|
|
|
|
r.printQueryResult(metric)
|
|
|
|
// For keep-alive connections, add smaller delays between queries
|
|
// to better utilize the persistent connection
|
|
if r.config.KeepAlive {
|
|
time.Sleep(5 * time.Millisecond)
|
|
} else {
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
}
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
if err := packetCapture.GetError(); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Warning: packet capture errors occurred: %v\n", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *MeasurementRunner) performQuery(dnsClient client.DNSClient, domain, upstream, proto string, qType uint16) results.DNSMetric {
|
|
metric := results.DNSMetric{
|
|
Domain: domain,
|
|
QueryType: r.config.QueryType,
|
|
Protocol: proto,
|
|
DNSSEC: r.config.DNSSEC,
|
|
KeepAlive: r.config.KeepAlive,
|
|
DNSServer: upstream,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
msg := new(dns.Msg)
|
|
msg.Id = dns.Id()
|
|
msg.RecursionDesired = true
|
|
msg.SetQuestion(dns.Fqdn(domain), qType)
|
|
|
|
packed, err := msg.Pack()
|
|
if err != nil {
|
|
metric.ResponseCode = "ERROR"
|
|
metric.Error = fmt.Sprintf("pack request: %v", err)
|
|
return metric
|
|
}
|
|
metric.RequestSize = len(packed)
|
|
|
|
start := time.Now()
|
|
resp, err := dnsClient.Query(msg)
|
|
metric.Duration = time.Since(start).Nanoseconds()
|
|
metric.DurationMs = float64(metric.Duration) / 1e6
|
|
|
|
if err != nil {
|
|
metric.ResponseCode = "ERROR"
|
|
metric.Error = err.Error()
|
|
return metric
|
|
}
|
|
|
|
respBytes, err := resp.Pack()
|
|
if err != nil {
|
|
metric.ResponseCode = "ERROR"
|
|
metric.Error = fmt.Sprintf("pack response: %v", err)
|
|
return metric
|
|
}
|
|
|
|
metric.ResponseSize = len(respBytes)
|
|
metric.ResponseCode = dns.RcodeToString[resp.Rcode]
|
|
return metric
|
|
}
|
|
|
|
func (r *MeasurementRunner) printQueryResult(metric results.DNSMetric) {
|
|
statusIcon := "✓"
|
|
if metric.ResponseCode == "ERROR" {
|
|
statusIcon = "✗"
|
|
}
|
|
|
|
keepAliveIndicator := ""
|
|
if metric.KeepAlive {
|
|
keepAliveIndicator = "⟷"
|
|
}
|
|
|
|
fmt.Printf("%s %s%s [%s] %s %.2fms\n",
|
|
statusIcon, metric.Domain, keepAliveIndicator, metric.Protocol, metric.ResponseCode, metric.DurationMs)
|
|
}
|
|
|
|
func (r *MeasurementRunner) readDomainsFile() ([]string, error) {
|
|
f, err := os.Open(r.config.DomainsFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
|
|
var out []string
|
|
sc := bufio.NewScanner(f)
|
|
for sc.Scan() {
|
|
l := strings.TrimSpace(sc.Text())
|
|
if l != "" && !strings.HasPrefix(l, "#") {
|
|
out = append(out, l)
|
|
}
|
|
}
|
|
return out, sc.Err()
|
|
}
|
|
|
|
func (r *MeasurementRunner) checkCapturePermissions() error {
|
|
handle, err := pcap.OpenLive("any", 65535, false, time.Millisecond*100)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "permission") ||
|
|
strings.Contains(err.Error(), "Operation not permitted") {
|
|
return fmt.Errorf("insufficient permissions for packet capture")
|
|
}
|
|
return nil
|
|
}
|
|
handle.Close()
|
|
return nil
|
|
}
|