Made DoT use io.ReadFull to make sure it reads all the bytes
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -74,31 +75,39 @@ func (c *DoTClient) Query(domain, queryType, target string, dnssec bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var lengthPrefixedMessage bytes.Buffer
|
var lengthPrefixedMessage bytes.Buffer
|
||||||
binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage)))
|
err = binary.Write(&lengthPrefixedMessage, binary.BigEndian, uint16(len(DNSMessage)))
|
||||||
lengthPrefixedMessage.Write(DNSMessage)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write message length: %v", err)
|
||||||
|
}
|
||||||
|
_, err = lengthPrefixedMessage.Write(DNSMessage)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write DNS message: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes())
|
_, err = c.tlsConn.Write(lengthPrefixedMessage.Bytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed writing TLS request: %v", err)
|
return fmt.Errorf("failed writing TLS request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the 2-byte length prefix
|
|
||||||
lengthBuf := make([]byte, 2)
|
lengthBuf := make([]byte, 2)
|
||||||
_, err = c.tlsConn.Read(lengthBuf)
|
_, err = io.ReadFull(c.tlsConn, lengthBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading response length: %v", err)
|
return fmt.Errorf("failed reading response length: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
messageLength := binary.BigEndian.Uint16(lengthBuf)
|
messageLength := binary.BigEndian.Uint16(lengthBuf)
|
||||||
|
if messageLength == 0 {
|
||||||
|
return fmt.Errorf("received zero-length message")
|
||||||
|
}
|
||||||
|
|
||||||
responseBuf := make([]byte, messageLength)
|
responseBuf := make([]byte, messageLength)
|
||||||
n, err := c.tlsConn.Read(responseBuf)
|
_, err = io.ReadFull(c.tlsConn, responseBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed reading TLS response: %v", err)
|
return fmt.Errorf("failed reading TLS response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
recvMsg := new(dns.Msg)
|
recvMsg := new(dns.Msg)
|
||||||
err = recvMsg.Unpack(responseBuf[:n])
|
err = recvMsg.Unpack(responseBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse DNS response: %v", err)
|
return fmt.Errorf("failed to parse DNS response: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user