From 18e3b47c0740a3dd5da812384cdde715dbd52ccf Mon Sep 17 00:00:00 2001 From: afonso Date: Wed, 26 Feb 2025 17:55:14 +0000 Subject: [PATCH 1/8] Do53 and DoH (POST) basic queries implemented --- .gitignore | 17 +++ README.md | 3 + cmd/resolver/main.go | 78 +++++++++++++ go.mod | 8 ++ go.sum | 10 ++ internal/protocols/dnscrypt/dnscrypt.go | 3 + internal/protocols/dnssec/dnssec.go | 3 + internal/protocols/do53/do53.go | 63 ++++++++++ internal/protocols/do53/packet.go | 60 ++++++++++ internal/protocols/doh/doh.go | 148 ++++++++++++++++++++++++ internal/protocols/doq/doq.go | 3 + internal/protocols/dot/dot.go | 3 + 12 files changed, 399 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 cmd/resolver/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/protocols/dnscrypt/dnscrypt.go create mode 100644 internal/protocols/dnssec/dnssec.go create mode 100644 internal/protocols/do53/do53.go create mode 100644 internal/protocols/do53/packet.go create mode 100644 internal/protocols/doh/doh.go create mode 100644 internal/protocols/doq/doq.go create mode 100644 internal/protocols/dot/dot.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..30f3f28 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +**/tls-key-log.txt diff --git a/README.md b/README.md new file mode 100644 index 0000000..31d7daf --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# DNS Resolver + +A DNS resolver supporting multiple protocols including DoH, DoT, DoQ, DNSSEC, ODoH, and DNSCrypt. diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go new file mode 100644 index 0000000..2731df6 --- /dev/null +++ b/cmd/resolver/main.go @@ -0,0 +1,78 @@ +package main + +import ( + "log" + + "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" + "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" + "github.com/alecthomas/kong" +) + +type CommonFlags struct { + DomainName string `help:"Domain name to resolve" arg:"" required:""` + QueryType string `help:"Query type" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` + Server string `help:"DNS server to use"` + DNSSEC bool `help:"Enable DNSSEC validation"` +} + +type DoHCmd struct { + CommonFlags `embed:""` + HTTP3 bool `help:"Use HTTP/3" name:"http3"` + Path string `help:"The HTTP path for the POST request" name:"path" required:""` + Proxy string `help:"The Proxy to use with ODoH"` +} + +type DoTCmd struct { + CommonFlags +} + +type DoQCmd struct { + CommonFlags +} + +type Do53Cmd struct { + CommonFlags +} + +var cli struct { + Verbose bool `help:"Enable verbose logging" short:"v"` + + DoH DoHCmd `cmd:"doh" help:"Query using DNS-over-HTTPS" name:"doh"` + DoT DoTCmd `cmd:"dot" help:"Query using DNS-over-TLS" name:"dot"` + DoQ DoQCmd `cmd:"doq" help:"Query using DNS-over-QUIC" name:"doq"` + Do53 Do53Cmd `cmd:"doq" help:"Query using plain DNS over UDP" name:"do53"` +} + +func (c *Do53Cmd) Run() error { + return do53.Run(c.DomainName, c.QueryType, c.Server, c.DNSSEC) +} + +func (c *DoHCmd) Run() error { + return doh.Run(c.DomainName, c.QueryType, c.Server, c.Path,c.Proxy, c.DNSSEC) +} + +func (c *DoTCmd) Run() error { + // TODO: Implement DoT query + return nil +} + +func (c *DoQCmd) Run() error { + // TODO: Implement DoQ query + return nil +} + +func main() { + ctx := kong.Parse(&cli, + kong.Name("dns-go"), + kong.Description("A DNS resolver supporting DoH, DoT, and DoQ protocols"), + kong.UsageOnError(), + kong.ConfigureHelp(kong.HelpOptions{ + Compact: true, + Summary: true, + })) + + err := ctx.Run() + if err != nil { + log.Fatalf("Error: %v", err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7b1813b --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/afonsofrancof/sdns-perf + +go 1.24.0 + +require ( + github.com/alecthomas/kong v1.8.1 + golang.org/x/net v0.35.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f14bc2b --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= +github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/kong v1.8.1 h1:6aamvWBE/REnR/BCq10EcozmcpUPc5aGI1lPAWdB0EE= +github.com/alecthomas/kong v1.8.1/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= diff --git a/internal/protocols/dnscrypt/dnscrypt.go b/internal/protocols/dnscrypt/dnscrypt.go new file mode 100644 index 0000000..9939b27 --- /dev/null +++ b/internal/protocols/dnscrypt/dnscrypt.go @@ -0,0 +1,3 @@ +package dnscrypt + +// DNSCrypt resolver implementation diff --git a/internal/protocols/dnssec/dnssec.go b/internal/protocols/dnssec/dnssec.go new file mode 100644 index 0000000..fb852f6 --- /dev/null +++ b/internal/protocols/dnssec/dnssec.go @@ -0,0 +1,3 @@ +package dnssec + +// DNSSEC resolver implementation diff --git a/internal/protocols/do53/do53.go b/internal/protocols/do53/do53.go new file mode 100644 index 0000000..8728edd --- /dev/null +++ b/internal/protocols/do53/do53.go @@ -0,0 +1,63 @@ +package do53 + +import ( + "fmt" + "net" + + "golang.org/x/net/dns/dnsmessage" +) + +func Run(domain, queryType, dest string, dnssec bool) error { + + message, err := MakeDNSMessage(domain, queryType) + if err != nil { + return err + } + + udpAddr, err := net.ResolveUDPAddr("udp", dest) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %v", err) + } + + udpConn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return fmt.Errorf("failed to dial UDP connection: %v", err) + } + defer udpConn.Close() + + _, err = udpConn.Write(message) + if err != nil { + return fmt.Errorf("failed to send DNS query: %v", err) + } + + buf := make([]byte, 4096) + n, err := udpConn.Read(buf) + if err != nil { + return fmt.Errorf("failed to read DNS response: %v", err) + } + + var parser dnsmessage.Parser + _, err = parser.Start(buf[:n]) + if err != nil { + return fmt.Errorf("failed to parse DNS response: %v", err) + } + + // TODO: Check if the response had no errors or TD bit set + + err = parser.SkipAllQuestions() + if err != nil { + return fmt.Errorf("failed to skip questions: %v", err) + } + + answers, err := parser.AllAnswers() + if err != nil { + return err + } + + for _, answer := range answers { + fmt.Println(answer.GoString()) + } + + return nil +} + diff --git a/internal/protocols/do53/packet.go b/internal/protocols/do53/packet.go new file mode 100644 index 0000000..d287218 --- /dev/null +++ b/internal/protocols/do53/packet.go @@ -0,0 +1,60 @@ +package do53 + +import ( + "fmt" + + "golang.org/x/net/dns/dnsmessage" +) + +func MakeDNSMessage(domain string, queryType string) ([]byte, error) { + messageHeader := dnsmessage.Header{ + ID: 1234, // FIX: Use a random ID + Response: false, + OpCode: dnsmessage.OpCode(0), + RecursionDesired: true, + } + + messageBuilder := dnsmessage.NewBuilder(nil, messageHeader) + queryName, err := dnsmessage.NewName(domain) + if err != nil { + return nil, fmt.Errorf("failed to create query name: %v", err) + } + + // Determine query type + var queryTypeValue dnsmessage.Type + switch queryType { + case "A": + queryTypeValue = dnsmessage.TypeA + case "AAAA": + queryTypeValue = dnsmessage.TypeAAAA + case "MX": + queryTypeValue = dnsmessage.TypeMX + case "CNAME": + queryTypeValue = dnsmessage.TypeCNAME + case "TXT": + queryTypeValue = dnsmessage.TypeTXT + default: + queryTypeValue = dnsmessage.TypeA + } + + messageQuestion := dnsmessage.Question{ + Name: queryName, + Type: queryTypeValue, + Class: dnsmessage.ClassINET, + } + + err = messageBuilder.StartQuestions() + if err != nil { + return nil, err + } + err = messageBuilder.Question(messageQuestion) + if err != nil { + return nil, fmt.Errorf("failed to add question: %v", err) + } + + message, err := messageBuilder.Finish() + if err != nil { + return nil, fmt.Errorf("failed to build message: %v", err) + } + return message, nil +} diff --git a/internal/protocols/doh/doh.go b/internal/protocols/doh/doh.go new file mode 100644 index 0000000..915bac5 --- /dev/null +++ b/internal/protocols/doh/doh.go @@ -0,0 +1,148 @@ +package doh + +import ( + "bufio" + "bytes" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "os" + + "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" + "golang.org/x/net/dns/dnsmessage" +) + +func Run(domain, queryType, server, path, proxy string, dnssec bool) error { + + DNSMessage, err := do53.MakeDNSMessage(domain, queryType) + if err != nil { + return err + } + + // Step 1 - Establish a TCP Connection + tcpConn, err := net.Dial("tcp", server) + if err != nil { + return fmt.Errorf("failed to establish TCP connection: %v", err) + } + defer tcpConn.Close() + + // Step 2 - Upgrade it to a TLS Connection + + // Temporary keylog file to allow traffic inspection + keyLogFile, err := os.OpenFile( + "tls-key-log.txt", + os.O_APPEND|os.O_CREATE|os.O_WRONLY, + 0600, + ) + if err != nil { + return fmt.Errorf("failed opening key log file: %v", err) + } + defer keyLogFile.Close() + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + KeyLogWriter: keyLogFile, + } + + tlsConn := tls.Client(tcpConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return fmt.Errorf("failed to execute the TLS handshake: %v", err) + } + defer tlsConn.Close() + + // Step 3 - Create an HTTP request with the do53 message in the body + httpReq, err := http.NewRequest("POST", "https://"+server+"/"+path, bytes.NewBuffer(DNSMessage)) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %v", err) + } + httpReq.Header.Add("Content-Type", "application/dns-message") + httpReq.Header.Set("Accept", "application/dns-message") + + err = httpReq.Write(tlsConn) + if err != nil { + return fmt.Errorf("failed writing HTTP request: %v", err) + } + + reader := bufio.NewReader(tlsConn) + resp, err := http.ReadResponse(reader, httpReq) + if err != nil { + return fmt.Errorf("failed reading HTTP response: %v", err) + } + defer resp.Body.Close() + + responseBody := make([]byte, 4096) + n, err := resp.Body.Read(responseBody) + if err != nil && err != io.EOF { + return fmt.Errorf("failed reading response body: %v", err) + } + + // Parse the response + var parser dnsmessage.Parser + header, err := parser.Start(responseBody[:n]) + if err != nil { + return fmt.Errorf("failed to parse DNS response: %v", err) + } + + fmt.Printf("DNS Response Header:\n") + fmt.Printf(" ID: %d\n", header.ID) + fmt.Printf(" Response: %v\n", header.Response) + fmt.Printf(" RCode: %v\n", header.RCode) + + // Skip all questions before reading answers + err = parser.SkipAllQuestions() + if err != nil { + return fmt.Errorf("failed to skip questions: %v", err) + } + + // Parse answers + fmt.Printf("\nAnswers:\n") + answers, err := parser.AllAnswers() + + for i, answer := range answers { + + if err != nil { + return fmt.Errorf("failed to parse answer %d: %v", i, err) + } + + fmt.Printf(" Answer %d:\n", i+1) + fmt.Printf(" Name: %v\n", answer.Header.Name) + fmt.Printf(" Type: %v\n", answer.Header.Type) + fmt.Printf(" TTL: %v seconds\n", answer.Header.TTL) + + // Handle different record types + switch answer.Header.Type { + case dnsmessage.TypeA: + if r, ok := answer.Body.(*dnsmessage.AResource); ok { + fmt.Printf(" IPv4: %d.%d.%d.%d\n", r.A[0], r.A[1], r.A[2], r.A[3]) + } + case dnsmessage.TypeAAAA: + if r, ok := answer.Body.(*dnsmessage.AAAAResource); ok { + ip := r.AAAA + fmt.Printf(" IPv6: %02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x\n", + ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], + ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]) + } + case dnsmessage.TypeCNAME: + if r, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { + fmt.Printf(" CNAME: %v\n", r.CNAME) + } + case dnsmessage.TypeMX: + if r, ok := answer.Body.(*dnsmessage.MXResource); ok { + fmt.Printf(" Preference: %v\n", r.Pref) + fmt.Printf(" MX: %v\n", r.MX) + } + case dnsmessage.TypeTXT: + if r, ok := answer.Body.(*dnsmessage.TXTResource); ok { + fmt.Printf(" TXT: %v\n", r.TXT) + } + default: + fmt.Printf(" [Unsupported record type]\n") + } + } + + return nil +} diff --git a/internal/protocols/doq/doq.go b/internal/protocols/doq/doq.go new file mode 100644 index 0000000..123b02c --- /dev/null +++ b/internal/protocols/doq/doq.go @@ -0,0 +1,3 @@ +package doq + +// DoQ (DNS over QUIC) resolver implementation diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go new file mode 100644 index 0000000..d1f5d05 --- /dev/null +++ b/internal/protocols/dot/dot.go @@ -0,0 +1,3 @@ +package dot + +// DoT (DNS over TLS) resolver implementation From f17ff6123c5d7c4cd6586cdd068fb4cf339df0b9 Mon Sep 17 00:00:00 2001 From: afonso Date: Wed, 26 Feb 2025 21:39:22 +0000 Subject: [PATCH 2/8] Added basic DoT support --- cmd/resolver/main.go | 6 +- internal/protocols/dot/dot.go | 144 +++++++++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 4 deletions(-) diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go index 2731df6..6127d83 100644 --- a/cmd/resolver/main.go +++ b/cmd/resolver/main.go @@ -5,13 +5,14 @@ import ( "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" + "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" "github.com/alecthomas/kong" ) type CommonFlags struct { DomainName string `help:"Domain name to resolve" arg:"" required:""` QueryType string `help:"Query type" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` - Server string `help:"DNS server to use"` + Server string `help:"DNS server to use" required:""` DNSSEC bool `help:"Enable DNSSEC validation"` } @@ -52,8 +53,7 @@ func (c *DoHCmd) Run() error { } func (c *DoTCmd) Run() error { - // TODO: Implement DoT query - return nil + return dot.Run(c.DomainName, c.QueryType, c.Server, c.DNSSEC) } func (c *DoQCmd) Run() error { diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index d1f5d05..eaf6ebd 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -1,3 +1,145 @@ package dot -// DoT (DNS over TLS) resolver implementation +import ( + "crypto/tls" + "fmt" + "net" + "os" + + "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" + "golang.org/x/net/dns/dnsmessage" +) + +func Run(domain, queryType, server string, dnssec bool) error { + + DNSMessage, err := do53.MakeDNSMessage(domain, queryType) + if err != nil { + return err + } + + // Step 1 - Establish a TCP Connection + tcpConn, err := net.Dial("tcp", server) + if err != nil { + return fmt.Errorf("failed to establish TCP connection: %v", err) + } + defer tcpConn.Close() + + // Step 2 - Upgrade it to a TLS Connection + + // Temporary keylog file to allow traffic inspection + keyLogFile, err := os.OpenFile( + "tls-key-log.txt", + os.O_APPEND|os.O_CREATE|os.O_WRONLY, + 0600, + ) + if err != nil { + return fmt.Errorf("failed opening key log file: %v", err) + } + defer keyLogFile.Close() + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + KeyLogWriter: keyLogFile, + } + + tlsConn := tls.Client(tcpConn, tlsConfig) + err = tlsConn.Handshake() + if err != nil { + return fmt.Errorf("failed to execute the TLS handshake: %v", err) + } + defer tlsConn.Close() + + // Before sending the DNS message over TLS, prepend the 2-byte length field + lengthPrefixedMessage := make([]byte, len(DNSMessage)+2) + lengthPrefixedMessage[0] = byte(len(DNSMessage) >> 8) // High byte + lengthPrefixedMessage[1] = byte(len(DNSMessage) & 0xFF) // Low byte + copy(lengthPrefixedMessage[2:], DNSMessage) + + _, err = tlsConn.Write(lengthPrefixedMessage) + if err != nil { + return fmt.Errorf("failed writing TLS request: %v", err) + } + + // Read the 2-byte length prefix + lengthBuf := make([]byte, 2) + _, err = tlsConn.Read(lengthBuf) + if err != nil { + return fmt.Errorf("failed reading response length: %v", err) + } + + // Calculate the message length from the 2-byte prefix + messageLength := int(lengthBuf[0])<<8 | int(lengthBuf[1]) + + responseBuf := make([]byte, messageLength) + n, err := tlsConn.Read(responseBuf) + if err != nil { + return fmt.Errorf("failed reading TLS response: %v", err) + } + + // Parse the response + var parser dnsmessage.Parser + header, err := parser.Start(responseBuf[:n]) + if err != nil { + return fmt.Errorf("failed to parse DNS response: %v", err) + } + + fmt.Printf("DNS Response Header:\n") + fmt.Printf(" ID: %d\n", header.ID) + fmt.Printf(" Response: %v\n", header.Response) + fmt.Printf(" RCode: %v\n", header.RCode) + + // Skip all questions before reading answers + err = parser.SkipAllQuestions() + if err != nil { + return fmt.Errorf("failed to skip questions: %v", err) + } + + // Parse answers + fmt.Printf("\nAnswers:\n") + answers, err := parser.AllAnswers() + + for i, answer := range answers { + + if err != nil { + return fmt.Errorf("failed to parse answer %d: %v", i, err) + } + + fmt.Printf(" Answer %d:\n", i+1) + fmt.Printf(" Name: %v\n", answer.Header.Name) + fmt.Printf(" Type: %v\n", answer.Header.Type) + fmt.Printf(" TTL: %v seconds\n", answer.Header.TTL) + + // Handle different record types + switch answer.Header.Type { + case dnsmessage.TypeA: + if r, ok := answer.Body.(*dnsmessage.AResource); ok { + fmt.Printf(" IPv4: %d.%d.%d.%d\n", r.A[0], r.A[1], r.A[2], r.A[3]) + } + case dnsmessage.TypeAAAA: + if r, ok := answer.Body.(*dnsmessage.AAAAResource); ok { + ip := r.AAAA + fmt.Printf(" IPv6: %02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x\n", + ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], + ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]) + } + case dnsmessage.TypeCNAME: + if r, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { + fmt.Printf(" CNAME: %v\n", r.CNAME) + } + case dnsmessage.TypeMX: + if r, ok := answer.Body.(*dnsmessage.MXResource); ok { + fmt.Printf(" Preference: %v\n", r.Pref) + fmt.Printf(" MX: %v\n", r.MX) + } + case dnsmessage.TypeTXT: + if r, ok := answer.Body.(*dnsmessage.TXTResource); ok { + fmt.Printf(" TXT: %v\n", r.TXT) + } + default: + fmt.Printf(" [Unsupported record type]\n") + } + } + + return nil +} From dfdf518ea29d9886b7beddd4baf0e83e92b4b236 Mon Sep 17 00:00:00 2001 From: afonso Date: Thu, 27 Feb 2025 03:51:00 +0000 Subject: [PATCH 3/8] Made DoT length prefix operations more readable --- internal/protocols/dot/dot.go | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index eaf6ebd..448027b 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -1,7 +1,9 @@ package dot import ( + "bytes" "crypto/tls" + "encoding/binary" "fmt" "net" "os" @@ -50,13 +52,11 @@ func Run(domain, queryType, server string, dnssec bool) error { } defer tlsConn.Close() - // Before sending the DNS message over TLS, prepend the 2-byte length field - lengthPrefixedMessage := make([]byte, len(DNSMessage)+2) - lengthPrefixedMessage[0] = byte(len(DNSMessage) >> 8) // High byte - lengthPrefixedMessage[1] = byte(len(DNSMessage) & 0xFF) // Low byte - copy(lengthPrefixedMessage[2:], DNSMessage) + var lengthPrefixedMessage bytes.Buffer + binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) + lengthPrefixedMessage.Write(DNSMessage) - _, err = tlsConn.Write(lengthPrefixedMessage) + _, err = tlsConn.Write(lengthPrefixedMessage.Bytes()) if err != nil { return fmt.Errorf("failed writing TLS request: %v", err) } @@ -68,8 +68,7 @@ func Run(domain, queryType, server string, dnssec bool) error { return fmt.Errorf("failed reading response length: %v", err) } - // Calculate the message length from the 2-byte prefix - messageLength := int(lengthBuf[0])<<8 | int(lengthBuf[1]) + messageLength := binary.BigEndian.Uint16(lengthBuf) responseBuf := make([]byte, messageLength) n, err := tlsConn.Read(responseBuf) From f5fa15b70177f781711f095dcf0b1d6bf4fa9a8d Mon Sep 17 00:00:00 2001 From: afonso Date: Sat, 1 Mar 2025 05:47:46 +0000 Subject: [PATCH 4/8] Made clients for each protocol to reuse connections --- cmd/resolver/main.go | 26 +++++-- go.mod | 17 +++- go.sum | 56 +++++++++++++ internal/protocols/do53/do53.go | 54 ++++++------- internal/protocols/do53/packet.go | 58 +++++--------- internal/protocols/doh/doh.go | 125 ++++++++++++------------------ internal/protocols/dot/dot.go | 119 +++++++++++----------------- 7 files changed, 233 insertions(+), 222 deletions(-) diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go index 6127d83..4ee21a3 100644 --- a/cmd/resolver/main.go +++ b/cmd/resolver/main.go @@ -5,6 +5,7 @@ import ( "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" + "github.com/afonsofrancof/sdns-perf/internal/protocols/doq" "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" "github.com/alecthomas/kong" ) @@ -45,20 +46,35 @@ var cli struct { } func (c *Do53Cmd) Run() error { - return do53.Run(c.DomainName, c.QueryType, c.Server, c.DNSSEC) + do53client, err := do53.New(c.Server) + if err != nil { + return err + } + return do53client.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) } func (c *DoHCmd) Run() error { - return doh.Run(c.DomainName, c.QueryType, c.Server, c.Path,c.Proxy, c.DNSSEC) + dohclient, err := doh.New(c.Server, c.Path, c.Proxy) + if err != nil { + return err + } + return dohclient.Query(c.DomainName, c.QueryType, c.DNSSEC) } func (c *DoTCmd) Run() error { - return dot.Run(c.DomainName, c.QueryType, c.Server, c.DNSSEC) + dotclient, err := dot.New(c.Server) + if err != nil { + return err + } + return dotclient.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) } func (c *DoQCmd) Run() error { - // TODO: Implement DoQ query - return nil + doqclient, err := doq.New(c.Server) + if err != nil { + return err + } + return doqclient.Query(c.DomainName, c.QueryType, c.DNSSEC) } func main() { diff --git a/go.mod b/go.mod index 7b1813b..307936c 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,20 @@ go 1.24.0 require ( github.com/alecthomas/kong v1.8.1 - golang.org/x/net v0.35.0 + github.com/miekg/dns v1.1.63 + github.com/quic-go/quic-go v0.50.0 +) + +require ( + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect + go.uber.org/mock v0.5.0 // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/net v0.35.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/tools v0.22.0 // indirect ) diff --git a/go.sum b/go.sum index f14bc2b..25e4870 100644 --- a/go.sum +++ b/go.sum @@ -4,7 +4,63 @@ github.com/alecthomas/kong v1.8.1 h1:6aamvWBE/REnR/BCq10EcozmcpUPc5aGI1lPAWdB0EE github.com/alecthomas/kong v1.8.1/go.mod h1:p2vqieVMeTAnaC83txKtXe8FLke2X07aruPWXyMPQrU= github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= +github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/miekg/dns v1.1.63 h1:8M5aAw6OMZfFXTT7K5V0Eu5YiiL8l7nUAkyN6C9YwaY= +github.com/miekg/dns v1.1.63/go.mod h1:6NGHfjhpmr5lt3XPLuyfDJi5AXbNIPM9PY6H6sF1Nfs= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo= +github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= +golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/tools v0.22.0 h1:gqSGLZqv+AI9lIQzniJ0nZDRG5GBPsSi+DRNHWNz6yA= +golang.org/x/tools v0.22.0/go.mod h1:aCwcsjqvq7Yqt6TNyX7QMU2enbQ/Gt0bo6krSeEri+c= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/protocols/do53/do53.go b/internal/protocols/do53/do53.go index 8728edd..0ed695b 100644 --- a/internal/protocols/do53/do53.go +++ b/internal/protocols/do53/do53.go @@ -4,60 +4,62 @@ import ( "fmt" "net" - "golang.org/x/net/dns/dnsmessage" + "github.com/miekg/dns" ) -func Run(domain, queryType, dest string, dnssec bool) error { +type Do53Client struct { + udpConn *net.UDPConn +} - message, err := MakeDNSMessage(domain, queryType) - if err != nil { - return err - } +func New(dest string) (*Do53Client, error) { udpAddr, err := net.ResolveUDPAddr("udp", dest) if err != nil { - return fmt.Errorf("failed to resolve UDP address: %v", err) + return nil, fmt.Errorf("failed to resolve UDP address: %v", err) } udpConn, err := net.DialUDP("udp", nil, udpAddr) if err != nil { - return fmt.Errorf("failed to dial UDP connection: %v", err) + return nil, fmt.Errorf("failed to dial UDP connection: %v", err) } - defer udpConn.Close() + return &Do53Client{udpConn: udpConn}, nil +} - _, err = udpConn.Write(message) +func (c *Do53Client) Close() { + if c.udpConn != nil { + c.udpConn.Close() + } +} + +func (c *Do53Client) Query(domain, queryType, dest string, dnssec bool) error { + + message, err := NewDNSMessage(domain, queryType) + if err != nil { + return err + } + + _, err = c.udpConn.Write(message) if err != nil { return fmt.Errorf("failed to send DNS query: %v", err) } buf := make([]byte, 4096) - n, err := udpConn.Read(buf) + n, err := c.udpConn.Read(buf) if err != nil { return fmt.Errorf("failed to read DNS response: %v", err) } - var parser dnsmessage.Parser - _, err = parser.Start(buf[:n]) + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(buf[:n]) if err != nil { return fmt.Errorf("failed to parse DNS response: %v", err) } // TODO: Check if the response had no errors or TD bit set - err = parser.SkipAllQuestions() - if err != nil { - return fmt.Errorf("failed to skip questions: %v", err) - } - - answers, err := parser.AllAnswers() - if err != nil { - return err - } - - for _, answer := range answers { - fmt.Println(answer.GoString()) + for _, answer := range recvMsg.Answer { + fmt.Println(answer.String()) } return nil } - diff --git a/internal/protocols/do53/packet.go b/internal/protocols/do53/packet.go index d287218..29f6a55 100644 --- a/internal/protocols/do53/packet.go +++ b/internal/protocols/do53/packet.go @@ -1,60 +1,40 @@ package do53 import ( - "fmt" - - "golang.org/x/net/dns/dnsmessage" + "github.com/miekg/dns" ) -func MakeDNSMessage(domain string, queryType string) ([]byte, error) { - messageHeader := dnsmessage.Header{ - ID: 1234, // FIX: Use a random ID - Response: false, - OpCode: dnsmessage.OpCode(0), - RecursionDesired: true, - } +func NewDNSMessage(domain string, queryType string) ([]byte, error) { - messageBuilder := dnsmessage.NewBuilder(nil, messageHeader) - queryName, err := dnsmessage.NewName(domain) - if err != nil { - return nil, fmt.Errorf("failed to create query name: %v", err) - } - - // Determine query type - var queryTypeValue dnsmessage.Type + // TODO: Move this somewhere else and receive the type already parsed + var queryTypeValue uint16 switch queryType { case "A": - queryTypeValue = dnsmessage.TypeA + queryTypeValue = dns.TypeA case "AAAA": - queryTypeValue = dnsmessage.TypeAAAA + queryTypeValue = dns.TypeAAAA case "MX": - queryTypeValue = dnsmessage.TypeMX + queryTypeValue = dns.TypeMX case "CNAME": - queryTypeValue = dnsmessage.TypeCNAME + queryTypeValue = dns.TypeCNAME case "TXT": - queryTypeValue = dnsmessage.TypeTXT + queryTypeValue = dns.TypeTXT default: - queryTypeValue = dnsmessage.TypeA + queryTypeValue = dns.TypeA } - messageQuestion := dnsmessage.Question{ - Name: queryName, - Type: queryTypeValue, - Class: dnsmessage.ClassINET, - } + message := new(dns.Msg) - err = messageBuilder.StartQuestions() + message.Id = dns.Id() + message.Response = false + message.Opcode = dns.OpcodeQuery + message.Question = make([]dns.Question, 1) + message.Question[0] = dns.Question{Name: domain, Qtype: uint16(queryTypeValue), Qclass: dns.ClassINET} + message.Compress = true + wireMsg, err := message.Pack() if err != nil { return nil, err } - err = messageBuilder.Question(messageQuestion) - if err != nil { - return nil, fmt.Errorf("failed to add question: %v", err) - } - message, err := messageBuilder.Finish() - if err != nil { - return nil, fmt.Errorf("failed to build message: %v", err) - } - return message, nil + return wireMsg, nil } diff --git a/internal/protocols/doh/doh.go b/internal/protocols/doh/doh.go index 915bac5..63ca869 100644 --- a/internal/protocols/doh/doh.go +++ b/internal/protocols/doh/doh.go @@ -11,37 +11,41 @@ import ( "os" "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" - "golang.org/x/net/dns/dnsmessage" + "github.com/miekg/dns" ) -func Run(domain, queryType, server, path, proxy string, dnssec bool) error { +type DoHClient struct { + tcpConn *net.TCPConn + tlsConn *tls.Conn + keyLogFile *os.File + target string + path string + proxy string +} - DNSMessage, err := do53.MakeDNSMessage(domain, queryType) +func New(target, path, proxy string) (*DoHClient, error) { + + tcpAddr, err := net.ResolveTCPAddr("tcp", target) if err != nil { - return err + return nil, fmt.Errorf("failed to resolve TCP address: %v", err) } - // Step 1 - Establish a TCP Connection - tcpConn, err := net.Dial("tcp", server) + tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil { - return fmt.Errorf("failed to establish TCP connection: %v", err) + return nil, fmt.Errorf("failed to establish TCP connection: %v", err) } - defer tcpConn.Close() - // Step 2 - Upgrade it to a TLS Connection - - // Temporary keylog file to allow traffic inspection keyLogFile, err := os.OpenFile( "tls-key-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600, ) if err != nil { - return fmt.Errorf("failed opening key log file: %v", err) + return nil, fmt.Errorf("failed opening key log file: %v", err) } - defer keyLogFile.Close() tlsConfig := &tls.Config{ + // FIX: Actually check the domain name InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, KeyLogWriter: keyLogFile, @@ -50,24 +54,45 @@ func Run(domain, queryType, server, path, proxy string, dnssec bool) error { tlsConn := tls.Client(tcpConn, tlsConfig) err = tlsConn.Handshake() if err != nil { - return fmt.Errorf("failed to execute the TLS handshake: %v", err) + return nil, fmt.Errorf("failed to execute the TLS handshake: %v", err) } - defer tlsConn.Close() - // Step 3 - Create an HTTP request with the do53 message in the body - httpReq, err := http.NewRequest("POST", "https://"+server+"/"+path, bytes.NewBuffer(DNSMessage)) + return &DoHClient{tcpConn: tcpConn, keyLogFile: keyLogFile, tlsConn: tlsConn, target: target, path: path, proxy: proxy}, err + +} + +func (c *DoHClient) Close() { + if c.tcpConn != nil { + c.tcpConn.Close() + } + if c.keyLogFile != nil { + c.keyLogFile.Close() + } + if c.tlsConn != nil { + c.tlsConn.Close() + } +} + +func (c *DoHClient) Query(domain, queryType string, dnssec bool) error { + + DNSMessage, err := do53.NewDNSMessage(domain, queryType) + if err != nil { + return err + } + + httpReq, err := http.NewRequest("POST", "https://"+c.target+"/"+c.path, bytes.NewBuffer(DNSMessage)) if err != nil { return fmt.Errorf("failed to create HTTP request: %v", err) } httpReq.Header.Add("Content-Type", "application/dns-message") httpReq.Header.Set("Accept", "application/dns-message") - err = httpReq.Write(tlsConn) + err = httpReq.Write(c.tlsConn) if err != nil { return fmt.Errorf("failed writing HTTP request: %v", err) } - reader := bufio.NewReader(tlsConn) + reader := bufio.NewReader(c.tlsConn) resp, err := http.ReadResponse(reader, httpReq) if err != nil { return fmt.Errorf("failed reading HTTP response: %v", err) @@ -80,68 +105,16 @@ func Run(domain, queryType, server, path, proxy string, dnssec bool) error { return fmt.Errorf("failed reading response body: %v", err) } - // Parse the response - var parser dnsmessage.Parser - header, err := parser.Start(responseBody[:n]) + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(responseBody[:n]) if err != nil { return fmt.Errorf("failed to parse DNS response: %v", err) } - fmt.Printf("DNS Response Header:\n") - fmt.Printf(" ID: %d\n", header.ID) - fmt.Printf(" Response: %v\n", header.Response) - fmt.Printf(" RCode: %v\n", header.RCode) + // TODO: Check if the response had no errors or TD bit set - // Skip all questions before reading answers - err = parser.SkipAllQuestions() - if err != nil { - return fmt.Errorf("failed to skip questions: %v", err) - } - - // Parse answers - fmt.Printf("\nAnswers:\n") - answers, err := parser.AllAnswers() - - for i, answer := range answers { - - if err != nil { - return fmt.Errorf("failed to parse answer %d: %v", i, err) - } - - fmt.Printf(" Answer %d:\n", i+1) - fmt.Printf(" Name: %v\n", answer.Header.Name) - fmt.Printf(" Type: %v\n", answer.Header.Type) - fmt.Printf(" TTL: %v seconds\n", answer.Header.TTL) - - // Handle different record types - switch answer.Header.Type { - case dnsmessage.TypeA: - if r, ok := answer.Body.(*dnsmessage.AResource); ok { - fmt.Printf(" IPv4: %d.%d.%d.%d\n", r.A[0], r.A[1], r.A[2], r.A[3]) - } - case dnsmessage.TypeAAAA: - if r, ok := answer.Body.(*dnsmessage.AAAAResource); ok { - ip := r.AAAA - fmt.Printf(" IPv6: %02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x\n", - ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], - ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]) - } - case dnsmessage.TypeCNAME: - if r, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { - fmt.Printf(" CNAME: %v\n", r.CNAME) - } - case dnsmessage.TypeMX: - if r, ok := answer.Body.(*dnsmessage.MXResource); ok { - fmt.Printf(" Preference: %v\n", r.Pref) - fmt.Printf(" MX: %v\n", r.MX) - } - case dnsmessage.TypeTXT: - if r, ok := answer.Body.(*dnsmessage.TXTResource); ok { - fmt.Printf(" TXT: %v\n", r.TXT) - } - default: - fmt.Printf(" [Unsupported record type]\n") - } + for _, answer := range recvMsg.Answer { + fmt.Println(answer.String()) } return nil diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index 448027b..3b2d14d 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -9,35 +9,35 @@ import ( "os" "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" - "golang.org/x/net/dns/dnsmessage" + "github.com/miekg/dns" ) -func Run(domain, queryType, server string, dnssec bool) error { +type DoTClient struct { + tcpConn *net.TCPConn + tlsConn *tls.Conn + keyLogFile *os.File +} - DNSMessage, err := do53.MakeDNSMessage(domain, queryType) +func New(target string) (*DoTClient, error) { + + tcpAddr, err := net.ResolveTCPAddr("tcp", target) if err != nil { - return err + return nil, fmt.Errorf("failed to resolve TCP address: %v", err) } - // Step 1 - Establish a TCP Connection - tcpConn, err := net.Dial("tcp", server) + tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) if err != nil { - return fmt.Errorf("failed to establish TCP connection: %v", err) + return nil, fmt.Errorf("failed to establish TCP connection: %v", err) } - defer tcpConn.Close() - // Step 2 - Upgrade it to a TLS Connection - - // Temporary keylog file to allow traffic inspection keyLogFile, err := os.OpenFile( "tls-key-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600, ) if err != nil { - return fmt.Errorf("failed opening key log file: %v", err) + return nil, fmt.Errorf("failed opening key log file: %v", err) } - defer keyLogFile.Close() tlsConfig := &tls.Config{ InsecureSkipVerify: true, @@ -48,22 +48,43 @@ func Run(domain, queryType, server string, dnssec bool) error { tlsConn := tls.Client(tcpConn, tlsConfig) err = tlsConn.Handshake() if err != nil { - return fmt.Errorf("failed to execute the TLS handshake: %v", err) + return nil, fmt.Errorf("failed to execute the TLS handshake: %v", err) + } + + return &DoTClient{tcpConn: tcpConn, tlsConn: tlsConn, keyLogFile: keyLogFile}, nil +} + +func (c *DoTClient) Close() { + if c.tcpConn != nil { + c.tcpConn.Close() + } + if c.tlsConn != nil { + c.tlsConn.Close() + } + if c.keyLogFile != nil { + c.keyLogFile.Close() + } +} + +func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error { + + DNSMessage, err := do53.NewDNSMessage(domain, queryType) + if err != nil { + return err } - defer tlsConn.Close() var lengthPrefixedMessage bytes.Buffer binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) lengthPrefixedMessage.Write(DNSMessage) - _, err = tlsConn.Write(lengthPrefixedMessage.Bytes()) + _, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes()) if err != nil { return fmt.Errorf("failed writing TLS request: %v", err) } // Read the 2-byte length prefix lengthBuf := make([]byte, 2) - _, err = tlsConn.Read(lengthBuf) + _, err = c.tlsConn.Read(lengthBuf) if err != nil { return fmt.Errorf("failed reading response length: %v", err) } @@ -71,73 +92,21 @@ func Run(domain, queryType, server string, dnssec bool) error { messageLength := binary.BigEndian.Uint16(lengthBuf) responseBuf := make([]byte, messageLength) - n, err := tlsConn.Read(responseBuf) + n, err := c.tlsConn.Read(responseBuf) if err != nil { return fmt.Errorf("failed reading TLS response: %v", err) } - // Parse the response - var parser dnsmessage.Parser - header, err := parser.Start(responseBuf[:n]) + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(responseBuf[:n]) if err != nil { return fmt.Errorf("failed to parse DNS response: %v", err) } - fmt.Printf("DNS Response Header:\n") - fmt.Printf(" ID: %d\n", header.ID) - fmt.Printf(" Response: %v\n", header.Response) - fmt.Printf(" RCode: %v\n", header.RCode) + // TODO: Check if the response had no errors or TD bit set - // Skip all questions before reading answers - err = parser.SkipAllQuestions() - if err != nil { - return fmt.Errorf("failed to skip questions: %v", err) - } - - // Parse answers - fmt.Printf("\nAnswers:\n") - answers, err := parser.AllAnswers() - - for i, answer := range answers { - - if err != nil { - return fmt.Errorf("failed to parse answer %d: %v", i, err) - } - - fmt.Printf(" Answer %d:\n", i+1) - fmt.Printf(" Name: %v\n", answer.Header.Name) - fmt.Printf(" Type: %v\n", answer.Header.Type) - fmt.Printf(" TTL: %v seconds\n", answer.Header.TTL) - - // Handle different record types - switch answer.Header.Type { - case dnsmessage.TypeA: - if r, ok := answer.Body.(*dnsmessage.AResource); ok { - fmt.Printf(" IPv4: %d.%d.%d.%d\n", r.A[0], r.A[1], r.A[2], r.A[3]) - } - case dnsmessage.TypeAAAA: - if r, ok := answer.Body.(*dnsmessage.AAAAResource); ok { - ip := r.AAAA - fmt.Printf(" IPv6: %02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x\n", - ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], - ip[8], ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15]) - } - case dnsmessage.TypeCNAME: - if r, ok := answer.Body.(*dnsmessage.CNAMEResource); ok { - fmt.Printf(" CNAME: %v\n", r.CNAME) - } - case dnsmessage.TypeMX: - if r, ok := answer.Body.(*dnsmessage.MXResource); ok { - fmt.Printf(" Preference: %v\n", r.Pref) - fmt.Printf(" MX: %v\n", r.MX) - } - case dnsmessage.TypeTXT: - if r, ok := answer.Body.(*dnsmessage.TXTResource); ok { - fmt.Printf(" TXT: %v\n", r.TXT) - } - default: - fmt.Printf(" [Unsupported record type]\n") - } + for _, answer := range recvMsg.Answer { + fmt.Println(answer.String()) } return nil From a2fb149e7a24cbb865fd59ca8b4efd5c5a41b808 Mon Sep 17 00:00:00 2001 From: afonso Date: Sat, 1 Mar 2025 16:45:15 +0000 Subject: [PATCH 5/8] basic stdin reader on loop to reuse connections --- cmd/resolver/main.go | 174 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go index 4ee21a3..eef09ff 100644 --- a/cmd/resolver/main.go +++ b/cmd/resolver/main.go @@ -1,7 +1,10 @@ package main import ( + "bufio" "log" + "os" + "strings" "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" @@ -36,6 +39,10 @@ type Do53Cmd struct { CommonFlags } +type Listen struct { + +} + var cli struct { Verbose bool `help:"Enable verbose logging" short:"v"` @@ -43,6 +50,7 @@ var cli struct { DoT DoTCmd `cmd:"dot" help:"Query using DNS-over-TLS" name:"dot"` DoQ DoQCmd `cmd:"doq" help:"Query using DNS-over-QUIC" name:"doq"` Do53 Do53Cmd `cmd:"doq" help:"Query using plain DNS over UDP" name:"do53"` + Listen Listen `cmd:"listen"` } func (c *Do53Cmd) Run() error { @@ -50,6 +58,7 @@ func (c *Do53Cmd) Run() error { if err != nil { return err } + defer do53client.Close() return do53client.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) } @@ -58,6 +67,7 @@ func (c *DoHCmd) Run() error { if err != nil { return err } + defer dohclient.Close() return dohclient.Query(c.DomainName, c.QueryType, c.DNSSEC) } @@ -66,6 +76,7 @@ func (c *DoTCmd) Run() error { if err != nil { return err } + defer dotclient.Close() return dotclient.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) } @@ -74,9 +85,172 @@ func (c *DoQCmd) Run() error { if err != nil { return err } + defer doqclient.Close() return doqclient.Query(c.DomainName, c.QueryType, c.DNSSEC) } +func (l *Listen) Run() error { + // Maps to store clients for reuse + do53Clients := make(map[string]*do53.Do53Client) + dotClients := make(map[string]*dot.DoTClient) + doqClients := make(map[string]*doq.DoQClient) + dohClients := make(map[string]*doh.DoHClient) // Using server+path+proxy as key + + scanner := bufio.NewScanner(os.Stdin) + log.Println("Listening for input. Format: protocol domain server [options]") + + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + + if len(fields) < 3 { + log.Printf("Invalid input: %s. Format should be 'protocol domain server [options]'", line) + continue + } + + protocol := fields[0] + domain := fields[1] + server := fields[2] + + // Default query type and DNSSEC setting + queryType := "A" + dnssec := false + + switch protocol { + case "do53": + // Parse additional options + if len(fields) > 3 { + queryType = fields[3] + } + if len(fields) > 4 && fields[4] == "dnssec" { + dnssec = true + } + + // Check if client exists, if not create it + client, exists := do53Clients[server] + if !exists { + var err error + client, err = do53.New(server) + if err != nil { + log.Printf("Error creating Do53 client: %v", err) + continue + } + do53Clients[server] = client + } + + err := client.Query(domain, queryType, server, dnssec) + if err != nil { + log.Printf("Error querying with Do53: %v", err) + } + + case "dot": + // Parse additional options + if len(fields) > 3 { + queryType = fields[3] + } + if len(fields) > 4 && fields[4] == "dnssec" { + dnssec = true + } + + client, exists := dotClients[server] + if !exists { + var err error + client, err = dot.New(server) + if err != nil { + log.Printf("Error creating DoT client: %v", err) + continue + } + dotClients[server] = client + } + + err := client.Query(domain, queryType, server, dnssec) + if err != nil { + log.Printf("Error querying with DoT: %v", err) + } + + case "doq": + // Parse additional options + if len(fields) > 3 { + queryType = fields[3] + } + if len(fields) > 4 && fields[4] == "dnssec" { + dnssec = true + } + + client, exists := doqClients[server] + if !exists { + var err error + client, err = doq.New(server) + if err != nil { + log.Printf("Error creating DoQ client: %v", err) + continue + } + doqClients[server] = client + } + + err := client.Query(domain, queryType, dnssec) + if err != nil { + log.Printf("Error querying with DoQ: %v", err) + } + + case "doh": + // DoH requires path parameter + if len(fields) < 4 { + log.Printf("DoH requires a path parameter") + continue + } + + path := fields[3] + proxy := "" + + // Parse additional options + if len(fields) > 4 { + queryType = fields[4] + } + + if len(fields) > 5 { + if fields[5] == "dnssec" { + dnssec = true + } else { + proxy = fields[5] + } + } + + if len(fields) > 6 && fields[6] == "dnssec" { + dnssec = true + } + + // Create a composite key for DoH clients + key := server + ":" + path + ":" + proxy + client, exists := dohClients[key] + if !exists { + var err error + client, err = doh.New(server, path, proxy) + if err != nil { + log.Printf("Error creating DoH client: %v", err) + continue + } + dohClients[key] = client + } + + err := client.Query(domain, queryType, dnssec) + if err != nil { + log.Printf("Error querying with DoH: %v", err) + } + + default: + log.Printf("Unknown protocol: %s", protocol) + } + } + + if err := scanner.Err(); err != nil { + return err + } + + return nil +} + + func main() { ctx := kong.Parse(&cli, kong.Name("dns-go"), From 606309f4b16e5dafb044879a17347cddda667012 Mon Sep 17 00:00:00 2001 From: afonso Date: Sat, 1 Mar 2025 16:45:35 +0000 Subject: [PATCH 6/8] Made DoT use io.ReadFull to make sure it reads all the bytes --- internal/protocols/dot/dot.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index 3b2d14d..99ebe2c 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "io" "net" "os" @@ -74,31 +75,39 @@ func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error { } var lengthPrefixedMessage bytes.Buffer - binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) - lengthPrefixedMessage.Write(DNSMessage) + err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) + if err != nil { + return fmt.Errorf("failed to write message length: %v", err) + } + _, err = lengthPrefixedMessage.Write(DNSMessage) + if err != nil { + return fmt.Errorf("failed to write DNS message: %v", err) + } _, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes()) if err != nil { return fmt.Errorf("failed writing TLS request: %v", err) } - // Read the 2-byte length prefix lengthBuf := make([]byte, 2) - _, err = c.tlsConn.Read(lengthBuf) + _, err = io.ReadFull(c.tlsConn, lengthBuf) if err != nil { return fmt.Errorf("failed reading response length: %v", err) } messageLength := binary.BigEndian.Uint16(lengthBuf) + if messageLength == 0 { + return fmt.Errorf("received zero-length message") + } responseBuf := make([]byte, messageLength) - n, err := c.tlsConn.Read(responseBuf) + _, err = io.ReadFull(c.tlsConn, responseBuf) if err != nil { return fmt.Errorf("failed reading TLS response: %v", err) } recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBuf[:n]) + err = recvMsg.Unpack(responseBuf) if err != nil { return fmt.Errorf("failed to parse DNS response: %v", err) } From bf190b2396f94ac30e2134cc6345c90764ae1485 Mon Sep 17 00:00:00 2001 From: afonso Date: Sat, 1 Mar 2025 16:46:27 +0000 Subject: [PATCH 7/8] DoQ first draft. Still need to think about connection reusability --- internal/protocols/doq/doq.go | 116 +++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/internal/protocols/doq/doq.go b/internal/protocols/doq/doq.go index 123b02c..936d0eb 100644 --- a/internal/protocols/doq/doq.go +++ b/internal/protocols/doq/doq.go @@ -1,3 +1,117 @@ package doq -// DoQ (DNS over QUIC) resolver implementation +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + "fmt" + "io" + "os" + + "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" + "github.com/miekg/dns" + "github.com/quic-go/quic-go" +) + +type DoQClient struct { + target string + keyLogFile *os.File + tlsConfig *tls.Config +} + +func New(target string) (*DoQClient, error) { + keyLogFile, err := os.OpenFile( + "tls-key-log.txt", + os.O_APPEND|os.O_CREATE|os.O_WRONLY, + 0600, + ) + if err != nil { + return nil, fmt.Errorf("failed opening key log file: %v", err) + } + + tlsConfig := &tls.Config{ + // FIX: Actually check the domain name + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + KeyLogWriter: keyLogFile, + NextProtos: []string{"doq"}, + } + + return &DoQClient{ + target: target, + keyLogFile: keyLogFile, + tlsConfig: tlsConfig, + }, nil +} + +func (c *DoQClient) Close() { + if c.keyLogFile != nil { + c.keyLogFile.Close() + } +} + +func (c *DoQClient) Query(domain, queryType string, dnssec bool) error { + quicConn, err := quic.DialAddr(context.Background(), c.target, c.tlsConfig, &quic.Config{}) + if err != nil { + return fmt.Errorf("failed to establish QUIC connection: %v", err) + } + defer quicConn.CloseWithError(0, "") + + DNSMessage, err := do53.NewDNSMessage(domain, queryType) + if err != nil { + return err + } + + quicStream, err := quicConn.OpenStreamSync(context.Background()) + if err != nil { + return fmt.Errorf("failed to opening QUIC stream: %v", err) + } + defer quicStream.Close() + + var lengthPrefixedMessage bytes.Buffer + err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) + if err != nil { + return fmt.Errorf("failed to write message length: %v", err) + } + _, err = lengthPrefixedMessage.Write(DNSMessage) + if err != nil { + return fmt.Errorf("failed to write DNS message: %v", err) + } + + _, err = quicStream.Write(lengthPrefixedMessage.Bytes()) + if err != nil { + return fmt.Errorf("failed writing to QUIC stream: %v", err) + } + + lengthBuf := make([]byte, 2) + _, err = io.ReadFull(quicStream, lengthBuf) + if err != nil { + return fmt.Errorf("failed reading response length: %v", err) + } + + messageLength := binary.BigEndian.Uint16(lengthBuf) + if messageLength == 0 { + return fmt.Errorf("received zero-length message") + } + + responseBuf := make([]byte, messageLength) + _, err = io.ReadFull(quicStream, responseBuf) + if err != nil { + return fmt.Errorf("failed reading response data: %v", err) + } + + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(responseBuf) + if err != nil { + return fmt.Errorf("failed to parse DNS response: %v", err) + } + + // TODO: Check if the response had no errors or TD bit set + + for _, answer := range recvMsg.Answer { + fmt.Println(answer.String()) + } + + return nil +} From 2e0042153a112dbf0c4523ba1721c80fbcc3a910 Mon Sep 17 00:00:00 2001 From: afonso Date: Thu, 1 May 2025 12:34:30 +0100 Subject: [PATCH 8/8] Not finished stuff --- cmd/resolver/main.go | 344 ++++++++++---------------------- internal/client/client.go | 195 ++++++++++++++++++ internal/protocols/do53/do53.go | 150 ++++++++++---- internal/protocols/doh/doh.go | 144 ++++++------- internal/protocols/doq/doq.go | 98 +++++++-- internal/protocols/dot/dot.go | 185 ++++++++++------- 6 files changed, 673 insertions(+), 443 deletions(-) create mode 100644 internal/client/client.go diff --git a/cmd/resolver/main.go b/cmd/resolver/main.go index eef09ff..c770c7d 100644 --- a/cmd/resolver/main.go +++ b/cmd/resolver/main.go @@ -1,268 +1,138 @@ package main import ( - "bufio" + "fmt" "log" "os" "strings" + "time" + + "github.com/afonsofrancof/sdns-perf/internal/client" - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" - "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" - "github.com/afonsofrancof/sdns-perf/internal/protocols/doq" - "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" "github.com/alecthomas/kong" + "github.com/miekg/dns" ) -type CommonFlags struct { - DomainName string `help:"Domain name to resolve" arg:"" required:""` - QueryType string `help:"Query type" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` - Server string `help:"DNS server to use" required:""` - DNSSEC bool `help:"Enable DNSSEC validation"` -} - -type DoHCmd struct { - CommonFlags `embed:""` - HTTP3 bool `help:"Use HTTP/3" name:"http3"` - Path string `help:"The HTTP path for the POST request" name:"path" required:""` - Proxy string `help:"The Proxy to use with ODoH"` -} - -type DoTCmd struct { - CommonFlags -} - -type DoQCmd struct { - CommonFlags -} - -type Do53Cmd struct { - CommonFlags -} - -type Listen struct { - -} - var cli struct { - Verbose bool `help:"Enable verbose logging" short:"v"` + // Global flags + Verbose bool `help:"Enable verbose logging." short:"v"` - DoH DoHCmd `cmd:"doh" help:"Query using DNS-over-HTTPS" name:"doh"` - DoT DoTCmd `cmd:"dot" help:"Query using DNS-over-TLS" name:"dot"` - DoQ DoQCmd `cmd:"doq" help:"Query using DNS-over-QUIC" name:"doq"` - Do53 Do53Cmd `cmd:"doq" help:"Query using plain DNS over UDP" name:"do53"` - Listen Listen `cmd:"listen"` + Query QueryCmd `cmd:"" help:"Perform a DNS query (client mode)."` + Listen ListenCmd `cmd:"" help:"Run as a DNS listener/resolver (server mode)."` } -func (c *Do53Cmd) Run() error { - do53client, err := do53.New(c.Server) +type QueryCmd struct { + DomainName string `help:"Domain name to resolve." arg:"" required:""` + Server string `help:"Upstream server address (e.g., https://1.1.1.1/dns-query, tls://1.1.1.1, 8.8.8.8)." short:"s" required:""` + QueryType string `help:"Query type (A, AAAA, MX, TXT, etc.)." short:"t" enum:"A,AAAA,MX,TXT,NS,CNAME,SOA,PTR" default:"A"` + DNSSEC bool `help:"Enable DNSSEC (DO bit)." short:"d"` + Timeout time.Duration `help:"Timeout for the query operation." default:"10s"` // Default might be higher now + KeyLogFile string `help:"Path to TLS key log file (for DoT/DoH/DoQ)." env:"SSLKEYLOGFILE"` +} + +func (q *QueryCmd) Run() error { + log.Printf("Querying %s for %s type %s (DNSSEC: %v, Timeout: %v)\n", + q.Server, q.DomainName, q.QueryType, q.DNSSEC, q.Timeout) + + opts := client.Options{ + Timeout: q.Timeout, + DNSSEC: q.DNSSEC, + KeyLogPath: q.KeyLogFile, + } + + dnsClient, err := client.New(q.Server, opts) if err != nil { return err } - defer do53client.Close() - return do53client.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) -} + defer dnsClient.Close() -func (c *DoHCmd) Run() error { - dohclient, err := doh.New(c.Server, c.Path, c.Proxy) - if err != nil { - return err + qTypeUint, ok := dns.StringToType[strings.ToUpper(q.QueryType)] + if !ok { + return fmt.Errorf("invalid query type: %s", q.QueryType) } - defer dohclient.Close() - return dohclient.Query(c.DomainName, c.QueryType, c.DNSSEC) -} -func (c *DoTCmd) Run() error { - dotclient, err := dot.New(c.Server) + dnsMsg, err := dnsClient.Query(q.DomainName, qTypeUint) if err != nil { - return err + return fmt.Errorf("query failed: %w ", err) } - defer dotclient.Close() - return dotclient.Query(c.DomainName, c.QueryType, c.Server, c.DNSSEC) + + printResponse(q.DomainName, q.QueryType, dnsMsg) + + return nil } -func (c *DoQCmd) Run() error { - doqclient, err := doq.New(c.Server) - if err != nil { - return err +type ListenCmd struct { + Address string `help:"Address to listen on (e.g., :53, :8053)." default:":53"` + // Add other server-specific flags: default upstream, TLS cert/key paths etc. +} + +func (l *ListenCmd) Run() error { + return fmt.Errorf("server/listen mode not yet implemented") +} + +func printResponse(domain, qtype string, msg *dns.Msg) { + fmt.Println(";; QUESTION SECTION:") + + fmt.Printf(";%s.\tIN\t%s\n", dns.Fqdn(domain), strings.ToUpper(qtype)) + + fmt.Println("\n;; ANSWER SECTION:") + if len(msg.Answer) > 0 { + for _, rr := range msg.Answer { + fmt.Println(rr.String()) + } + } else { + fmt.Println(";; No records found in answer section.") } - defer doqclient.Close() - return doqclient.Query(c.DomainName, c.QueryType, c.DNSSEC) -} -func (l *Listen) Run() error { - // Maps to store clients for reuse - do53Clients := make(map[string]*do53.Do53Client) - dotClients := make(map[string]*dot.DoTClient) - doqClients := make(map[string]*doq.DoQClient) - dohClients := make(map[string]*doh.DoHClient) // Using server+path+proxy as key - - scanner := bufio.NewScanner(os.Stdin) - log.Println("Listening for input. Format: protocol domain server [options]") - - for scanner.Scan() { - line := scanner.Text() - fields := strings.Fields(line) - - if len(fields) < 3 { - log.Printf("Invalid input: %s. Format should be 'protocol domain server [options]'", line) - continue - } - - protocol := fields[0] - domain := fields[1] - server := fields[2] - - // Default query type and DNSSEC setting - queryType := "A" - dnssec := false - - switch protocol { - case "do53": - // Parse additional options - if len(fields) > 3 { - queryType = fields[3] - } - if len(fields) > 4 && fields[4] == "dnssec" { - dnssec = true - } - - // Check if client exists, if not create it - client, exists := do53Clients[server] - if !exists { - var err error - client, err = do53.New(server) - if err != nil { - log.Printf("Error creating Do53 client: %v", err) - continue - } - do53Clients[server] = client - } - - err := client.Query(domain, queryType, server, dnssec) - if err != nil { - log.Printf("Error querying with Do53: %v", err) - } - - case "dot": - // Parse additional options - if len(fields) > 3 { - queryType = fields[3] - } - if len(fields) > 4 && fields[4] == "dnssec" { - dnssec = true - } - - client, exists := dotClients[server] - if !exists { - var err error - client, err = dot.New(server) - if err != nil { - log.Printf("Error creating DoT client: %v", err) - continue - } - dotClients[server] = client - } - - err := client.Query(domain, queryType, server, dnssec) - if err != nil { - log.Printf("Error querying with DoT: %v", err) - } - - case "doq": - // Parse additional options - if len(fields) > 3 { - queryType = fields[3] - } - if len(fields) > 4 && fields[4] == "dnssec" { - dnssec = true - } - - client, exists := doqClients[server] - if !exists { - var err error - client, err = doq.New(server) - if err != nil { - log.Printf("Error creating DoQ client: %v", err) - continue - } - doqClients[server] = client - } - - err := client.Query(domain, queryType, dnssec) - if err != nil { - log.Printf("Error querying with DoQ: %v", err) - } - - case "doh": - // DoH requires path parameter - if len(fields) < 4 { - log.Printf("DoH requires a path parameter") - continue - } - - path := fields[3] - proxy := "" - - // Parse additional options - if len(fields) > 4 { - queryType = fields[4] - } - - if len(fields) > 5 { - if fields[5] == "dnssec" { - dnssec = true - } else { - proxy = fields[5] - } - } - - if len(fields) > 6 && fields[6] == "dnssec" { - dnssec = true - } - - // Create a composite key for DoH clients - key := server + ":" + path + ":" + proxy - client, exists := dohClients[key] - if !exists { - var err error - client, err = doh.New(server, path, proxy) - if err != nil { - log.Printf("Error creating DoH client: %v", err) - continue - } - dohClients[key] = client - } - - err := client.Query(domain, queryType, dnssec) - if err != nil { - log.Printf("Error querying with DoH: %v", err) - } - - default: - log.Printf("Unknown protocol: %s", protocol) - } - } - - if err := scanner.Err(); err != nil { - return err - } - - return nil -} + if len(msg.Ns) > 0 { + fmt.Println("\n;; AUTHORITY SECTION:") + for _, rr := range msg.Ns { + fmt.Println(rr.String()) + } + } + if len(msg.Extra) > 0 { + hasRealExtra := false + for _, rr := range msg.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + hasRealExtra = true + break + } + } + if hasRealExtra { + fmt.Println("\n;; ADDITIONAL SECTION:") + for _, rr := range msg.Extra { + if rr.Header().Rrtype != dns.TypeOPT { + fmt.Println(rr.String()) + } + } + } + } + fmt.Printf("\n;; RCODE: %s, ID: %d", dns.RcodeToString[msg.Rcode], msg.Id) + opt := msg.IsEdns0() + if opt != nil { + fmt.Printf(", EDNS: version: %d; flags:", opt.Version()) + if opt.Do() { + fmt.Printf(" do;") + } else { + fmt.Printf(";") + } + fmt.Printf(" udp: %d", opt.UDPSize()) + } + fmt.Println() +} func main() { - ctx := kong.Parse(&cli, - kong.Name("dns-go"), - kong.Description("A DNS resolver supporting DoH, DoT, and DoQ protocols"), - kong.UsageOnError(), - kong.ConfigureHelp(kong.HelpOptions{ - Compact: true, - Summary: true, - })) + log.SetOutput(os.Stderr) + log.SetFlags(log.Ltime | log.Lshortfile) - err := ctx.Run() - if err != nil { - log.Fatalf("Error: %v", err) - } + kongCtx := kong.Parse(&cli, + kong.Name("sdns-perf"), + kong.Description("A DNS client/server tool supporting multiple protocols."), + kong.UsageOnError(), + kong.ConfigureHelp(kong.HelpOptions{Compact: true, Summary: true}), + ) + + err := kongCtx.Run() + kongCtx.FatalIfErrorf(err) } diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000..0a81128 --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,195 @@ +// internal/client/client.go +package client + +import ( + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + "time" + + "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" + "github.com/afonsofrancof/sdns-perf/internal/protocols/doh" + // "github.com/afonsofrancof/sdns-perf/internal/protocols/doq" + // "github.com/afonsofrancof/sdns-perf/internal/protocols/dot" + + "github.com/miekg/dns" +) + +// DNSClient defines the interface that all specific protocol clients must implement. +type DNSClient interface { + Query(domain string, queryType uint16) (*dns.Msg, error) + Close() +} + +// Options holds common configuration options for creating any DNS client. +type Options struct { + Timeout time.Duration + DNSSEC bool + KeyLogPath string // Path for TLS key logging +} + +type protocolType int + +const ( + protoUnknown protocolType = iota + protoDo53 + protoDoT + protoDoH + protoDoH3 + protoDoQ +) + +// config holds the parsed details of an upstream server string. +// This is internal to the client package. +type config struct { + original string + protocol protocolType + host string + port string + path string +} + +// parseUpstream takes a user-provided upstream string and attempts to determine +// the protocol, host, port, and path. (Internal helper) +func parseUpstream(upstreamStr string) (config, error) { + cfg := config{original: upstreamStr, protocol: protoUnknown} + + // Try parsing as a full URL first + parsedURL, err := url.Parse(upstreamStr) + if err == nil && parsedURL.Scheme != "" && parsedURL.Host != "" { + cfg.host = parsedURL.Hostname() + cfg.port = parsedURL.Port() + cfg.path = parsedURL.Path + if cfg.path == "" { + cfg.path = "/" // Default path + } + + switch strings.ToLower(parsedURL.Scheme) { + case "https", "doh": + cfg.protocol = protoDoH + if cfg.port == "" { + cfg.port = "443" + } + case "h3", "doh3": + cfg.protocol = protoDoH3 + if cfg.port == "" { + cfg.port = "443" + } + case "tls", "dot": + cfg.protocol = protoDoT + if cfg.port == "" { + cfg.port = "853" + } + case "quic", "doq": + cfg.protocol = protoDoQ + if cfg.port == "" { + cfg.port = "853" + } + case "udp", "do53": + cfg.protocol = protoDo53 + if cfg.port == "" { + cfg.port = "53" + } + default: + return cfg, fmt.Errorf("unsupported URL scheme: %q", parsedURL.Scheme) + } + return cfg, nil + } + + // If not a valid URL or no scheme, assume plain DNS (Do53 UDP) + cfg.protocol = protoDo53 + host, port, err := net.SplitHostPort(upstreamStr) + if err == nil { + cfg.host = host + cfg.port = port + if _, pErr := strconv.Atoi(port); pErr != nil { + return cfg, fmt.Errorf("invalid port %q in upstream %q: %w", port, upstreamStr, pErr) + } + } else { + cfg.host = upstreamStr + cfg.port = "53" + // Basic check for likely IPv6 without brackets and port + if strings.Contains(cfg.host, ":") && !strings.Contains(cfg.host, "[") { + _, resolveErr := net.ResolveUDPAddr("udp", net.JoinHostPort(cfg.host, cfg.port)) + if resolveErr != nil { + return cfg, fmt.Errorf("invalid upstream format; could not parse %q as host:port or resolve as host with default port 53: %w", upstreamStr, err) + } + } + } + + if cfg.host == "" { + return cfg, fmt.Errorf("could not extract host from upstream: %q", upstreamStr) + } + + return cfg, nil +} + +// New creates the appropriate DNS client based on the upstream string format. +// It returns an uninitialized client (connections are lazy). +func New(upstreamStr string, opts Options) (DNSClient, error) { + cfg, err := parseUpstream(upstreamStr) + if err != nil { + return nil, fmt.Errorf("client: failed to parse upstream %q: %w", upstreamStr, err) + } + + var client DNSClient + var clientErr error + + switch cfg.protocol { + case protoDo53: + // Ensure do53.New matches this signature + config := do53.Config{HostAndPort: net.JoinHostPort(cfg.host, cfg.port), DNSSEC: false} + client, clientErr = do53.New(config) + + case protoDoH: + // Ensure doh.New matches this signature + config := doh.Config{Host: cfg.host, Port: cfg.port, Path: cfg.path, DNSSEC: false} + client, clientErr = doh.New(config) + + case protoDoT: + // Ensure dot.New matches this signature + // client, clientErr = dot.New(cfg.hostPort(), opts.Timeout, opts.DNSSEC, opts.KeyLogPath) + // if clientErr == nil && client == nil { + // clientErr = fmt.Errorf("client: DoT package returned nil client without error") + // } + + case protoDoQ: + // Ensure doq.New matches this signature + // client, clientErr = doq.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath) + // if clientErr == nil && client == nil { + // clientErr = fmt.Errorf("client: DoQ package returned nil client without error") + // } + + case protoDoH3: + // Decide on DoH3 handling (fallback or error) + // Fallback example: + // fmt.Fprintf(os.Stderr, "Warning: DoH3 protocol (h3://) detected for %s. Attempting connection using standard DoH (HTTPS).\n", cfg.original) + // client, clientErr = doh.New(cfg.hostPort(), cfg.path, opts.Timeout, opts.DNSSEC, opts.KeyLogPath) + // Error example: + // clientErr = fmt.Errorf("client: DoH3 protocol (h3://) is not yet supported") + + default: + clientErr = fmt.Errorf("client: unknown or unsupported protocol detected for upstream: %s", upstreamStr) + } + + if clientErr != nil { + return nil, fmt.Errorf("client: failed to create client for %s: %w", upstreamStr, clientErr) + } + if client == nil { + // Should be caught by clientErr checks above, but as a safeguard + return nil, fmt.Errorf("client: internal error - nil client returned for %s", upstreamStr) + } + + return client, nil +} + +// Helper function to close key log writer if needed (can be used by specific clients) +func CloseKeyLogWriter(w io.WriteCloser) error { + if w != nil { + return w.Close() + } + return nil +} diff --git a/internal/protocols/do53/do53.go b/internal/protocols/do53/do53.go index 0ed695b..ee659f9 100644 --- a/internal/protocols/do53/do53.go +++ b/internal/protocols/do53/do53.go @@ -2,64 +2,128 @@ package do53 import ( "fmt" + "log" "net" + "sync" "github.com/miekg/dns" ) -type Do53Client struct { - udpConn *net.UDPConn +type Config struct { + HostAndPort string + DNSSEC bool } -func New(dest string) (*Do53Client, error) { +type Client struct { + udpAddr *net.UDPAddr + conn *net.UDPConn - udpAddr, err := net.ResolveUDPAddr("udp", dest) - if err != nil { - return nil, fmt.Errorf("failed to resolve UDP address: %v", err) - } + responseChannels map[uint16]chan *dns.Msg + responseMutex *sync.Mutex - udpConn, err := net.DialUDP("udp", nil, udpAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial UDP connection: %v", err) - } - return &Do53Client{udpConn: udpConn}, nil + config Config } -func (c *Do53Client) Close() { - if c.udpConn != nil { - c.udpConn.Close() +func New(config Config) (*Client, error) { + udpAddr, err := net.ResolveUDPAddr("udp", config.HostAndPort) + if err != nil { + return nil, fmt.Errorf("do53: failed to resolve UDP address %q: %w", config.HostAndPort, err) + } + + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + return nil, fmt.Errorf("do53: failed to dial UDP connection to %s: %w", config.HostAndPort, err) + } + + responseChannels := map[uint16]chan *dns.Msg{} + rcMutex := new(sync.Mutex) + + client := &Client{ + udpAddr: udpAddr, + conn: conn, + responseChannels: responseChannels, + responseMutex: rcMutex, + config: config, + } + + go client.receiveLoop() + + return client, nil +} + +func (c *Client) Close() { + if c.conn != nil { + c.conn.Close() + c.conn = nil } } -func (c *Do53Client) Query(domain, queryType, dest string, dnssec bool) error { +func (c *Client) receiveLoop() { - message, err := NewDNSMessage(domain, queryType) - if err != nil { - return err + buffer := make([]byte, dns.MaxMsgSize) + + for { + // Reads one UDP Datagram + n, err := c.conn.Read(buffer) + if err != nil { + log.Printf("do53: failed to read DNS response: %s", err.Error()) + } + + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(buffer[:n]) + if err != nil { + log.Printf("do53: failed to unpack DNS response: %s", err.Error()) + continue + } + + c.responseMutex.Lock() + respChan, ok := c.responseChannels[recvMsg.Id] + delete(c.responseChannels, recvMsg.Id) + c.responseMutex.Unlock() + + if ok { + respChan <- recvMsg + } else { + log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id) + } } - _, err = c.udpConn.Write(message) - if err != nil { - return fmt.Errorf("failed to send DNS query: %v", err) - } - - buf := make([]byte, 4096) - n, err := c.udpConn.Read(buf) - if err != nil { - return fmt.Errorf("failed to read DNS response: %v", err) - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(buf[:n]) - if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) - } - - // TODO: Check if the response had no errors or TD bit set - - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } - - return nil +} + +func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { + + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), queryType) + msg.Id = dns.Id() + msg.RecursionDesired = true + + if c.config.DNSSEC { + msg.SetEdns0(4096, true) + } + + respChan := make(chan *dns.Msg) + + c.responseMutex.Lock() + c.responseChannels[msg.Id] = respChan + c.responseMutex.Unlock() + + packedMsg, err := msg.Pack() + if err != nil { + c.responseMutex.Lock() + delete(c.responseChannels, msg.Id) + c.responseMutex.Unlock() + return nil, fmt.Errorf("do53: failed to pack DNS message: %w", err) + } + + _, err = c.conn.Write(packedMsg) + if err != nil { + c.responseMutex.Lock() + delete(c.responseChannels, msg.Id) + c.responseMutex.Unlock() + return nil, fmt.Errorf("do53: failed to send DNS query: %w", err) + } + + recvMsg := <-respChan + + return recvMsg, nil } diff --git a/internal/protocols/doh/doh.go b/internal/protocols/doh/doh.go index 63ca869..fa17cdf 100644 --- a/internal/protocols/doh/doh.go +++ b/internal/protocols/doh/doh.go @@ -1,121 +1,125 @@ package doh import ( - "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "net" "net/http" - "os" + "net/url" + "strings" - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/miekg/dns" ) -type DoHClient struct { - tcpConn *net.TCPConn - tlsConn *tls.Conn - keyLogFile *os.File - target string - path string - proxy string +const dnsMessageContentType = "application/dns-message" + +type Config struct { + Host string + Port string + Path string + DNSSEC bool } -func New(target, path, proxy string) (*DoHClient, error) { +type Client struct { + httpClient *http.Client + upstreamURL *url.URL + config Config +} - tcpAddr, err := net.ResolveTCPAddr("tcp", target) - if err != nil { - return nil, fmt.Errorf("failed to resolve TCP address: %v", err) +func New(config Config) (*Client, error) { + if config.Host == "" || config.Port == "" || config.Path == "" { + fmt.Printf("%v,%v,%v", config.Host,config.Port,config.Path) + return nil, errors.New("doh: host, port, and path must not be empty") } - tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - return nil, fmt.Errorf("failed to establish TCP connection: %v", err) + if !strings.HasPrefix(config.Path, "/") { + config.Path = "/" + config.Path } + rawURL := "https://" + net.JoinHostPort(config.Host, config.Port) + config.Path - keyLogFile, err := os.OpenFile( - "tls-key-log.txt", - os.O_APPEND|os.O_CREATE|os.O_WRONLY, - 0600, - ) + parsedURL, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("failed opening key log file: %v", err) + return nil, fmt.Errorf("doh: failed to parse constructed URL %q: %w", rawURL, err) } tlsConfig := &tls.Config{ - // FIX: Actually check the domain name - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS12, - KeyLogWriter: keyLogFile, + ServerName: config.Host, + MinVersion: tls.VersionTLS12, } - tlsConn := tls.Client(tcpConn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, fmt.Errorf("failed to execute the TLS handshake: %v", err) + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + ForceAttemptHTTP2: true, } - return &DoHClient{tcpConn: tcpConn, keyLogFile: keyLogFile, tlsConn: tlsConn, target: target, path: path, proxy: proxy}, err + httpClient := &http.Client{ + Transport: transport, + } + return &Client{ + httpClient: httpClient, + upstreamURL: parsedURL, + config: config, + }, nil } -func (c *DoHClient) Close() { - if c.tcpConn != nil { - c.tcpConn.Close() - } - if c.keyLogFile != nil { - c.keyLogFile.Close() - } - if c.tlsConn != nil { - c.tlsConn.Close() - } +// Close cleans up idle connections held by the underlying HTTP transport. +func (c *Client) Close() { + c.httpClient.CloseIdleConnections() } -func (c *DoHClient) Query(domain, queryType string, dnssec bool) error { +func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { + msg := new(dns.Msg) + msg.SetQuestion(dns.Fqdn(domain), queryType) + msg.Id = dns.Id() + msg.RecursionDesired = true - DNSMessage, err := do53.NewDNSMessage(domain, queryType) + if c.config.DNSSEC { + msg.SetEdns0(4096, true) + } + + packedMsg, err := msg.Pack() if err != nil { - return err + return nil, fmt.Errorf("doh: failed to pack DNS message: %w", err) } - httpReq, err := http.NewRequest("POST", "https://"+c.target+"/"+c.path, bytes.NewBuffer(DNSMessage)) + httpReq, err := http.NewRequest("POST", c.upstreamURL.String(), bytes.NewReader(packedMsg)) if err != nil { - return fmt.Errorf("failed to create HTTP request: %v", err) + return nil, fmt.Errorf("doh: failed to create HTTP request object: %w", err) } - httpReq.Header.Add("Content-Type", "application/dns-message") - httpReq.Header.Set("Accept", "application/dns-message") - err = httpReq.Write(c.tlsConn) + httpReq.Header.Set("User-Agent", "sdns-perf") + httpReq.Header.Set("Content-Type", dnsMessageContentType) + httpReq.Header.Set("Accept", dnsMessageContentType) + + httpResp, err := c.httpClient.Do(httpReq) if err != nil { - return fmt.Errorf("failed writing HTTP request: %v", err) + return nil, fmt.Errorf("doh: failed executing HTTP request to %s: %w", c.upstreamURL.Host, err) + } + defer httpResp.Body.Close() + + if httpResp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("doh: received non-200 HTTP status from %s: %s", c.upstreamURL.Host, httpResp.Status) } - reader := bufio.NewReader(c.tlsConn) - resp, err := http.ReadResponse(reader, httpReq) + if ct := httpResp.Header.Get("Content-Type"); ct != dnsMessageContentType { + return nil, fmt.Errorf("doh: unexpected Content-Type from %s: got %q, want %q", c.upstreamURL.Host, ct, dnsMessageContentType) + } + + responseBody, err := io.ReadAll(httpResp.Body) if err != nil { - return fmt.Errorf("failed reading HTTP response: %v", err) - } - defer resp.Body.Close() - - responseBody := make([]byte, 4096) - n, err := resp.Body.Read(responseBody) - if err != nil && err != io.EOF { - return fmt.Errorf("failed reading response body: %v", err) + return nil, fmt.Errorf("doh: failed reading response body from %s: %w", c.upstreamURL.Host, err) } + // Unpack the DNS message recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBody[:n]) + err = recvMsg.Unpack(responseBody) if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) + return nil, fmt.Errorf("doh: failed to unpack DNS response from %s: %w", c.upstreamURL.Host, err) } - // TODO: Check if the response had no errors or TD bit set - - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } - - return nil + return recvMsg, nil } diff --git a/internal/protocols/doq/doq.go b/internal/protocols/doq/doq.go index 936d0eb..d99b95b 100644 --- a/internal/protocols/doq/doq.go +++ b/internal/protocols/doq/doq.go @@ -7,87 +7,144 @@ import ( "encoding/binary" "fmt" "io" + "net" "os" + "time" "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/miekg/dns" "github.com/quic-go/quic-go" ) -type DoQClient struct { - target string +type Client struct { + targetAddr *net.UDPAddr keyLogFile *os.File tlsConfig *tls.Config + udpConn *net.UDPConn + quicConn quic.Connection + quicTransport *quic.Transport + quicConfig *quic.Config } -func New(target string) (*DoQClient, error) { +func New(target string) (*Client, error) { keyLogFile, err := os.OpenFile( "tls-key-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600, ) if err != nil { - return nil, fmt.Errorf("failed opening key log file: %v", err) + return nil, fmt.Errorf("failed opening key log file: %w", err) } tlsConfig := &tls.Config{ // FIX: Actually check the domain name InsecureSkipVerify: true, MinVersion: tls.VersionTLS13, + ClientSessionCache: tls.NewLRUClientSessionCache(100), KeyLogWriter: keyLogFile, NextProtos: []string{"doq"}, } - return &DoQClient{ - target: target, + udpAddr, err := net.ResolveUDPAddr("udp", "0.0.0.0:6000") + if err != nil { + return nil, fmt.Errorf("failed to resolve target address: %w", err) + } + targetAddr, err := net.ResolveUDPAddr("udp", target) + if err != nil { + return nil, err + } + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, fmt.Errorf("failed to connect to target address: %w", err) + } + + quicTransport := quic.Transport{ + Conn: udpConn, + } + + quicConfig := quic.Config{ + // Use the default value of 30 seconds + MaxIdleTimeout: 30 * time.Second, + } + + return &Client{ + targetAddr: targetAddr, keyLogFile: keyLogFile, tlsConfig: tlsConfig, + udpConn: udpConn, + quicConn: nil, + quicTransport: &quicTransport, + quicConfig: &quicConfig, }, nil } -func (c *DoQClient) Close() { +func (c *Client) Close() { if c.keyLogFile != nil { c.keyLogFile.Close() } + if c.udpConn != nil { + c.udpConn.Close() + } } -func (c *DoQClient) Query(domain, queryType string, dnssec bool) error { - quicConn, err := quic.DialAddr(context.Background(), c.target, c.tlsConfig, &quic.Config{}) +func (c *Client) OpenConnection() error { + quicConn, err := c.quicTransport.DialEarly(context.Background(), c.targetAddr, c.tlsConfig, c.quicConfig) if err != nil { - return fmt.Errorf("failed to establish QUIC connection: %v", err) + return err + } + + c.quicConn = quicConn + return nil +} + +func (c *Client) Query(domain, queryType string, dnssec bool) error { + + if c.quicConn == nil { + err := c.OpenConnection() + if err != nil { + return err + } } - defer quicConn.CloseWithError(0, "") DNSMessage, err := do53.NewDNSMessage(domain, queryType) if err != nil { return err } - quicStream, err := quicConn.OpenStreamSync(context.Background()) + var quicStream quic.Stream + quicStream, err = c.quicConn.OpenStream() if err != nil { - return fmt.Errorf("failed to opening QUIC stream: %v", err) + err = c.OpenConnection() + if err != nil { + return err + } + quicStream, err = c.quicConn.OpenStream() + if err != nil { + return err + } } - defer quicStream.Close() var lengthPrefixedMessage bytes.Buffer err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) if err != nil { - return fmt.Errorf("failed to write message length: %v", err) + return fmt.Errorf("failed to write message length: %w", err) } _, err = lengthPrefixedMessage.Write(DNSMessage) if err != nil { - return fmt.Errorf("failed to write DNS message: %v", err) + return fmt.Errorf("failed to write DNS message: %w", err) } _, err = quicStream.Write(lengthPrefixedMessage.Bytes()) if err != nil { - return fmt.Errorf("failed writing to QUIC stream: %v", err) + return fmt.Errorf("failed writing to QUIC stream: %w", err) } + // Indicate that no further data will be written from this side + quicStream.Close() lengthBuf := make([]byte, 2) _, err = io.ReadFull(quicStream, lengthBuf) if err != nil { - return fmt.Errorf("failed reading response length: %v", err) + return fmt.Errorf("failed reading response length: %w", err) } messageLength := binary.BigEndian.Uint16(lengthBuf) @@ -98,17 +155,18 @@ func (c *DoQClient) Query(domain, queryType string, dnssec bool) error { responseBuf := make([]byte, messageLength) _, err = io.ReadFull(quicStream, responseBuf) if err != nil { - return fmt.Errorf("failed reading response data: %v", err) + return fmt.Errorf("failed reading response data: %w", err) } recvMsg := new(dns.Msg) err = recvMsg.Unpack(responseBuf) if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) + return fmt.Errorf("failed to parse DNS response: %w", err) } // TODO: Check if the response had no errors or TD bit set + fmt.Println(c.quicConn.ConnectionState().Used0RTT) for _, answer := range recvMsg.Answer { fmt.Println(answer.String()) } diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index 99ebe2c..d314281 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -1,122 +1,161 @@ package dot import ( - "bytes" + "context" "crypto/tls" "encoding/binary" "fmt" "io" + "log" "net" "os" + "sync" + "time" - "github.com/afonsofrancof/sdns-perf/internal/protocols/do53" "github.com/miekg/dns" ) -type DoTClient struct { - tcpConn *net.TCPConn - tlsConn *tls.Conn - keyLogFile *os.File +type Config struct { + Host string + Port string + DNSSEC bool + Debug bool } -func New(target string) (*DoTClient, error) { +type Client struct { + config Config - tcpAddr, err := net.ResolveTCPAddr("tcp", target) + serverAddr *net.TCPAddr + + tcpConn *net.TCPConn + tlsConn *tls.Conn + tlsConfig *tls.Config + keyLogFile *os.File + + sendChannel chan *dns.Msg + + responseChannels map[uint16]chan *dns.Msg + responseMutex *sync.Mutex +} + +func New(config Config) (*Client, error) { + serverAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(config.Host, config.Port)) if err != nil { - return nil, fmt.Errorf("failed to resolve TCP address: %v", err) + return nil, fmt.Errorf("dot: failed to resolve TCP address %q: %w", config.Host, err) } - tcpConn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - return nil, fmt.Errorf("failed to establish TCP connection: %v", err) - } - - keyLogFile, err := os.OpenFile( - "tls-key-log.txt", - os.O_APPEND|os.O_CREATE|os.O_WRONLY, - 0600, - ) - if err != nil { - return nil, fmt.Errorf("failed opening key log file: %v", err) + var keyLogFile *os.File + if config.Debug { + keyLogFile, err = os.OpenFile( + "tls-key-log.txt", + os.O_APPEND|os.O_CREATE|os.O_WRONLY, + 0600, + ) + if err != nil { + log.Printf("dot: failed opening TLS key log file: %v", err) + keyLogFile = nil + } } tlsConfig := &tls.Config{ - InsecureSkipVerify: true, + ServerName: serverAddr.IP.String(), MinVersion: tls.VersionTLS12, KeyLogWriter: keyLogFile, + ClientSessionCache: tls.NewLRUClientSessionCache(100), } - tlsConn := tls.Client(tcpConn, tlsConfig) - err = tlsConn.Handshake() - if err != nil { - return nil, fmt.Errorf("failed to execute the TLS handshake: %v", err) + client := &Client{ + config: config, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + keyLogFile: keyLogFile, } - return &DoTClient{tcpConn: tcpConn, tlsConn: tlsConn, keyLogFile: keyLogFile}, nil + go client.receiveLoop() + + return client, nil } -func (c *DoTClient) Close() { - if c.tcpConn != nil { - c.tcpConn.Close() - } +func (c *Client) Close() { if c.tlsConn != nil { c.tlsConn.Close() + c.tlsConn = nil } + + if c.tcpConn != nil { + c.tcpConn.Close() + c.tcpConn = nil + } + if c.keyLogFile != nil { c.keyLogFile.Close() + c.keyLogFile = nil } } -func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error { +func (c *Client) receiveLoop() { - DNSMessage, err := do53.NewDNSMessage(domain, queryType) + lengthBuffer := make([]byte, 2) + buffer := make([]byte, dns.MaxMsgSize) + + for { + msgSize, err := io.ReadFull(c.tlsConn, lengthBuffer) + if err != nil { + log.Printf("doh: failed to read the DNS message's size: %s", err.Error()) + // FIX: HANDLE RECONNECTION + } + n, err := io.ReadFull(c.tlsConn, buffer[:msgSize]) + if err != nil { + log.Printf("doh: failed to read the DNS message: %s", err.Error()) + // FIX: HANDLE RECONNECTION + } + + recvMsg := new(dns.Msg) + err = recvMsg.Unpack(buffer[:n]) + if err != nil { + log.Printf("do53: failed to unpack DNS response: %s", err.Error()) + continue + } + + c.responseMutex.Lock() + respChan, ok := c.responseChannels[recvMsg.Id] + delete(c.responseChannels, recvMsg.Id) + c.responseMutex.Unlock() + + if ok { + respChan <- recvMsg + } else { + log.Printf("Receiver: Received DNS response for unknown or already processed msg ID: %v\n", recvMsg.Id) + } + + } + +} + +func (c *Client) connect(ctx context.Context) error { + tcpConn, err := net.DialTCP("tcp", nil, c.serverAddr) if err != nil { - return err + return fmt.Errorf("dot: failed to establish TCP connection: %w", err) } - var lengthPrefixedMessage bytes.Buffer - err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) + c.tcpConn.SetKeepAlive(true) + c.tcpConn.SetKeepAlivePeriod(1 * time.Minute) + + tlsConn := tls.Client(c.tcpConn, c.tlsConfig) + err = tlsConn.HandshakeContext(ctx) if err != nil { - return fmt.Errorf("failed to write message length: %v", err) - } - _, err = lengthPrefixedMessage.Write(DNSMessage) - if err != nil { - return fmt.Errorf("failed to write DNS message: %v", err) + c.tcpConn.Close() + c.tcpConn = nil + return fmt.Errorf("dot: failed to execute the TLS handshake: %w", err) } - _, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes()) - if err != nil { - return fmt.Errorf("failed writing TLS request: %v", err) - } + c.tlsConn = tlsConn - lengthBuf := make([]byte, 2) - _, err = io.ReadFull(c.tlsConn, lengthBuf) - if err != nil { - return fmt.Errorf("failed reading response length: %v", err) - } - - messageLength := binary.BigEndian.Uint16(lengthBuf) - if messageLength == 0 { - return fmt.Errorf("received zero-length message") - } - - responseBuf := make([]byte, messageLength) - _, err = io.ReadFull(c.tlsConn, responseBuf) - if err != nil { - return fmt.Errorf("failed reading TLS response: %v", err) - } - - recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBuf) - if err != nil { - return fmt.Errorf("failed to parse DNS response: %v", err) - } - - // TODO: Check if the response had no errors or TD bit set - - for _, answer := range recvMsg.Answer { - fmt.Println(answer.String()) - } + log.Println("dot: TCP/TLS connection established successfully.") return nil } + +func (c *Client) Query(domain string, queryType uint16) (*dns.Msg, error) { + //TODO +}