mirror of
https://github.com/SamTherapy/dnscrypt.git
synced 2024-12-21 16:50:42 +00:00
Added UDPSize properties to client and server
This commit is contained in:
parent
b5bcf754ca
commit
3416ca53bc
8 changed files with 167 additions and 15 deletions
32
client.go
32
client.go
|
@ -16,6 +16,10 @@ import (
|
|||
type Client struct {
|
||||
Net string // protocol (can be "udp" or "tcp", by default - "udp")
|
||||
Timeout time.Duration // read/write timeout
|
||||
|
||||
// UDPSize is the maximum size of a DNS response (or query) this client can
|
||||
// sent or receive. If not set, we use dns.MinMsgSize by default.
|
||||
UDPSize int
|
||||
}
|
||||
|
||||
// ResolverInfo contains DNSCrypt resolver information necessary for decryption/encryption
|
||||
|
@ -158,7 +162,11 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
|
|||
}
|
||||
|
||||
if proto == "udp" {
|
||||
response := make([]byte, maxQueryLen)
|
||||
bufSize := c.UDPSize
|
||||
if bufSize == 0 {
|
||||
bufSize = dns.MinMsgSize
|
||||
}
|
||||
response := make([]byte, bufSize)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -182,7 +190,12 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.Encrypt(query, resolverInfo.SharedKey)
|
||||
b, err := q.Encrypt(query, resolverInfo.SharedKey)
|
||||
if len(b) > c.maxQuerySize() {
|
||||
return nil, ErrQueryTooLarge
|
||||
}
|
||||
|
||||
return b, err
|
||||
}
|
||||
|
||||
// decrypts decrypts a DNS message using a shared key from the resolver info
|
||||
|
@ -212,7 +225,8 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
|
|||
|
||||
query := new(dns.Msg)
|
||||
query.SetQuestion(providerName, dns.TypeTXT)
|
||||
client := dns.Client{Net: c.Net, UDPSize: uint16(maxQueryLen), Timeout: c.Timeout}
|
||||
// use 1252 as a UDPSize for this client to make sure the buffer is not too small
|
||||
client := dns.Client{Net: c.Net, UDPSize: uint16(1252), Timeout: c.Timeout}
|
||||
r, _, err := client.Exchange(query, stamp.ServerAddrStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -284,3 +298,15 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
|
|||
|
||||
return nil, certErr
|
||||
}
|
||||
|
||||
func (c *Client) maxQuerySize() int {
|
||||
if c.Net == "tcp" {
|
||||
return dns.MaxMsgSize
|
||||
}
|
||||
|
||||
if c.UDPSize > 0 {
|
||||
return c.UDPSize
|
||||
}
|
||||
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
|
|
@ -61,9 +61,6 @@ const (
|
|||
// Some servers do not work if padded length is less than 256. Example: Quad9
|
||||
minUDPQuestionSize = 256
|
||||
|
||||
// <max-query-len> is the maximum allowed query length
|
||||
maxQueryLen = 1252
|
||||
|
||||
// Minimum possible DNS packet size
|
||||
minDNSPacketSize = 12 + 5
|
||||
|
||||
|
|
|
@ -72,10 +72,6 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) (
|
|||
return nil, ErrEsVersion
|
||||
}
|
||||
|
||||
if len(query) > maxQueryLen {
|
||||
return nil, ErrQueryTooLarge
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -48,6 +48,10 @@ type Server struct {
|
|||
// ResolverCert contains resolver certificate.
|
||||
ResolverCert *Cert
|
||||
|
||||
// UDPSize is the default buffer size to use to read incoming UDP messages.
|
||||
// If not set it defaults to dns.MinMsgSize (512 B).
|
||||
UDPSize int
|
||||
|
||||
// Handler to invoke. If nil, uses DefaultHandler.
|
||||
Handler Handler
|
||||
|
||||
|
@ -148,6 +152,10 @@ func (s *Server) init() {
|
|||
s.tcpConns = map[net.Conn]struct{}{}
|
||||
s.udpListeners = map[*net.UDPConn]struct{}{}
|
||||
s.tcpListeners = map[net.Listener]struct{}{}
|
||||
|
||||
if s.UDPSize == 0 {
|
||||
s.UDPSize = dns.MinMsgSize
|
||||
}
|
||||
}
|
||||
|
||||
// isStarted returns true if the server is processing queries right now
|
||||
|
|
|
@ -35,7 +35,7 @@ func (w *TCPResponseWriter) RemoteAddr() net.Addr {
|
|||
|
||||
// WriteMsg writes DNS message to the client
|
||||
func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
m.Truncate(dnsSize("tcp", w.req))
|
||||
normalize("tcp", w.req, m)
|
||||
|
||||
res, err := w.encrypt(m, w.query)
|
||||
if err != nil {
|
||||
|
|
108
server_test.go
108
server_test.go
|
@ -56,6 +56,85 @@ func TestServer_ReadTimeout(t *testing.T) {
|
|||
testThisServerRespondMessages(t, "tcp", srv)
|
||||
}
|
||||
|
||||
func TestServer_UDPTruncateMessage(t *testing.T) {
|
||||
// Create a test server that returns large response which should be
|
||||
// truncated if sent over UDP
|
||||
srv := newTestServer(t, &testLargeMsgHandler{})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, srv.Close())
|
||||
})
|
||||
|
||||
// Create client and connect
|
||||
client := &Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Net: "udp",
|
||||
}
|
||||
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
|
||||
stamp := dnsstamps.ServerStamp{
|
||||
ServerAddrStr: serverAddr,
|
||||
ServerPk: srv.resolverPk,
|
||||
ProviderName: srv.server.ProviderName,
|
||||
Proto: dnsstamps.StampProtoTypeDNSCrypt,
|
||||
}
|
||||
ri, err := client.DialStamp(stamp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ri)
|
||||
|
||||
// Send a test message and check that the response was truncated
|
||||
m := createTestMessage()
|
||||
res, err := client.Exchange(m, ri)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, dns.RcodeSuccess, res.Rcode)
|
||||
require.Len(t, res.Answer, 0)
|
||||
require.True(t, res.Truncated)
|
||||
}
|
||||
|
||||
func TestServer_UDPEDNS0_NoTruncate(t *testing.T) {
|
||||
// Create a test server that returns large response which should be
|
||||
// truncated if sent over UDP
|
||||
// However, when EDNS0 is set with the buffer large enough, there should
|
||||
// be no truncation
|
||||
srv := newTestServer(t, &testLargeMsgHandler{})
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, srv.Close())
|
||||
})
|
||||
|
||||
// Create client and connect
|
||||
client := &Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Net: "udp",
|
||||
UDPSize: 7000, // make sure the client will be able to read the response
|
||||
}
|
||||
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
|
||||
stamp := dnsstamps.ServerStamp{
|
||||
ServerAddrStr: serverAddr,
|
||||
ServerPk: srv.resolverPk,
|
||||
ProviderName: srv.server.ProviderName,
|
||||
Proto: dnsstamps.StampProtoTypeDNSCrypt,
|
||||
}
|
||||
ri, err := client.DialStamp(stamp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ri)
|
||||
|
||||
// Send a test message with UDP buffer size large enough
|
||||
// and check that the response was NOT truncated
|
||||
m := createTestMessage()
|
||||
m.Extra = append(m.Extra, &dns.OPT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: ".",
|
||||
Rrtype: dns.TypeOPT,
|
||||
Class: 2000, // Set large enough UDPSize here
|
||||
},
|
||||
})
|
||||
res, err := client.Exchange(m, ri)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, res)
|
||||
require.Equal(t, dns.RcodeSuccess, res.Rcode)
|
||||
require.Len(t, res.Answer, 64)
|
||||
require.False(t, res.Truncated)
|
||||
}
|
||||
|
||||
func testServerServeCert(t *testing.T, network string) {
|
||||
srv := newTestServer(t, &testHandler{})
|
||||
t.Cleanup(func() {
|
||||
|
@ -193,9 +272,9 @@ type testHandler struct{}
|
|||
|
||||
// ServeDNS - implements Handler interface
|
||||
func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
|
||||
// Google DNS
|
||||
res := new(dns.Msg)
|
||||
res.SetReply(r)
|
||||
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
|
@ -203,7 +282,34 @@ func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
|
|||
Ttl: 300,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
// First record is from Google DNS
|
||||
answer.A = net.IPv4(8, 8, 8, 8)
|
||||
res.Answer = append(res.Answer, answer)
|
||||
|
||||
return rw.WriteMsg(res)
|
||||
}
|
||||
|
||||
// testLargeMsgHandler is a handler that returns a huge response
|
||||
// used for testing messages truncation
|
||||
type testLargeMsgHandler struct{}
|
||||
|
||||
// ServeDNS - implements Handler interface
|
||||
func (h *testLargeMsgHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
|
||||
res := new(dns.Msg)
|
||||
res.SetReply(r)
|
||||
|
||||
for i := 0; i < 64; i++ {
|
||||
answer := new(dns.A)
|
||||
answer.Hdr = dns.RR_Header{
|
||||
Name: r.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Ttl: 300,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
answer.A = net.IPv4(127, 0, 0, byte(i))
|
||||
res.Answer = append(res.Answer, answer)
|
||||
}
|
||||
|
||||
res.Compress = true
|
||||
return rw.WriteMsg(res)
|
||||
}
|
||||
|
|
|
@ -20,7 +20,7 @@ type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error)
|
|||
type UDPResponseWriter struct {
|
||||
udpConn *net.UDPConn // UDP connection
|
||||
sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP)
|
||||
encrypt encryptionFunc // DNSCRypt encryption function
|
||||
encrypt encryptionFunc // DNSCrypt encryption function
|
||||
req *dns.Msg // DNS query that was processed
|
||||
query EncryptedQuery // DNSCrypt query properties
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ func (w *UDPResponseWriter) RemoteAddr() net.Addr {
|
|||
|
||||
// WriteMsg writes DNS message to the client
|
||||
func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
m.Truncate(dnsSize("udp", w.req))
|
||||
normalize("udp", w.req, m)
|
||||
|
||||
res, err := w.encrypt(m, w.query)
|
||||
if err != nil {
|
||||
|
@ -157,7 +157,7 @@ func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) {
|
|||
// readUDPMsg reads incoming UDP message
|
||||
func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) {
|
||||
_ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout))
|
||||
b := make([]byte, dns.MinMsgSize)
|
||||
b := make([]byte, s.UDPSize)
|
||||
n, sess, err := dns.ReadFromSessionUDP(l, b)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
|
19
util.go
19
util.go
|
@ -184,6 +184,25 @@ func unpackTxtString(s string) ([]byte, error) {
|
|||
return msg, nil
|
||||
}
|
||||
|
||||
// normalize truncates the DNS response if needed depending on the protocol
|
||||
func normalize(proto string, req *dns.Msg, res *dns.Msg) {
|
||||
size := dnsSize(proto, req)
|
||||
// DNSCrypt encryption adds a header to each message, we should
|
||||
// consider this when truncating a message.
|
||||
// 64 should cover all cases
|
||||
size = size - 64
|
||||
|
||||
// Truncate response message
|
||||
res.Truncate(size)
|
||||
|
||||
// In case of UDP it is safer to simply remove all response records
|
||||
// dns.Msg.Truncate method will not consider that we need a response
|
||||
// shorter than dns.MinMsgSize
|
||||
if res.Truncated && proto == "udp" {
|
||||
res.Answer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// dnsSize returns if buffer size *advertised* in the requests OPT record.
|
||||
// Or when the request was over TCP, we return the maximum allowed size of 64K.
|
||||
func dnsSize(proto string, r *dns.Msg) int {
|
||||
|
|
Loading…
Reference in a new issue