diff --git a/internal/protocols/dot/dot.go b/internal/protocols/dot/dot.go index 3b2d14d..99ebe2c 100644 --- a/internal/protocols/dot/dot.go +++ b/internal/protocols/dot/dot.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "encoding/binary" "fmt" + "io" "net" "os" @@ -74,31 +75,39 @@ func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error { } var lengthPrefixedMessage bytes.Buffer - binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) - lengthPrefixedMessage.Write(DNSMessage) + err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) + if err != nil { + return fmt.Errorf("failed to write message length: %v", err) + } + _, err = lengthPrefixedMessage.Write(DNSMessage) + if err != nil { + return fmt.Errorf("failed to write DNS message: %v", err) + } _, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes()) if err != nil { return fmt.Errorf("failed writing TLS request: %v", err) } - // Read the 2-byte length prefix lengthBuf := make([]byte, 2) - _, err = c.tlsConn.Read(lengthBuf) + _, err = io.ReadFull(c.tlsConn, lengthBuf) if err != nil { return fmt.Errorf("failed reading response length: %v", err) } messageLength := binary.BigEndian.Uint16(lengthBuf) + if messageLength == 0 { + return fmt.Errorf("received zero-length message") + } responseBuf := make([]byte, messageLength) - n, err := c.tlsConn.Read(responseBuf) + _, err = io.ReadFull(c.tlsConn, responseBuf) if err != nil { return fmt.Errorf("failed reading TLS response: %v", err) } recvMsg := new(dns.Msg) - err = recvMsg.Unpack(responseBuf[:n]) + err = recvMsg.Unpack(responseBuf) if err != nil { return fmt.Errorf("failed to parse DNS response: %v", err) }