From 2e0042153a112dbf0c4523ba1721c80fbcc3a910 Mon Sep 17 00:00:00 2001 From: afonso Date: Thu, 1 May 2025 12:34:30 +0100 Subject: [PATCH] Not finished stuff --- cmd/resolver/main.go | 344 ++++++++++---------------------- internal/client/client.go | 195 ++++++++++++++++++ internal/protocols/do53/do53.go | 150 ++++++++++---- internal/protocols/doh/doh.go | 144 ++++++------- internal/protocols/doq/doq.go | 98 +++++++-- internal/protocols/dot/dot.go | 185 ++++++++++------- 6 files changed, 673 insertions(+), 443 deletions(-) create mode 100644 internal/client/client.go diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go index eef09ff..c770c7d 100644 --- a/cmd/resolver/main.go +++ b/cmd/resolver/main.go @@ -1,268 +1,138 @@ package main import ( - "bufio" + "fmt" "log" "os" "strings" + "time" + + "github.com/afonsofrancof/sdns-perf/internal/client" - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" - "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" - "github.com/afonsofrancof/sdns-perf/internal/protocols/doq" - "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" "github.com/alecthomas/kong" + "github.com/miekg/dns" ) -type CommonFlags struct { - DomainName string `help:"Domain name to resolve" arg:"" required:""` - QueryType string `help:"Query type" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` - Server string `help:"DNS server to use" required:""` - DNSSEC bool `help:"Enable DNSSEC validation"` -} - -type DoHCmd struct { - CommonFlags `embed:""` - HTTP3 bool `help:"Use HTTP/3" name:"http3"` - Path string `help:"The HTTP path for the POST request" name:"path" required:""` - Proxy string `help:"The Proxy to use with ODoH"` -} - -type DoTCmd struct { - CommonFlags -} - -type DoQCmd struct { - CommonFlags -} - -type Do53Cmd struct { - CommonFlags -} - -type Listen struct { - -} - var cli struct { - Verbose bool `help:"Enable verbose logging" short:"v"` + // Global flags + Verbose bool `help:"Enable verbose logging." short:"v"` - DoH DoHCmd `cmd:"doh" help:"Query using DNS-over-HTTPS" name:"doh"` - DoT DoTCmd `cmd:"dot" help:"Query using DNS-over-TLS" name:"dot"` - DoQ DoQCmd `cmd:"doq" help:"Query using DNS-over-QUIC" name:"doq"` - Do53 Do53Cmd `cmd:"doq" help:"Query using plain DNS over UDP" name:"do53"` - Listen Listen `cmd:"listen"` + Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."` + Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."` } -func (c *Do53Cmd) Run() error { - do53client, err := do53.New(c.Server) +type QueryCmd struct { + DomainName string `help:"Domain name to resolve." arg:"" required:""` + Server string `help:"Upstream server address (e.g., https://1.1.1.1/dns-query, tls://1.1.1.1, 8.8.8.8)." short:"s" required:""` + QueryType string `help:"Query type (A, AAAA, MX, TXT, etc.)." short:"t" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` + DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"` + Timeout time.Duration `help:"Timeout for the query operation." default:"10s"` // Default might be higher now + KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"` +} + +func (q *QueryCmd) Run() error { + log.Printf("Querying %s for %s type %s (DNSSEC: %v, Timeout: %v)\n", + q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.Timeout) + + opts := client.Options{ + Timeout: q.Timeout, + DNSSEC: q.DNSSEC, + KeyLogPath: q.KeyLogFile, + } + + dnsClient, err := client.New(q.Server, opts) if err != nil { return err } - defer do53client.Close() - return do53client.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) -} + defer dnsClient.Close() -func (c *DoHCmd) Run() error { - dohclient, err := doh.New(c.Server, c.Path, c.Proxy) - if err != nil { - return err + qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)] + if !ok { + return fmt.Errorf("invalid query type: %s", q.QueryType) } - defer dohclient.Close() - return dohclient.Query(c.DomainName, c.QueryType, c.DNSSEC) -} -func (c *DoTCmd) Run() error { - dotclient, err := dot.New(c.Server) + dnsMsg, err := dnsClient.Query(q.DomainName, qTypeUint) if err != nil { - return err + return fmt.Errorf("query failed: %w ", err) } - defer dotclient.Close() - return dotclient.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) + + printResponse(q.DomainName, q.QueryType, dnsMsg) + + return nil } -func (c *DoQCmd) Run() error { - doqclient, err := doq.New(c.Server) - if err != nil { - return err +type ListenCmd struct { + Address string `help:"Address to listen on (e.g., :53, :8053)." default:":53"` + // Add other server-specific flags: default upstream, TLS cert/key paths etc. +} + +func (l *ListenCmd) Run() error { + return fmt.Errorf("server/listen mode not yet implemented") +} + +func printResponse(domain, qtype string, msg *dns.Msg) { + fmt.Println(";; QUESTION SECTION:") + + fmt.Printf(";%s.\tIN\t%s\n", dns.Fqdn(domain), strings.ToUpper(qtype)) + + fmt.Println("\n;; ANSWER SECTION:") + if len(msg.Answer) > 0 { + for _, rr := range msg.Answer { + fmt.Println(rr.String()) + } + } else { + fmt.Println(";; No records found in answer section.") } - defer doqclient.Close() - return doqclient.Query(c.DomainName, c.QueryType, c.DNSSEC) -} -func (l *Listen) Run() error { - // Maps to store clients for reuse - do53Clients := make(map[string]*do53.Do53Client) - dotClients := make(map[string]*dot.DoTClient) - doqClients := make(map[string]*doq.DoQClient) - dohClients := make(map[string]*doh.DoHClient) // Using server+path+proxy as key - - scanner := bufio.NewScanner(os.Stdin) - log.Println("Listening for input. Format: protocol domain server [options]") - - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - - if len(fields) < 3 { - log.Printf("Invalid input: %s. Format should be 'protocol domain server [options]'", line) - continue - } - - protocol := fields[0] - domain := fields[1] - server := fields[2] - - // Default query type and DNSSEC setting - queryType := "A" - dnssec := false - - switch protocol { - case "do53": - // Parse additional options - if len(fields) > 3 { - queryType = fields[3] - } - if len(fields) > 4 && fields[4] == "dnssec" { - dnssec = true - } - - // Check if client exists, if not create it - client, exists := do53Clients[server] - if !exists { - var err error - client, err = do53.New(server) - if err != nil { - log.Printf("Error creating Do53 client: %v", err) - continue - } - do53Clients[server] = client - } - - err := client.Query(domain, queryType, server, dnssec) - if err != nil { - log.Printf("Error querying with Do53: %v", err) - } - - case "dot": - // Parse additional options - if len(fields) > 3 { - queryType = fields[3] - } - if len(fields) > 4 && fields[4] == "dnssec" { - dnssec = true - } - - client, exists := dotClients[server] - if !exists { - var err error - client, err = dot.New(server) - if err != nil { - log.Printf("Error creating DoT client: %v", err) - continue - } - dotClients[server] = client - } - - err := client.Query(domain, queryType, server, dnssec) - if err != nil { - log.Printf("Error querying with DoT: %v", err) - } - - case "doq": - // Parse additional options - if len(fields) > 3 { - queryType = fields[3] - } - if len(fields) > 4 && fields[4] == "dnssec" { - dnssec = true - } - - client, exists := doqClients[server] - if !exists { - var err error - client, err = doq.New(server) - if err != nil { - log.Printf("Error creating DoQ client: %v", err) - continue - } - doqClients[server] = client - } - - err := client.Query(domain, queryType, dnssec) - if err != nil { - log.Printf("Error querying with DoQ: %v", err) - } - - case "doh": - // DoH requires path parameter - if len(fields) < 4 { - log.Printf("DoH requires a path parameter") - continue - } - - path := fields[3] - proxy := "" - - // Parse additional options - if len(fields) > 4 { - queryType = fields[4] - } - - if len(fields) > 5 { - if fields[5] == "dnssec" { - dnssec = true - } else { - proxy = fields[5] - } - } - - if len(fields) > 6 && fields[6] == "dnssec" { - dnssec = true - } - - // Create a composite key for DoH clients - key := server + ":" + path + ":" + proxy - client, exists := dohClients[key] - if !exists { - var err error - client, err = doh.New(server, path, proxy) - if err != nil { - log.Printf("Error creating DoH client: %v", err) - continue - } - dohClients[key] = client - } - - err := client.Query(domain, queryType, dnssec) - if err != nil { - log.Printf("Error querying with DoH: %v", err) - } - - default: - log.Printf("Unknown protocol: %s", protocol) - } - } - - if err := scanner.Err(); err != nil { - return err - } - - return nil -} + if len(msg.Ns) > 0 { + fmt.Println("\n;; AUTHORITY SECTION:") + for _, rr := range msg.Ns { + fmt.Println(rr.String()) + } + } + if len(msg.Extra) > 0 { + hasRealExtra := false + for _, rr := range msg.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + hasRealExtra = true + break + } + } + if hasRealExtra { + fmt.Println("\n;; ADDITIONAL SECTION:") + for _, rr := range msg.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + fmt.Println(rr.String()) + } + } + } + } + fmt.Printf("\n;; RCODE: %s, ID: %d", dns.RcodeToString[msg.Rcode], msg.Id) + opt := msg.IsEdns0() + if opt != nil { + fmt.Printf(", EDNS: version: %d; flags:", opt.Version()) + if opt.Do() { + fmt.Printf(" do;") + } else { + fmt.Printf(";") + } + fmt.Printf(" udp: %d", opt.UDPSize()) + } + fmt.Println() +} func main() { - ctx := kong.Parse(&cli, - kong.Name("dns-go"), - kong.Description("A DNS resolver supporting DoH, DoT, and DoQ protocols"), - kong.UsageOnError(), - kong.ConfigureHelp(kong.HelpOptions{ - Compact: true, - Summary: true, - })) + log.SetOutput(os.Stderr) + log.SetFlags(log.Ltime | log.Lshortfile) - err := ctx.Run() - if err != nil { - log.Fatalf("Error: %v", err) - } + kongCtx := kong.Parse(&cli, + kong.Name("sdns-perf"), + kong.Description("A DNS client/server tool supporting multiple protocols."), + kong.UsageOnError(), + kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}), + ) + + err := kongCtx.Run() + kongCtx.FatalIfErrorf(err) } diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000..0a81128 --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,195 @@ +// internal/client/client.go +package client + +import ( + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" + "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" + // "github.com/afonsofrancof/sdns-perf/internal/protocols/doq" + // "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" + + "github.com/miekg/dns" +) + +// DNSClient defines the interface that all specific protocol clients must implement. +type DNSClient interface { + Query(domain string, queryType uint16) (*dns.Msg, error) + Close() +} + +// Options holds common configuration options for creating any DNS client. +type Options struct { + Timeout time.Duration + DNSSEC bool + KeyLogPath string // Path for TLS key logging +} + +type protocolType int + +const ( + protoUnknown protocolType = iota + protoDo53 + protoDoT + protoDoH + protoDoH3 + protoDoQ +) + +// config holds the parsed details of an upstream server string. +// This is internal to the client package. +type config struct { + original string + protocol protocolType + host string + port string + path string +} + +// parseUpstream takes a user-provided upstream string and attempts to determine +// the protocol, host, port, and path. (Internal helper) +func parseUpstream(upstreamStr string) (config, error) { + cfg := config{original: upstreamStr, protocol: protoUnknown} + + // Try parsing as a full URL first + parsedURL, err := url.Parse(upstreamStr) + if err == nil && parsedURL.Scheme != "" && parsedURL.Host != "" { + cfg.host = parsedURL.Hostname() + cfg.port = parsedURL.Port() + cfg.path = parsedURL.Path + if cfg.path == "" { + cfg.path = "/" // Default path + } + + switch strings.ToLower(parsedURL.Scheme) { + case "https", "doh": + cfg.protocol = protoDoH + if cfg.port == "" { + cfg.port = "443" + } + case "h3", "doh3": + cfg.protocol = protoDoH3 + if cfg.port == "" { + cfg.port = "443" + } + case "tls", "dot": + cfg.protocol = protoDoT + if cfg.port == "" { + cfg.port = "853" + } + case "quic", "doq": + cfg.protocol = protoDoQ + if cfg.port == "" { + cfg.port = "853" + } + case "udp", "do53": + cfg.protocol = protoDo53 + if cfg.port == "" { + cfg.port = "53" + } + default: + return cfg, fmt.Errorf("unsupported URL scheme: %q", parsedURL.Scheme) + } + return cfg, nil + } + + // If not a valid URL or no scheme, assume plain DNS (Do53 UDP) + cfg.protocol = protoDo53 + host, port, err := net.SplitHostPort(upstreamStr) + if err == nil { + cfg.host = host + cfg.port = port + if _, pErr := strconv.Atoi(port); pErr != nil { + return cfg, fmt.Errorf("invalid port %q in upstream %q: %w", port, upstreamStr, pErr) + } + } else { + cfg.host = upstreamStr + cfg.port = "53" + // Basic check for likely IPv6 without brackets and port + if strings.Contains(cfg.host, ":") && !strings.Contains(cfg.host, "[") { + _, resolveErr := net.ResolveUDPAddr("udp", net.JoinHostPort(cfg.host, cfg.port)) + if resolveErr != nil { + return cfg, fmt.Errorf("invalid upstream format; could not parse %q as host:port or resolve as host with default port 53: %w", upstreamStr, err) + } + } + } + + if cfg.host == "" { + return cfg, fmt.Errorf("could not extract host from upstream: %q", upstreamStr) + } + + return cfg, nil +} + +// New creates the appropriate DNS client based on the upstream string format. +// It returns an uninitialized client (connections are lazy). +func New(upstreamStr string, opts Options) (DNSClient, error) { + cfg, err := parseUpstream(upstreamStr) + if err != nil { + return nil, fmt.Errorf("client: failed to parse upstream %q: %w", upstreamStr, err) + } + + var client DNSClient + var clientErr error + + switch cfg.protocol { + case protoDo53: + // Ensure do53.New matches this signature + config := do53.Config{HostAndPort: net.JoinHostPort(cfg.host, cfg.port), DNSSEC: false} + client, clientErr = do53.New(config) + + case protoDoH: + // Ensure doh.New matches this signature + config := doh.Config{Host: cfg.host, Port: cfg.port, Path: cfg.path, DNSSEC: false} + client, clientErr = doh.New(config) + + case protoDoT: + // Ensure dot.New matches this signature + // client, clientErr = dot.New(cfg.hostPort(), opts.Timeout, opts.DNSSEC, opts.KeyLogPath) + // if clientErr == nil && client == nil { + // clientErr = fmt.Errorf("client: DoT package returned nil client without error") + // } + + case protoDoQ: + // Ensure doq.New matches this signature + // client, clientErr = doq.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath) + // if clientErr == nil && client == nil { + // clientErr = fmt.Errorf("client: DoQ package returned nil client without error") + // } + + case protoDoH3: + // Decide on DoH3 handling (fallback or error) + // Fallback example: + // fmt.Fprintf(os.Stderr, "Warning: DoH3 protocol (h3://) detected for %s. Attempting connection using standard DoH (HTTPS).\n", cfg.original) + // client, clientErr = doh.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath) + // Error example: + // clientErr = fmt.Errorf("client: DoH3 protocol (h3://) is not yet supported") + + default: + clientErr = fmt.Errorf("client: unknown or unsupported protocol detected for upstream: %s", upstreamStr) + } + + if clientErr != nil { + return nil, fmt.Errorf("client: failed to create client for %s: %w", upstreamStr, clientErr) + } + if client == nil { + // Should be caught by clientErr checks above, but as a safeguard + return nil, fmt.Errorf("client: internal error - nil client returned for %s", upstreamStr) + } + + return client, nil +} + +// Helper function to close key log writer if needed (can be used by specific clients) +func CloseKeyLogWriter(w io.WriteCloser) error { + if w != nil { + return w.Close() + } + return nil +} diff --git a/internal/protocols/do53/do53.go b/internal/protocols/do53/do53.go index 0ed695b..ee659f9 100644 --- a/internal/protocols/do53/do53.go +++ b/internal/protocols/do53/do53.go @@ -2,64 +2,128 @@ package do53 import ( "fmt" + "log" "net" + "sync" "github.com/miekg/dns" ) -type Do53Client struct { - udpConn *net.UDPConn +type Config struct { + HostAndPort string + DNSSEC bool } -func New(dest string) (*Do53Client, error) { +type Client struct { + udpAddr *net.UDPAddr + conn *net.UDPConn - udpAddr, err := net.ResolveUDPAddr("udp", dest) - if err != nil { - return nil, fmt.Errorf("failed to resolve UDP address: %v", err) - } + responseChannels map[uint16]chan *dns.Msg + responseMutex *sync.Mutex - udpConn, err := net.DialUDP("udp", nil, udpAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial UDP connection: %v", err) - } - return &Do53Client{udpConn: udpConn}, nil + config Config } -func (c *Do53Client) Close() { - if c.udpConn != nil { - c.udpConn.Close() +func New(config Config) (*Client, error) { + udpAddr, err := net.ResolveUDPAddr("udp", config.HostAndPort) + if err != nil { + return nil, fmt.Errorf("do53: failed to resolve UDP address %q: %w", config.HostAndPort, err) + } + + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, fmt.Errorf("do53: failed to dial UDP connection to %s: %w", config.HostAndPort, err) + } + + responseChannels := map[uint16]chan *dns.Msg{} + rcMutex := new(sync.Mutex) + + client := &Client{ + udpAddr: udpAddr, + conn: conn, + responseChannels: responseChannels, + responseMutex: rcMutex, + config: config, + } + + go client.receiveLoop() + + return client, nil +} + +func (c *Client) Close() { + if c.conn != nil { + c.conn.Close() + c.conn = nil } } -func (c *Do53Client) Query(domain, queryType, dest string, dnssec bool) error { +func (c *Client) receiveLoop() { - message, err := NewDNSMessage(domain, queryType) - if err != nil { - return err + buffer := make([]byte, dns.MaxMsgSize) + + for { + // Reads one UDP Datagram + n, err := c.conn.Read(buffer) + if err != nil { + log.Printf("do53: failed to read DNS response: %s", err.Error()) + } + + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(buffer[:n]) + if err != nil { + log.Printf("do53: failed to unpack DNS response: %s", err.Error()) + continue + } + + c.responseMutex.Lock() + respChan, ok := c.responseChannels[recvMsg.Id] + delete(c.responseChannels, recvMsg.Id) + c.responseMutex.Unlock() + + if ok { + respChan <- recvMsg + } else { + log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id) + } } - _, err = c.udpConn.Write(message) - if err != nil { - return fmt.Errorf("failed to send DNS query: %v", err) - } - - buf := make([]byte, 4096) - n, err := c.udpConn.Read(buf) - if err != nil { - return fmt.Errorf("failed to read DNS response: %v", err) - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(buf[:n]) - if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) - } - - // TODO: Check if the response had no errors or TD bit set - - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } - - return nil +} + +func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), queryType) + msg.Id = dns.Id() + msg.RecursionDesired = true + + if c.config.DNSSEC { + msg.SetEdns0(4096, true) + } + + respChan := make(chan *dns.Msg) + + c.responseMutex.Lock() + c.responseChannels[msg.Id] = respChan + c.responseMutex.Unlock() + + packedMsg, err := msg.Pack() + if err != nil { + c.responseMutex.Lock() + delete(c.responseChannels, msg.Id) + c.responseMutex.Unlock() + return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err) + } + + _, err = c.conn.Write(packedMsg) + if err != nil { + c.responseMutex.Lock() + delete(c.responseChannels, msg.Id) + c.responseMutex.Unlock() + return nil, fmt.Errorf("do53: failed to send DNS query: %w", err) + } + + recvMsg := <-respChan + + return recvMsg, nil } diff --git a/internal/protocols/doh/doh.go b/internal/protocols/doh/doh.go index 63ca869..fa17cdf 100644 --- a/internal/protocols/doh/doh.go +++ b/internal/protocols/doh/doh.go @@ -1,121 +1,125 @@ package doh import ( - "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "net" "net/http" - "os" + "net/url" + "strings" - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/miekg/dns" ) -type DoHClient struct { - tcpConn *net.TCPConn - tlsConn *tls.Conn - keyLogFile *os.File - target string - path string - proxy string +const dnsMessageContentType = "application/dns-message" + +type Config struct { + Host string + Port string + Path string + DNSSEC bool } -func New(target, path, proxy string) (*DoHClient, error) { +type Client struct { + httpClient *http.Client + upstreamURL *url.URL + config Config +} - tcpAddr, err := net.ResolveTCPAddr("tcp", target) - if err != nil { - return nil, fmt.Errorf("failed to resolve TCP address: %v", err) +func New(config Config) (*Client, error) { + if config.Host == "" || config.Port == "" || config.Path == "" { + fmt.Printf("%v,%v,%v", config.Host,config.Port,config.Path) + return nil, errors.New("doh: host, port, and path must not be empty") } - tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - return nil, fmt.Errorf("failed to establish TCP connection: %v", err) + if !strings.HasPrefix(config.Path, "/") { + config.Path = "/" + config.Path } + rawURL := "https://" + net.JoinHostPort(config.Host, config.Port) + config.Path - keyLogFile, err := os.OpenFile( - "tls-key-log.txt", - os.O_APPEND|os.O_CREATE|os.O_WRONLY, - 0600, - ) + parsedURL, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("failed opening key log file: %v", err) + return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err) } tlsConfig := &tls.Config{ - // FIX: Actually check the domain name - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS12, - KeyLogWriter: keyLogFile, + ServerName: config.Host, + MinVersion: tls.VersionTLS12, } - tlsConn := tls.Client(tcpConn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, fmt.Errorf("failed to execute the TLS handshake: %v", err) + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + ForceAttemptHTTP2: true, } - return &DoHClient{tcpConn: tcpConn, keyLogFile: keyLogFile, tlsConn: tlsConn, target: target, path: path, proxy: proxy}, err + httpClient := &http.Client{ + Transport: transport, + } + return &Client{ + httpClient: httpClient, + upstreamURL: parsedURL, + config: config, + }, nil } -func (c *DoHClient) Close() { - if c.tcpConn != nil { - c.tcpConn.Close() - } - if c.keyLogFile != nil { - c.keyLogFile.Close() - } - if c.tlsConn != nil { - c.tlsConn.Close() - } +// Close cleans up idle connections held by the underlying HTTP transport. +func (c *Client) Close() { + c.httpClient.CloseIdleConnections() } -func (c *DoHClient) Query(domain, queryType string, dnssec bool) error { +func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), queryType) + msg.Id = dns.Id() + msg.RecursionDesired = true - DNSMessage, err := do53.NewDNSMessage(domain, queryType) + if c.config.DNSSEC { + msg.SetEdns0(4096, true) + } + + packedMsg, err := msg.Pack() if err != nil { - return err + return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err) } - httpReq, err := http.NewRequest("POST", "https://"+c.target+"/"+c.path, bytes.NewBuffer(DNSMessage)) + httpReq, err := http.NewRequest("POST", c.upstreamURL.String(), bytes.NewReader(packedMsg)) if err != nil { - return fmt.Errorf("failed to create HTTP request: %v", err) + return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err) } - httpReq.Header.Add("Content-Type", "application/dns-message") - httpReq.Header.Set("Accept", "application/dns-message") - err = httpReq.Write(c.tlsConn) + httpReq.Header.Set("User-Agent", "sdns-perf") + httpReq.Header.Set("Content-Type", dnsMessageContentType) + httpReq.Header.Set("Accept", dnsMessageContentType) + + httpResp, err := c.httpClient.Do(httpReq) if err != nil { - return fmt.Errorf("failed writing HTTP request: %v", err) + return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err) + } + defer httpResp.Body.Close() + + if httpResp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status) } - reader := bufio.NewReader(c.tlsConn) - resp, err := http.ReadResponse(reader, httpReq) + if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType { + return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType) + } + + responseBody, err := io.ReadAll(httpResp.Body) if err != nil { - return fmt.Errorf("failed reading HTTP response: %v", err) - } - defer resp.Body.Close() - - responseBody := make([]byte, 4096) - n, err := resp.Body.Read(responseBody) - if err != nil && err != io.EOF { - return fmt.Errorf("failed reading response body: %v", err) + return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err) } + // Unpack the DNS message recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBody[:n]) + err = recvMsg.Unpack(responseBody) if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) + return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err) } - // TODO: Check if the response had no errors or TD bit set - - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } - - return nil + return recvMsg, nil } diff --git a/internal/protocols/doq/doq.go b/internal/protocols/doq/doq.go index 936d0eb..d99b95b 100644 --- a/internal/protocols/doq/doq.go +++ b/internal/protocols/doq/doq.go @@ -7,87 +7,144 @@ import ( "encoding/binary" "fmt" "io" + "net" "os" + "time" "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/miekg/dns" "github.com/quic-go/quic-go" ) -type DoQClient struct { - target string +type Client struct { + targetAddr *net.UDPAddr keyLogFile *os.File tlsConfig *tls.Config + udpConn *net.UDPConn + quicConn quic.Connection + quicTransport *quic.Transport + quicConfig *quic.Config } -func New(target string) (*DoQClient, error) { +func New(target string) (*Client, error) { keyLogFile, err := os.OpenFile( "tls-key-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600, ) if err != nil { - return nil, fmt.Errorf("failed opening key log file: %v", err) + return nil, fmt.Errorf("failed opening key log file: %w", err) } tlsConfig := &tls.Config{ // FIX: Actually check the domain name InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, + ClientSessionCache: tls.NewLRUClientSessionCache(100), KeyLogWriter: keyLogFile, NextProtos: []string{"doq"}, } - return &DoQClient{ - target: target, + udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:6000") + if err != nil { + return nil, fmt.Errorf("failed to resolve target address: %w", err) + } + targetAddr, err := net.ResolveUDPAddr("udp", target) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, fmt.Errorf("failed to connect to target address: %w", err) + } + + quicTransport := quic.Transport{ + Conn: udpConn, + } + + quicConfig := quic.Config{ + // Use the default value of 30 seconds + MaxIdleTimeout: 30 * time.Second, + } + + return &Client{ + targetAddr: targetAddr, keyLogFile: keyLogFile, tlsConfig: tlsConfig, + udpConn: udpConn, + quicConn: nil, + quicTransport: &quicTransport, + quicConfig: &quicConfig, }, nil } -func (c *DoQClient) Close() { +func (c *Client) Close() { if c.keyLogFile != nil { c.keyLogFile.Close() } + if c.udpConn != nil { + c.udpConn.Close() + } } -func (c *DoQClient) Query(domain, queryType string, dnssec bool) error { - quicConn, err := quic.DialAddr(context.Background(), c.target, c.tlsConfig, &quic.Config{}) +func (c *Client) OpenConnection() error { + quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) if err != nil { - return fmt.Errorf("failed to establish QUIC connection: %v", err) + return err + } + + c.quicConn = quicConn + return nil +} + +func (c *Client) Query(domain, queryType string, dnssec bool) error { + + if c.quicConn == nil { + err := c.OpenConnection() + if err != nil { + return err + } } - defer quicConn.CloseWithError(0, "") DNSMessage, err := do53.NewDNSMessage(domain, queryType) if err != nil { return err } - quicStream, err := quicConn.OpenStreamSync(context.Background()) + var quicStream quic.Stream + quicStream, err = c.quicConn.OpenStream() if err != nil { - return fmt.Errorf("failed to opening QUIC stream: %v", err) + err = c.OpenConnection() + if err != nil { + return err + } + quicStream, err = c.quicConn.OpenStream() + if err != nil { + return err + } } - defer quicStream.Close() var lengthPrefixedMessage bytes.Buffer err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) if err != nil { - return fmt.Errorf("failed to write message length: %v", err) + return fmt.Errorf("failed to write message length: %w", err) } _, err = lengthPrefixedMessage.Write(DNSMessage) if err != nil { - return fmt.Errorf("failed to write DNS message: %v", err) + return fmt.Errorf("failed to write DNS message: %w", err) } _, err = quicStream.Write(lengthPrefixedMessage.Bytes()) if err != nil { - return fmt.Errorf("failed writing to QUIC stream: %v", err) + return fmt.Errorf("failed writing to QUIC stream: %w", err) } + // Indicate that no further data will be written from this side + quicStream.Close() lengthBuf := make([]byte, 2) _, err = io.ReadFull(quicStream, lengthBuf) if err != nil { - return fmt.Errorf("failed reading response length: %v", err) + return fmt.Errorf("failed reading response length: %w", err) } messageLength := binary.BigEndian.Uint16(lengthBuf) @@ -98,17 +155,18 @@ func (c *DoQClient) Query(domain, queryType string, dnssec bool) error { responseBuf := make([]byte, messageLength) _, err = io.ReadFull(quicStream, responseBuf) if err != nil { - return fmt.Errorf("failed reading response data: %v", err) + return fmt.Errorf("failed reading response data: %w", err) } recvMsg := new(dns.Msg) err = recvMsg.Unpack(responseBuf) if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) + return fmt.Errorf("failed to parse DNS response: %w", err) } // TODO: Check if the response had no errors or TD bit set + fmt.Println(c.quicConn.ConnectionState().Used0RTT) for _, answer := range recvMsg.Answer { fmt.Println(answer.String()) } diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index 99ebe2c..d314281 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -1,122 +1,161 @@ package dot import ( - "bytes" + "context" "crypto/tls" "encoding/binary" "fmt" "io" + "log" "net" "os" + "sync" + "time" - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/miekg/dns" ) -type DoTClient struct { - tcpConn *net.TCPConn - tlsConn *tls.Conn - keyLogFile *os.File +type Config struct { + Host string + Port string + DNSSEC bool + Debug bool } -func New(target string) (*DoTClient, error) { +type Client struct { + config Config - tcpAddr, err := net.ResolveTCPAddr("tcp", target) + serverAddr *net.TCPAddr + + tcpConn *net.TCPConn + tlsConn *tls.Conn + tlsConfig *tls.Config + keyLogFile *os.File + + sendChannel chan *dns.Msg + + responseChannels map[uint16]chan *dns.Msg + responseMutex *sync.Mutex +} + +func New(config Config) (*Client, error) { + serverAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(config.Host, config.Port)) if err != nil { - return nil, fmt.Errorf("failed to resolve TCP address: %v", err) + return nil, fmt.Errorf("dot: failed to resolve TCP address %q: %w", config.Host, err) } - tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - return nil, fmt.Errorf("failed to establish TCP connection: %v", err) - } - - keyLogFile, err := os.OpenFile( - "tls-key-log.txt", - os.O_APPEND|os.O_CREATE|os.O_WRONLY, - 0600, - ) - if err != nil { - return nil, fmt.Errorf("failed opening key log file: %v", err) + var keyLogFile *os.File + if config.Debug { + keyLogFile, err = os.OpenFile( + "tls-key-log.txt", + os.O_APPEND|os.O_CREATE|os.O_WRONLY, + 0600, + ) + if err != nil { + log.Printf("dot: failed opening TLS key log file: %v", err) + keyLogFile = nil + } } tlsConfig := &tls.Config{ - InsecureSkipVerify: true, + ServerName: serverAddr.IP.String(), MinVersion: tls.VersionTLS12, KeyLogWriter: keyLogFile, + ClientSessionCache: tls.NewLRUClientSessionCache(100), } - tlsConn := tls.Client(tcpConn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, fmt.Errorf("failed to execute the TLS handshake: %v", err) + client := &Client{ + config: config, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + keyLogFile: keyLogFile, } - return &DoTClient{tcpConn: tcpConn, tlsConn: tlsConn, keyLogFile: keyLogFile}, nil + go client.receiveLoop() + + return client, nil } -func (c *DoTClient) Close() { - if c.tcpConn != nil { - c.tcpConn.Close() - } +func (c *Client) Close() { if c.tlsConn != nil { c.tlsConn.Close() + c.tlsConn = nil } + + if c.tcpConn != nil { + c.tcpConn.Close() + c.tcpConn = nil + } + if c.keyLogFile != nil { c.keyLogFile.Close() + c.keyLogFile = nil } } -func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error { +func (c *Client) receiveLoop() { - DNSMessage, err := do53.NewDNSMessage(domain, queryType) + lengthBuffer := make([]byte, 2) + buffer := make([]byte, dns.MaxMsgSize) + + for { + msgSize, err := io.ReadFull(c.tlsConn, lengthBuffer) + if err != nil { + log.Printf("doh: failed to read the DNS message's size: %s", err.Error()) + // FIX: HANDLE RECONNECTION + } + n, err := io.ReadFull(c.tlsConn, buffer[:msgSize]) + if err != nil { + log.Printf("doh: failed to read the DNS message: %s", err.Error()) + // FIX: HANDLE RECONNECTION + } + + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(buffer[:n]) + if err != nil { + log.Printf("do53: failed to unpack DNS response: %s", err.Error()) + continue + } + + c.responseMutex.Lock() + respChan, ok := c.responseChannels[recvMsg.Id] + delete(c.responseChannels, recvMsg.Id) + c.responseMutex.Unlock() + + if ok { + respChan <- recvMsg + } else { + log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id) + } + + } + +} + +func (c *Client) connect(ctx context.Context) error { + tcpConn, err := net.DialTCP("tcp", nil, c.serverAddr) if err != nil { - return err + return fmt.Errorf("dot: failed to establish TCP connection: %w", err) } - var lengthPrefixedMessage bytes.Buffer - err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) + c.tcpConn.SetKeepAlive(true) + c.tcpConn.SetKeepAlivePeriod(1 * time.Minute) + + tlsConn := tls.Client(c.tcpConn, c.tlsConfig) + err = tlsConn.HandshakeContext(ctx) if err != nil { - return fmt.Errorf("failed to write message length: %v", err) - } - _, err = lengthPrefixedMessage.Write(DNSMessage) - if err != nil { - return fmt.Errorf("failed to write DNS message: %v", err) + c.tcpConn.Close() + c.tcpConn = nil + return fmt.Errorf("dot: failed to execute the TLS handshake: %w", err) } - _, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes()) - if err != nil { - return fmt.Errorf("failed writing TLS request: %v", err) - } + c.tlsConn = tlsConn - lengthBuf := make([]byte, 2) - _, err = io.ReadFull(c.tlsConn, lengthBuf) - if err != nil { - return fmt.Errorf("failed reading response length: %v", err) - } - - messageLength := binary.BigEndian.Uint16(lengthBuf) - if messageLength == 0 { - return fmt.Errorf("received zero-length message") - } - - responseBuf := make([]byte, messageLength) - _, err = io.ReadFull(c.tlsConn, responseBuf) - if err != nil { - return fmt.Errorf("failed reading TLS response: %v", err) - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBuf) - if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) - } - - // TODO: Check if the response had no errors or TD bit set - - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } + log.Println("dot: TCP/TLS connection established successfully.") return nil } + +func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { + //TODO +}