1
0
Fork 0
mirror of https://github.com/SamTherapy/dnscrypt.git synced 2024-10-02 16:32:51 +00:00
dnscrypt/server.go
Andrey Meshkov 5ddb58f703
Graceful shutdown of the DNSCrypt server (#6)
Graceful shutdown of the DNSCrypt server

This PR implements Server.Shutdown(ctx context.Context) method that allows
to shut down the DNSCrypt server gracefully.

Some additional changes that were inadvertently made while doing that:
1. Added benchmark tests
2. Started using dns.ReadFromSessionUDP / dns.WriteToSessionUDP instead of implementing it by ourselves
3. Generally improved tests
4. Added depguard 
5. Improved comments overall in the code
2021-03-19 15:42:48 +03:00

294 lines
7.5 KiB
Go

package dnscrypt
import (
"context"
"net"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// default read timeout for all reads
const defaultReadTimeout = 2 * time.Second
// in case of TCP we only use defaultReadTimeout for the first read
// then we start using defaultTCPIdleTimeout
const defaultTCPIdleTimeout = 8 * time.Second
// helper struct that is used in several SetReadDeadline calls
var longTimeAgo = time.Unix(1, 0)
// ServerDNSCrypt is an interface for a DNSCrypt server
type ServerDNSCrypt interface {
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call ServerDNSCrypt.Shutdown.
ServeTCP(l net.Listener) error
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call ServerDNSCrypt.Shutdown.
ServeUDP(l *net.UDPConn) error
// Shutdown tries to gracefully shutdown the server. It waits until all
// connections are processed and only after that it leaves the method.
// If context deadline is specified, it will exit earlier
// or call ServerDNSCrypt.Shutdown.
Shutdown(ctx context.Context) error
}
// Server is a simple DNSCrypt server implementation
type Server struct {
// ProviderName is a DNSCrypt provider name
ProviderName string
// ResolverCert contains resolver certificate.
ResolverCert *Cert
// Handler to invoke. If nil, uses DefaultHandler.
Handler Handler
// make sure init is called only once
initOnce sync.Once
// Shutdown handling
// --
lock sync.RWMutex // protects access to all the fields below
started bool
wg sync.WaitGroup // active workers (servers)
tcpListeners map[net.Listener]struct{} // track active TCP listeners
udpListeners map[*net.UDPConn]struct{} // track active UDP listeners
tcpConns map[net.Conn]struct{} // track active connections
}
// type check
var _ ServerDNSCrypt = &Server{}
// prepareShutdown - prepares the server to shutdown:
// unblocks reads from all connections related to this server
// marks the server as stopped
// if the server is not started, returns ErrServerNotStarted
func (s *Server) prepareShutdown() error {
s.lock.Lock()
defer s.lock.Unlock()
if !s.started {
log.Info("Server is not started")
return ErrServerNotStarted
}
s.started = false
// These listeners were passed to us from the outside so we cannot close
// them here - this is up to the calling code to do that. Instead of that,
// we call Set(Read)Deadline to unblock goroutines that are currently
// blocked on reading from those listeners.
// For tcpConns we would like to avoid closing them to be able to process
// queries before shutting everything down.
// Unblock reads for all active tcpConns
for conn := range s.tcpConns {
_ = conn.SetReadDeadline(longTimeAgo)
}
// Unblock reads for all active TCP listeners
for l := range s.tcpListeners {
switch v := l.(type) {
case *net.TCPListener:
_ = v.SetDeadline(longTimeAgo)
}
}
// Unblock reads for all active UDP listeners
for l := range s.udpListeners {
_ = l.SetReadDeadline(longTimeAgo)
}
return nil
}
// Shutdown tries to gracefully shutdown the server. It waits until all
// connections are processed and only after that it leaves the method.
// If context deadline is specified, it will exit earlier.
func (s *Server) Shutdown(ctx context.Context) error {
log.Info("Shutting down the DNSCrypt server")
err := s.prepareShutdown()
if err != nil {
return err
}
// Using this channel to wait until all goroutines finish their work
closed := make(chan struct{})
go func() {
s.wg.Wait()
log.Info("Serve goroutines finished their work")
close(closed)
}()
// Wait for either all goroutines finish their work
// Or for the context deadline
select {
case <-closed:
log.Info("DNSCrypt server has been stopped")
case <-ctx.Done():
log.Info("DNSCrypt server shutdown has timed out")
err = ctx.Err()
}
return err
}
// init initializes (lazily) Server properties on startup
// this method is called from Server.ServeTCP and Server.ServeUDP
func (s *Server) init() {
s.tcpConns = map[net.Conn]struct{}{}
s.udpListeners = map[*net.UDPConn]struct{}{}
s.tcpListeners = map[net.Listener]struct{}{}
}
// isStarted returns true if the server is processing queries right now
// it means that Server.ServeTCP and/or Server.ServeUDP have been called
func (s *Server) isStarted() bool {
s.lock.RLock()
started := s.started
s.lock.RUnlock()
return started
}
// serveDNS serves DNS response
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error {
if r == nil || len(r.Question) != 1 || r.Response {
return ErrInvalidQuery
}
log.Tracef("Handling a DNS query: %s", r.Question[0].Name)
handler := s.Handler
if handler == nil {
handler = DefaultHandler
}
err := handler.ServeDNS(rw, r)
if err != nil {
log.Tracef("Error while handing a DNS query: %v", err)
reply := &dns.Msg{}
reply.SetRcode(r, dns.RcodeServerFailure)
_ = rw.WriteMsg(reply)
}
return nil
}
// encrypt encrypts DNSCrypt response
func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
r := EncryptedResponse{
EsVersion: q.EsVersion,
Nonce: q.Nonce,
}
packet, err := m.Pack()
if err != nil {
return nil, err
}
sharedKey, err := computeSharedKey(q.EsVersion, &s.ResolverCert.ResolverSk, &q.ClientPk)
if err != nil {
return nil, err
}
return r.Encrypt(packet, sharedKey)
}
// decrypt decrypts the incoming message and returns a DNS message to process
func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
q := EncryptedQuery{
EsVersion: s.ResolverCert.EsVersion,
ClientMagic: s.ResolverCert.ClientMagic,
}
msg, err := q.Decrypt(b, s.ResolverCert.ResolverSk)
if err != nil {
// Failed to decrypt, dropping it
return nil, q, err
}
r := new(dns.Msg)
err = r.Unpack(msg)
if err != nil {
// Invalid DNS message, ignore
return nil, q, err
}
return r, q, nil
}
// handleHandshake handles a TXT request that requests certificate data
func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
m := new(dns.Msg)
err := m.Unpack(b)
if err != nil {
// Not a handshake, just ignore it
return nil, err
}
if len(m.Question) != 1 || m.Response {
// Invalid query
return nil, ErrInvalidQuery
}
q := m.Question[0]
providerName := dns.Fqdn(s.ProviderName)
if q.Qtype != dns.TypeTXT || q.Name != providerName {
// Invalid provider name or type, doing nothing
return nil, ErrInvalidQuery
}
reply := new(dns.Msg)
reply.SetReply(m)
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeTXT,
Ttl: 60, // use 60 seconds by default, but it shouldn't matter
Class: dns.ClassINET,
},
Txt: []string{
certTxt,
},
}
reply.Answer = append(reply.Answer, txt)
return reply.Pack()
}
// validate checks if the Server config is properly set
func (s *Server) validate() bool {
if s.ResolverCert == nil {
log.Error("ResolverCert must be set")
return false
}
if !s.ResolverCert.VerifyDate() {
log.Error("ResolverCert date is not valid")
return false
}
if s.ProviderName == "" {
log.Error("ProviderName must be set")
return false
}
return true
}
// getCertTXT serializes the cert TXT record that are to be sent to the client
func (s *Server) getCertTXT() (string, error) {
certBuf, err := s.ResolverCert.Serialize()
if err != nil {
return "", err
}
certTxt := packTxtString(certBuf)
return certTxt, nil
}