1
0
Fork 0
mirror of https://github.com/SamTherapy/dnscrypt.git synced 2024-07-02 21:56:06 +00:00
dnscrypt/server.go
2021-08-31 20:33:03 +03:00

316 lines
8.3 KiB
Go

package dnscrypt
import (
"context"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// default read timeout for all reads
const defaultReadTimeout = 2 * time.Second
// in case of TCP we only use defaultReadTimeout for the first read
// then we start using defaultTCPIdleTimeout
const defaultTCPIdleTimeout = 8 * time.Second
// defaultUDPSize is the default size of the UDP read buffer. The release notes
// for dnscrypt-proxy version 1.1.0-RC1 claim that this size was chosen as the
// maximum one "for compatibility with some scary network setups", and making it
// smaller seems to break things for some people.
//
// See also: https://github.com/AdguardTeam/AdGuardDNS/issues/188.
const defaultUDPSize = 1252
// helper struct that is used in several SetReadDeadline calls
var longTimeAgo = time.Unix(1, 0)
// ServerDNSCrypt is an interface for a DNSCrypt server
type ServerDNSCrypt interface {
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call ServerDNSCrypt.Shutdown.
ServeTCP(l net.Listener) error
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call ServerDNSCrypt.Shutdown.
ServeUDP(l *net.UDPConn) error
// Shutdown tries to gracefully shutdown the server. It waits until all
// connections are processed and only after that it leaves the method.
// If context deadline is specified, it will exit earlier
// or call ServerDNSCrypt.Shutdown.
Shutdown(ctx context.Context) error
}
// Server is a simple DNSCrypt server implementation
type Server struct {
// ProviderName is a DNSCrypt provider name
ProviderName string
// ResolverCert contains resolver certificate.
ResolverCert *Cert
// UDPSize is the default buffer size to use to read incoming UDP messages.
// If not set it defaults to defaultUDPSize (1252 B).
UDPSize int
// Handler to invoke. If nil, uses DefaultHandler.
Handler Handler
// make sure init is called only once
initOnce sync.Once
// Shutdown handling
// --
lock sync.RWMutex // protects access to all the fields below
started bool
wg sync.WaitGroup // active workers (servers)
tcpListeners map[net.Listener]struct{} // track active TCP listeners
udpListeners map[*net.UDPConn]struct{} // track active UDP listeners
tcpConns map[net.Conn]struct{} // track active connections
}
// type check
var _ ServerDNSCrypt = &Server{}
// prepareShutdown - prepares the server to shutdown:
// unblocks reads from all connections related to this server
// marks the server as stopped
// if the server is not started, returns ErrServerNotStarted
func (s *Server) prepareShutdown() error {
s.lock.Lock()
defer s.lock.Unlock()
if !s.started {
log.Info("Server is not started")
return ErrServerNotStarted
}
s.started = false
// These listeners were passed to us from the outside so we cannot close
// them here - this is up to the calling code to do that. Instead of that,
// we call Set(Read)Deadline to unblock goroutines that are currently
// blocked on reading from those listeners.
// For tcpConns we would like to avoid closing them to be able to process
// queries before shutting everything down.
// Unblock reads for all active tcpConns
for conn := range s.tcpConns {
_ = conn.SetReadDeadline(longTimeAgo)
}
// Unblock reads for all active TCP listeners
for l := range s.tcpListeners {
switch v := l.(type) {
case *net.TCPListener:
_ = v.SetDeadline(longTimeAgo)
}
}
// Unblock reads for all active UDP listeners
for l := range s.udpListeners {
_ = l.SetReadDeadline(longTimeAgo)
}
return nil
}
// Shutdown tries to gracefully shutdown the server. It waits until all
// connections are processed and only after that it leaves the method.
// If context deadline is specified, it will exit earlier.
func (s *Server) Shutdown(ctx context.Context) error {
log.Info("Shutting down the DNSCrypt server")
err := s.prepareShutdown()
if err != nil {
return err
}
// Using this channel to wait until all goroutines finish their work
closed := make(chan struct{})
go func() {
s.wg.Wait()
log.Info("Serve goroutines finished their work")
close(closed)
}()
// Wait for either all goroutines finish their work
// Or for the context deadline
select {
case <-closed:
log.Info("DNSCrypt server has been stopped")
case <-ctx.Done():
log.Info("DNSCrypt server shutdown has timed out")
err = ctx.Err()
}
return err
}
// init initializes (lazily) Server properties on startup
// this method is called from Server.ServeTCP and Server.ServeUDP
func (s *Server) init() {
s.tcpConns = map[net.Conn]struct{}{}
s.udpListeners = map[*net.UDPConn]struct{}{}
s.tcpListeners = map[net.Listener]struct{}{}
if s.UDPSize == 0 {
s.UDPSize = defaultUDPSize
}
}
// isStarted returns true if the server is processing queries right now
// it means that Server.ServeTCP and/or Server.ServeUDP have been called
func (s *Server) isStarted() bool {
s.lock.RLock()
started := s.started
s.lock.RUnlock()
return started
}
// serveDNS serves a DNS response
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error {
if r == nil || len(r.Question) != 1 || r.Response {
return ErrInvalidQuery
}
log.Tracef("Handling a DNS query: %s", r.Question[0].Name)
handler := s.Handler
if handler == nil {
handler = DefaultHandler
}
err := handler.ServeDNS(rw, r)
if err != nil {
log.Tracef("Error while handing a DNS query: %v", err)
reply := &dns.Msg{}
reply.SetRcode(r, dns.RcodeServerFailure)
_ = rw.WriteMsg(reply)
}
return nil
}
// encrypt encrypts DNSCrypt response
func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
r := EncryptedResponse{
EsVersion: q.EsVersion,
Nonce: q.Nonce,
}
packet, err := m.Pack()
if err != nil {
return nil, err
}
sharedKey, err := computeSharedKey(q.EsVersion, &s.ResolverCert.ResolverSk, &q.ClientPk)
if err != nil {
return nil, err
}
return r.Encrypt(packet, sharedKey)
}
// decrypt decrypts the incoming message and returns a DNS message to process
func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
q := EncryptedQuery{
EsVersion: s.ResolverCert.EsVersion,
ClientMagic: s.ResolverCert.ClientMagic,
}
msg, err := q.Decrypt(b, s.ResolverCert.ResolverSk)
if err != nil {
// Failed to decrypt, dropping it
return nil, q, err
}
r := new(dns.Msg)
err = r.Unpack(msg)
if err != nil {
// Invalid DNS message, ignore
return nil, q, err
}
return r, q, nil
}
// handleHandshake handles a TXT request that requests certificate data
func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
m := new(dns.Msg)
err := m.Unpack(b)
if err != nil {
// Not a handshake, just ignore it
return nil, err
}
if len(m.Question) != 1 || m.Response {
// Invalid query
return nil, ErrInvalidQuery
}
q := m.Question[0]
providerName := dns.Fqdn(s.ProviderName)
qName := strings.ToLower(q.Name) // important, may be random case
if q.Qtype != dns.TypeTXT || qName != providerName {
// Invalid provider name or type, doing nothing
return nil, ErrInvalidQuery
}
reply := new(dns.Msg)
reply.SetReply(m)
txt := &dns.TXT{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeTXT,
Ttl: 60, // use 60 seconds by default, but it shouldn't matter
Class: dns.ClassINET,
},
Txt: []string{
certTxt,
},
}
reply.Answer = append(reply.Answer, txt)
// These bits are important for the old dnscrypt-proxy versions
reply.Authoritative = true
reply.RecursionAvailable = true
return reply.Pack()
}
// validate checks if the Server config is properly set
func (s *Server) validate() bool {
if s.ResolverCert == nil {
log.Error("ResolverCert must be set")
return false
}
if !s.ResolverCert.VerifyDate() {
log.Error("ResolverCert date is not valid")
return false
}
if s.ProviderName == "" {
log.Error("ProviderName must be set")
return false
}
return true
}
// getCertTXT serializes the cert TXT record that are to be sent to the client
func (s *Server) getCertTXT() (string, error) {
certBuf, err := s.ResolverCert.Serialize()
if err != nil {
return "", err
}
certTxt := packTxtString(certBuf)
return certTxt, nil
}