From 3416ca53bc59d41bc708c540e1612d8f2d23e6b3 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 13 Jul 2021 14:07:12 +0300 Subject: [PATCH] Added UDPSize properties to client and server --- client.go | 32 ++++++++++++-- constants.go | 3 -- encrypted_query.go | 4 -- server.go | 8 ++++ server_tcp.go | 2 +- server_test.go | 108 ++++++++++++++++++++++++++++++++++++++++++++- server_udp.go | 6 +-- util.go | 19 ++++++++ 8 files changed, 167 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index ebf1574..8d89911 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,10 @@ import ( type Client struct { Net string // protocol (can be "udp" or "tcp", by default - "udp") Timeout time.Duration // read/write timeout + + // UDPSize is the maximum size of a DNS response (or query) this client can + // sent or receive. If not set, we use dns.MinMsgSize by default. + UDPSize int } // ResolverInfo contains DNSCrypt resolver information necessary for decryption/encryption @@ -158,7 +162,11 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) { } if proto == "udp" { - response := make([]byte, maxQueryLen) + bufSize := c.UDPSize + if bufSize == 0 { + bufSize = dns.MinMsgSize + } + response := make([]byte, bufSize) n, err := conn.Read(response) if err != nil { return nil, err @@ -182,7 +190,12 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) if err != nil { return nil, err } - return q.Encrypt(query, resolverInfo.SharedKey) + b, err := q.Encrypt(query, resolverInfo.SharedKey) + if len(b) > c.maxQuerySize() { + return nil, ErrQueryTooLarge + } + + return b, err } // decrypts decrypts a DNS message using a shared key from the resolver info @@ -212,7 +225,8 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) { query := new(dns.Msg) query.SetQuestion(providerName, dns.TypeTXT) - client := dns.Client{Net: c.Net, UDPSize: uint16(maxQueryLen), Timeout: c.Timeout} + // use 1252 as a UDPSize for this client to make sure the buffer is not too small + client := dns.Client{Net: c.Net, UDPSize: uint16(1252), Timeout: c.Timeout} r, _, err := client.Exchange(query, stamp.ServerAddrStr) if err != nil { return nil, err @@ -284,3 +298,15 @@ func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) { return nil, certErr } + +func (c *Client) maxQuerySize() int { + if c.Net == "tcp" { + return dns.MaxMsgSize + } + + if c.UDPSize > 0 { + return c.UDPSize + } + + return dns.MinMsgSize +} diff --git a/constants.go b/constants.go index 2411d78..aa61f1e 100644 --- a/constants.go +++ b/constants.go @@ -61,9 +61,6 @@ const ( // Some servers do not work if padded length is less than 256. Example: Quad9 minUDPQuestionSize = 256 - // is the maximum allowed query length - maxQueryLen = 1252 - // Minimum possible DNS packet size minDNSPacketSize = 12 + 5 diff --git a/encrypted_query.go b/encrypted_query.go index f64ba8f..b927e02 100644 --- a/encrypted_query.go +++ b/encrypted_query.go @@ -72,10 +72,6 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) ( return nil, ErrEsVersion } - if len(query) > maxQueryLen { - return nil, ErrQueryTooLarge - } - return query, nil } diff --git a/server.go b/server.go index 02c4269..1120999 100644 --- a/server.go +++ b/server.go @@ -48,6 +48,10 @@ type Server struct { // ResolverCert contains resolver certificate. ResolverCert *Cert + // UDPSize is the default buffer size to use to read incoming UDP messages. + // If not set it defaults to dns.MinMsgSize (512 B). + UDPSize int + // Handler to invoke. If nil, uses DefaultHandler. Handler Handler @@ -148,6 +152,10 @@ func (s *Server) init() { s.tcpConns = map[net.Conn]struct{}{} s.udpListeners = map[*net.UDPConn]struct{}{} s.tcpListeners = map[net.Listener]struct{}{} + + if s.UDPSize == 0 { + s.UDPSize = dns.MinMsgSize + } } // isStarted returns true if the server is processing queries right now diff --git a/server_tcp.go b/server_tcp.go index 84f9a8f..f8a1c5a 100644 --- a/server_tcp.go +++ b/server_tcp.go @@ -35,7 +35,7 @@ func (w *TCPResponseWriter) RemoteAddr() net.Addr { // WriteMsg writes DNS message to the client func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error { - m.Truncate(dnsSize("tcp", w.req)) + normalize("tcp", w.req, m) res, err := w.encrypt(m, w.query) if err != nil { diff --git a/server_test.go b/server_test.go index de9cf90..c8796a8 100644 --- a/server_test.go +++ b/server_test.go @@ -56,6 +56,85 @@ func TestServer_ReadTimeout(t *testing.T) { testThisServerRespondMessages(t, "tcp", srv) } +func TestServer_UDPTruncateMessage(t *testing.T) { + // Create a test server that returns large response which should be + // truncated if sent over UDP + srv := newTestServer(t, &testLargeMsgHandler{}) + t.Cleanup(func() { + require.NoError(t, srv.Close()) + }) + + // Create client and connect + client := &Client{ + Timeout: 1 * time.Second, + Net: "udp", + } + serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port) + stamp := dnsstamps.ServerStamp{ + ServerAddrStr: serverAddr, + ServerPk: srv.resolverPk, + ProviderName: srv.server.ProviderName, + Proto: dnsstamps.StampProtoTypeDNSCrypt, + } + ri, err := client.DialStamp(stamp) + require.NoError(t, err) + require.NotNil(t, ri) + + // Send a test message and check that the response was truncated + m := createTestMessage() + res, err := client.Exchange(m, ri) + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, dns.RcodeSuccess, res.Rcode) + require.Len(t, res.Answer, 0) + require.True(t, res.Truncated) +} + +func TestServer_UDPEDNS0_NoTruncate(t *testing.T) { + // Create a test server that returns large response which should be + // truncated if sent over UDP + // However, when EDNS0 is set with the buffer large enough, there should + // be no truncation + srv := newTestServer(t, &testLargeMsgHandler{}) + t.Cleanup(func() { + require.NoError(t, srv.Close()) + }) + + // Create client and connect + client := &Client{ + Timeout: 1 * time.Second, + Net: "udp", + UDPSize: 7000, // make sure the client will be able to read the response + } + serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port) + stamp := dnsstamps.ServerStamp{ + ServerAddrStr: serverAddr, + ServerPk: srv.resolverPk, + ProviderName: srv.server.ProviderName, + Proto: dnsstamps.StampProtoTypeDNSCrypt, + } + ri, err := client.DialStamp(stamp) + require.NoError(t, err) + require.NotNil(t, ri) + + // Send a test message with UDP buffer size large enough + // and check that the response was NOT truncated + m := createTestMessage() + m.Extra = append(m.Extra, &dns.OPT{ + Hdr: dns.RR_Header{ + Name: ".", + Rrtype: dns.TypeOPT, + Class: 2000, // Set large enough UDPSize here + }, + }) + res, err := client.Exchange(m, ri) + require.NoError(t, err) + require.NotNil(t, res) + require.Equal(t, dns.RcodeSuccess, res.Rcode) + require.Len(t, res.Answer, 64) + require.False(t, res.Truncated) +} + func testServerServeCert(t *testing.T, network string) { srv := newTestServer(t, &testHandler{}) t.Cleanup(func() { @@ -193,9 +272,9 @@ type testHandler struct{} // ServeDNS - implements Handler interface func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error { - // Google DNS res := new(dns.Msg) res.SetReply(r) + answer := new(dns.A) answer.Hdr = dns.RR_Header{ Name: r.Question[0].Name, @@ -203,7 +282,34 @@ func (h *testHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error { Ttl: 300, Class: dns.ClassINET, } + // First record is from Google DNS answer.A = net.IPv4(8, 8, 8, 8) res.Answer = append(res.Answer, answer) + + return rw.WriteMsg(res) +} + +// testLargeMsgHandler is a handler that returns a huge response +// used for testing messages truncation +type testLargeMsgHandler struct{} + +// ServeDNS - implements Handler interface +func (h *testLargeMsgHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error { + res := new(dns.Msg) + res.SetReply(r) + + for i := 0; i < 64; i++ { + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: r.Question[0].Name, + Rrtype: dns.TypeA, + Ttl: 300, + Class: dns.ClassINET, + } + answer.A = net.IPv4(127, 0, 0, byte(i)) + res.Answer = append(res.Answer, answer) + } + + res.Compress = true return rw.WriteMsg(res) } diff --git a/server_udp.go b/server_udp.go index 51ba232..edf936b 100644 --- a/server_udp.go +++ b/server_udp.go @@ -20,7 +20,7 @@ type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error) type UDPResponseWriter struct { udpConn *net.UDPConn // UDP connection sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP) - encrypt encryptionFunc // DNSCRypt encryption function + encrypt encryptionFunc // DNSCrypt encryption function req *dns.Msg // DNS query that was processed query EncryptedQuery // DNSCrypt query properties } @@ -40,7 +40,7 @@ func (w *UDPResponseWriter) RemoteAddr() net.Addr { // WriteMsg writes DNS message to the client func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error { - m.Truncate(dnsSize("udp", w.req)) + normalize("udp", w.req, m) res, err := w.encrypt(m, w.query) if err != nil { @@ -157,7 +157,7 @@ func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) { // readUDPMsg reads incoming UDP message func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) { _ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout)) - b := make([]byte, dns.MinMsgSize) + b := make([]byte, s.UDPSize) n, sess, err := dns.ReadFromSessionUDP(l, b) if err != nil { return nil, nil, err diff --git a/util.go b/util.go index 0653b4e..403a6db 100644 --- a/util.go +++ b/util.go @@ -184,6 +184,25 @@ func unpackTxtString(s string) ([]byte, error) { return msg, nil } +// normalize truncates the DNS response if needed depending on the protocol +func normalize(proto string, req *dns.Msg, res *dns.Msg) { + size := dnsSize(proto, req) + // DNSCrypt encryption adds a header to each message, we should + // consider this when truncating a message. + // 64 should cover all cases + size = size - 64 + + // Truncate response message + res.Truncate(size) + + // In case of UDP it is safer to simply remove all response records + // dns.Msg.Truncate method will not consider that we need a response + // shorter than dns.MinMsgSize + if res.Truncated && proto == "udp" { + res.Answer = nil + } +} + // dnsSize returns if buffer size *advertised* in the requests OPT record. // Or when the request was over TCP, we return the maximum allowed size of 64K. func dnsSize(proto string, r *dns.Msg) int {