Made DoT use io.ReadFull to make sure it reads all the bytes
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
@@ -74,31 +75,39 @@ func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error {
|
||||
}
|
||||
|
||||
var lengthPrefixedMessage bytes.Buffer
|
||||
binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage)))
|
||||
lengthPrefixedMessage.Write(DNSMessage)
|
||||
err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage)))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write message length: %v", err)
|
||||
}
|
||||
_, err = lengthPrefixedMessage.Write(DNSMessage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write DNS message: %v", err)
|
||||
}
|
||||
|
||||
_, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed writing TLS request: %v", err)
|
||||
}
|
||||
|
||||
// Read the 2-byte length prefix
|
||||
lengthBuf := make([]byte, 2)
|
||||
_, err = c.tlsConn.Read(lengthBuf)
|
||||
_, err = io.ReadFull(c.tlsConn, lengthBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading response length: %v", err)
|
||||
}
|
||||
|
||||
messageLength := binary.BigEndian.Uint16(lengthBuf)
|
||||
if messageLength == 0 {
|
||||
return fmt.Errorf("received zero-length message")
|
||||
}
|
||||
|
||||
responseBuf := make([]byte, messageLength)
|
||||
n, err := c.tlsConn.Read(responseBuf)
|
||||
_, err = io.ReadFull(c.tlsConn, responseBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading TLS response: %v", err)
|
||||
}
|
||||
|
||||
recvMsg := new(dns.Msg)
|
||||
err = recvMsg.Unpack(responseBuf[:n])
|
||||
err = recvMsg.Unpack(responseBuf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse DNS response: %v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user