Made DoT use io.ReadFull to make sure it reads all the bytes

This commit is contained in:
2025-03-01 16:45:35 +00:00
parent a2fb149e7a
commit 606309f4b1

View File

@@ -5,6 +5,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
@@ -74,31 +75,39 @@ func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error {
} }
var lengthPrefixedMessage bytes.Buffer var lengthPrefixedMessage bytes.Buffer
binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage))) err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage)))
lengthPrefixedMessage.Write(DNSMessage) if err != nil {
return fmt.Errorf("failed to write message length: %v", err)
}
_, err = lengthPrefixedMessage.Write(DNSMessage)
if err != nil {
return fmt.Errorf("failed to write DNS message: %v", err)
}
_, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes()) _, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes())
if err != nil { if err != nil {
return fmt.Errorf("failed writing TLS request: %v", err) return fmt.Errorf("failed writing TLS request: %v", err)
} }
// Read the 2-byte length prefix
lengthBuf := make([]byte, 2) lengthBuf := make([]byte, 2)
_, err = c.tlsConn.Read(lengthBuf) _, err = io.ReadFull(c.tlsConn, lengthBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed reading response length: %v", err) return fmt.Errorf("failed reading response length: %v", err)
} }
messageLength := binary.BigEndian.Uint16(lengthBuf) messageLength := binary.BigEndian.Uint16(lengthBuf)
if messageLength == 0 {
return fmt.Errorf("received zero-length message")
}
responseBuf := make([]byte, messageLength) responseBuf := make([]byte, messageLength)
n, err := c.tlsConn.Read(responseBuf) _, err = io.ReadFull(c.tlsConn, responseBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed reading TLS response: %v", err) return fmt.Errorf("failed reading TLS response: %v", err)
} }
recvMsg := new(dns.Msg) recvMsg := new(dns.Msg)
err = recvMsg.Unpack(responseBuf[:n]) err = recvMsg.Unpack(responseBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse DNS response: %v", err) return fmt.Errorf("failed to parse DNS response: %v", err)
} }