refactor: refactor the code for qol.go
This commit is contained in:
101
internal/qol/capture/pcap.go
Normal file
101
internal/qol/capture/pcap.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package capture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/pcap"
|
||||
"github.com/google/gopacket/pcapgo"
|
||||
)
|
||||
|
||||
type PacketCapture struct {
|
||||
handle *pcap.Handle
|
||||
writer *pcapgo.Writer
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
err error
|
||||
}
|
||||
|
||||
func NewPacketCapture(iface, outputPath string) (*PacketCapture, error) {
|
||||
handle, err := pcap.OpenLive(iface, 65535, true, pcap.BlockForever)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pcap open (try running as root): %w", err)
|
||||
}
|
||||
|
||||
file, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
handle.Close()
|
||||
return nil, fmt.Errorf("create pcap file: %w", err)
|
||||
}
|
||||
|
||||
writer := pcapgo.NewWriter(file)
|
||||
if err := writer.WriteFileHeader(65535, handle.LinkType()); err != nil {
|
||||
handle.Close()
|
||||
file.Close()
|
||||
return nil, fmt.Errorf("pcap header: %w", err)
|
||||
}
|
||||
|
||||
return &PacketCapture{
|
||||
handle: handle,
|
||||
writer: writer,
|
||||
file: file,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pc *PacketCapture) Start(ctx context.Context) error {
|
||||
psrc := gopacket.NewPacketSource(pc.handle, pc.handle.LinkType())
|
||||
pktCh := psrc.Packets()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case pkt, ok := <-pktCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ci := pkt.Metadata().CaptureInfo
|
||||
if err := pc.writer.WritePacket(ci, pkt.Data()); err != nil {
|
||||
pc.mu.Lock()
|
||||
if pc.err == nil {
|
||||
pc.err = fmt.Errorf("pcap write error: %w", err)
|
||||
}
|
||||
pc.mu.Unlock()
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *PacketCapture) GetError() error {
|
||||
pc.mu.Lock()
|
||||
defer pc.mu.Unlock()
|
||||
return pc.err
|
||||
}
|
||||
|
||||
func (pc *PacketCapture) Close() error {
|
||||
var errs []error
|
||||
|
||||
if pc.handle != nil {
|
||||
pc.handle.Close()
|
||||
}
|
||||
|
||||
if pc.file != nil {
|
||||
if err := pc.file.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
233
internal/qol/measurement.go
Normal file
233
internal/qol/measurement.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package qol
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/afonsofrancof/sdns-proxy/client"
|
||||
"github.com/afonsofrancof/sdns-proxy/internal/qol/capture"
|
||||
"github.com/afonsofrancof/sdns-proxy/internal/qol/results"
|
||||
"github.com/google/gopacket/pcap"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type MeasurementConfig struct {
|
||||
DomainsFile string
|
||||
OutputDir string
|
||||
QueryType string
|
||||
DNSSEC bool
|
||||
Interface string
|
||||
Servers []string
|
||||
}
|
||||
|
||||
type MeasurementRunner struct {
|
||||
config MeasurementConfig
|
||||
}
|
||||
|
||||
func NewMeasurementRunner(config MeasurementConfig) *MeasurementRunner {
|
||||
return &MeasurementRunner{config: config}
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) Run() error {
|
||||
if err := r.checkCapturePermissions(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: %v\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Packet capture may fail. Consider running as root/administrator.\n")
|
||||
}
|
||||
|
||||
domains, err := r.readDomainsFile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading domains: %w", err)
|
||||
}
|
||||
|
||||
if len(r.config.Servers) == 0 {
|
||||
return fmt.Errorf("at least one server must be provided")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(r.config.OutputDir, 0755); err != nil {
|
||||
return fmt.Errorf("mkdir output: %w", err)
|
||||
}
|
||||
|
||||
qType, ok := dns.StringToType[strings.ToUpper(r.config.QueryType)]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid qtype: %s", r.config.QueryType)
|
||||
}
|
||||
|
||||
for _, upstream := range r.config.Servers {
|
||||
if err := r.runMeasurement(upstream, domains, qType); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error on server %s: %v\n", upstream, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) runMeasurement(upstream string, domains []string, qType uint16) error {
|
||||
// Setup DNS client
|
||||
dnsClient, err := r.setupDNSClient(upstream)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating client: %w", err)
|
||||
}
|
||||
defer dnsClient.Close()
|
||||
|
||||
// Setup output files
|
||||
jsonPath, pcapPath := GenerateOutputPaths(r.config.OutputDir, upstream, r.config.DNSSEC)
|
||||
|
||||
fmt.Printf(">>> Measuring %s (dnssec=%v) → %s\n", upstream, r.config.DNSSEC,
|
||||
strings.TrimSuffix(strings.TrimSuffix(jsonPath, ".jsonl"), r.config.OutputDir+"/"))
|
||||
|
||||
// Setup packet capture
|
||||
packetCapture, err := capture.NewPacketCapture(r.config.Interface, pcapPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer packetCapture.Close()
|
||||
|
||||
// Setup results writer
|
||||
writer, err := results.NewMetricsWriter(jsonPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
// Run measurements
|
||||
return r.runQueries(dnsClient, upstream, domains, qType, writer, packetCapture)
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) setupDNSClient(upstream string) (client.DNSClient, error) {
|
||||
opts := client.Options{DNSSEC: r.config.DNSSEC}
|
||||
return client.New(upstream, opts)
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) runQueries(dnsClient client.DNSClient, upstream string,
|
||||
domains []string, qType uint16, writer *results.MetricsWriter,
|
||||
packetCapture *capture.PacketCapture) error {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := packetCapture.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
failureCount := 0
|
||||
const maxFailures = 5
|
||||
proto := DetectProtocol(upstream)
|
||||
|
||||
for _, domain := range domains {
|
||||
if failureCount >= maxFailures {
|
||||
fmt.Printf("⚠ Skipping remaining domains (too many failures: %d)\n", failureCount)
|
||||
break
|
||||
}
|
||||
|
||||
metric := r.performQuery(dnsClient, domain, upstream, proto, qType)
|
||||
|
||||
if metric.ResponseCode == "ERROR" {
|
||||
failureCount++
|
||||
}
|
||||
|
||||
if err := writer.WriteMetric(metric); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "encode error: %v\n", err)
|
||||
}
|
||||
|
||||
r.printQueryResult(metric)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if err := packetCapture.GetError(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Warning: packet capture errors occurred: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) performQuery(dnsClient client.DNSClient, domain, upstream, proto string, qType uint16) results.DNSMetric {
|
||||
metric := results.DNSMetric{
|
||||
Domain: domain,
|
||||
QueryType: r.config.QueryType,
|
||||
Protocol: proto,
|
||||
DNSSEC: r.config.DNSSEC,
|
||||
DNSServer: upstream,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.Id = dns.Id()
|
||||
msg.RecursionDesired = true
|
||||
msg.SetQuestion(dns.Fqdn(domain), qType)
|
||||
|
||||
packed, err := msg.Pack()
|
||||
if err != nil {
|
||||
metric.ResponseCode = "ERROR"
|
||||
metric.Error = fmt.Sprintf("pack request: %v", err)
|
||||
return metric
|
||||
}
|
||||
metric.RequestSize = len(packed)
|
||||
|
||||
start := time.Now()
|
||||
resp, err := dnsClient.Query(msg)
|
||||
metric.Duration = time.Since(start).Nanoseconds()
|
||||
metric.DurationMs = float64(metric.Duration) / 1e6
|
||||
|
||||
if err != nil {
|
||||
metric.ResponseCode = "ERROR"
|
||||
metric.Error = err.Error()
|
||||
return metric
|
||||
}
|
||||
|
||||
respBytes, err := resp.Pack()
|
||||
if err != nil {
|
||||
metric.ResponseCode = "ERROR"
|
||||
metric.Error = fmt.Sprintf("pack response: %v", err)
|
||||
return metric
|
||||
}
|
||||
|
||||
metric.ResponseSize = len(respBytes)
|
||||
metric.ResponseCode = dns.RcodeToString[resp.Rcode]
|
||||
return metric
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) printQueryResult(metric results.DNSMetric) {
|
||||
statusIcon := "✓"
|
||||
if metric.ResponseCode == "ERROR" {
|
||||
statusIcon = "✗"
|
||||
}
|
||||
fmt.Printf("%s %s [%s] %s %.2fms\n",
|
||||
statusIcon, metric.Domain, metric.Protocol, metric.ResponseCode, metric.DurationMs)
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) readDomainsFile() ([]string, error) {
|
||||
f, err := os.Open(r.config.DomainsFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var out []string
|
||||
sc := bufio.NewScanner(f)
|
||||
for sc.Scan() {
|
||||
l := strings.TrimSpace(sc.Text())
|
||||
if l != "" && !strings.HasPrefix(l, "#") {
|
||||
out = append(out, l)
|
||||
}
|
||||
}
|
||||
return out, sc.Err()
|
||||
}
|
||||
|
||||
func (r *MeasurementRunner) checkCapturePermissions() error {
|
||||
handle, err := pcap.OpenLive("any", 65535, false, time.Millisecond*100)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "permission") ||
|
||||
strings.Contains(err.Error(), "Operation not permitted") {
|
||||
return fmt.Errorf("insufficient permissions for packet capture")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
handle.Close()
|
||||
return nil
|
||||
}
|
||||
51
internal/qol/results/writer.go
Normal file
51
internal/qol/results/writer.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package results
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DNSMetric struct {
|
||||
Domain string `json:"domain"`
|
||||
QueryType string `json:"query_type"`
|
||||
Protocol string `json:"protocol"`
|
||||
DNSSEC bool `json:"dnssec"`
|
||||
DNSServer string `json:"dns_server"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Duration int64 `json:"duration_ns"`
|
||||
DurationMs float64 `json:"duration_ms"`
|
||||
RequestSize int `json:"request_size_bytes"`
|
||||
ResponseSize int `json:"response_size_bytes"`
|
||||
ResponseCode string `json:"response_code"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type MetricsWriter struct {
|
||||
encoder *json.Encoder
|
||||
file *os.File
|
||||
}
|
||||
|
||||
func NewMetricsWriter(path string) (*MetricsWriter, error) {
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create json output: %w", err)
|
||||
}
|
||||
|
||||
return &MetricsWriter{
|
||||
encoder: json.NewEncoder(file),
|
||||
file: file,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (mw *MetricsWriter) WriteMetric(metric DNSMetric) error {
|
||||
return mw.encoder.Encode(metric)
|
||||
}
|
||||
|
||||
func (mw *MetricsWriter) Close() error {
|
||||
if mw.file != nil {
|
||||
return mw.file.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
49
internal/qol/utils.go
Normal file
49
internal/qol/utils.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package qol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GenerateOutputPaths(outputDir, upstream string, dnssec bool) (jsonPath, pcapPath string) {
|
||||
proto := DetectProtocol(upstream)
|
||||
serverName := ExtractServerName(upstream)
|
||||
ts := time.Now().Format("20060102_1504")
|
||||
dnssecStr := map[bool]string{true: "on", false: "off"}[dnssec]
|
||||
|
||||
base := fmt.Sprintf("%s_%s_dnssec_%s_%s",
|
||||
proto, sanitize(serverName), dnssecStr, ts)
|
||||
|
||||
return filepath.Join(outputDir, base+".jsonl"),
|
||||
filepath.Join(outputDir, base+".pcap")
|
||||
}
|
||||
|
||||
func sanitize(s string) string {
|
||||
return strings.NewReplacer(":", "_", "/", "_", ".", "_").Replace(s)
|
||||
}
|
||||
|
||||
func DetectProtocol(upstream string) string {
|
||||
if strings.Contains(upstream, "://") {
|
||||
u, err := url.Parse(upstream)
|
||||
if err == nil && u.Scheme != "" {
|
||||
return strings.ToLower(u.Scheme)
|
||||
}
|
||||
}
|
||||
return "do53"
|
||||
}
|
||||
|
||||
func ExtractServerName(upstream string) string {
|
||||
if strings.Contains(upstream, "://") {
|
||||
u, err := url.Parse(upstream)
|
||||
if err == nil {
|
||||
if u.Scheme == "https" && u.Path != "" && u.Path != "/" {
|
||||
return u.Host + u.Path
|
||||
}
|
||||
return u.Host
|
||||
}
|
||||
}
|
||||
return upstream
|
||||
}
|
||||
Reference in New Issue
Block a user