mirror of
https://github.com/SamTherapy/dnscrypt.git
synced 2024-10-02 16:32:51 +00:00
316 lines
8.3 KiB
Go
316 lines
8.3 KiB
Go
package dnscrypt
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/golibs/log"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// default read timeout for all reads
|
|
const defaultReadTimeout = 2 * time.Second
|
|
|
|
// in case of TCP we only use defaultReadTimeout for the first read
|
|
// then we start using defaultTCPIdleTimeout
|
|
const defaultTCPIdleTimeout = 8 * time.Second
|
|
|
|
// defaultUDPSize is the default size of the UDP read buffer. The release notes
|
|
// for dnscrypt-proxy version 1.1.0-RC1 claim that this size was chosen as the
|
|
// maximum one "for compatibility with some scary network setups", and making it
|
|
// smaller seems to break things for some people.
|
|
//
|
|
// See also: https://github.com/AdguardTeam/AdGuardDNS/issues/188.
|
|
const defaultUDPSize = 1252
|
|
|
|
// helper struct that is used in several SetReadDeadline calls
|
|
var longTimeAgo = time.Unix(1, 0)
|
|
|
|
// ServerDNSCrypt is an interface for a DNSCrypt server
|
|
type ServerDNSCrypt interface {
|
|
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
|
|
// It blocks the calling goroutine and to stop it you need to close the listener
|
|
// or call ServerDNSCrypt.Shutdown.
|
|
ServeTCP(l net.Listener) error
|
|
|
|
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
|
|
// It blocks the calling goroutine and to stop it you need to close the listener
|
|
// or call ServerDNSCrypt.Shutdown.
|
|
ServeUDP(l *net.UDPConn) error
|
|
|
|
// Shutdown tries to gracefully shutdown the server. It waits until all
|
|
// connections are processed and only after that it leaves the method.
|
|
// If context deadline is specified, it will exit earlier
|
|
// or call ServerDNSCrypt.Shutdown.
|
|
Shutdown(ctx context.Context) error
|
|
}
|
|
|
|
// Server is a simple DNSCrypt server implementation
|
|
type Server struct {
|
|
// ProviderName is a DNSCrypt provider name
|
|
ProviderName string
|
|
|
|
// ResolverCert contains resolver certificate.
|
|
ResolverCert *Cert
|
|
|
|
// UDPSize is the default buffer size to use to read incoming UDP messages.
|
|
// If not set it defaults to defaultUDPSize (1252 B).
|
|
UDPSize int
|
|
|
|
// Handler to invoke. If nil, uses DefaultHandler.
|
|
Handler Handler
|
|
|
|
// make sure init is called only once
|
|
initOnce sync.Once
|
|
|
|
// Shutdown handling
|
|
// --
|
|
lock sync.RWMutex // protects access to all the fields below
|
|
started bool
|
|
wg sync.WaitGroup // active workers (servers)
|
|
tcpListeners map[net.Listener]struct{} // track active TCP listeners
|
|
udpListeners map[*net.UDPConn]struct{} // track active UDP listeners
|
|
tcpConns map[net.Conn]struct{} // track active connections
|
|
}
|
|
|
|
// type check
|
|
var _ ServerDNSCrypt = &Server{}
|
|
|
|
// prepareShutdown - prepares the server to shutdown:
|
|
// unblocks reads from all connections related to this server
|
|
// marks the server as stopped
|
|
// if the server is not started, returns ErrServerNotStarted
|
|
func (s *Server) prepareShutdown() error {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
if !s.started {
|
|
log.Info("Server is not started")
|
|
return ErrServerNotStarted
|
|
}
|
|
|
|
s.started = false
|
|
|
|
// These listeners were passed to us from the outside so we cannot close
|
|
// them here - this is up to the calling code to do that. Instead of that,
|
|
// we call Set(Read)Deadline to unblock goroutines that are currently
|
|
// blocked on reading from those listeners.
|
|
// For tcpConns we would like to avoid closing them to be able to process
|
|
// queries before shutting everything down.
|
|
|
|
// Unblock reads for all active tcpConns
|
|
for conn := range s.tcpConns {
|
|
_ = conn.SetReadDeadline(longTimeAgo)
|
|
}
|
|
|
|
// Unblock reads for all active TCP listeners
|
|
for l := range s.tcpListeners {
|
|
switch v := l.(type) {
|
|
case *net.TCPListener:
|
|
_ = v.SetDeadline(longTimeAgo)
|
|
}
|
|
}
|
|
|
|
// Unblock reads for all active UDP listeners
|
|
for l := range s.udpListeners {
|
|
_ = l.SetReadDeadline(longTimeAgo)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Shutdown tries to gracefully shutdown the server. It waits until all
|
|
// connections are processed and only after that it leaves the method.
|
|
// If context deadline is specified, it will exit earlier.
|
|
func (s *Server) Shutdown(ctx context.Context) error {
|
|
log.Info("Shutting down the DNSCrypt server")
|
|
|
|
err := s.prepareShutdown()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Using this channel to wait until all goroutines finish their work
|
|
closed := make(chan struct{})
|
|
go func() {
|
|
s.wg.Wait()
|
|
log.Info("Serve goroutines finished their work")
|
|
close(closed)
|
|
}()
|
|
|
|
// Wait for either all goroutines finish their work
|
|
// Or for the context deadline
|
|
select {
|
|
case <-closed:
|
|
log.Info("DNSCrypt server has been stopped")
|
|
case <-ctx.Done():
|
|
log.Info("DNSCrypt server shutdown has timed out")
|
|
err = ctx.Err()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// init initializes (lazily) Server properties on startup
|
|
// this method is called from Server.ServeTCP and Server.ServeUDP
|
|
func (s *Server) init() {
|
|
s.tcpConns = map[net.Conn]struct{}{}
|
|
s.udpListeners = map[*net.UDPConn]struct{}{}
|
|
s.tcpListeners = map[net.Listener]struct{}{}
|
|
|
|
if s.UDPSize == 0 {
|
|
s.UDPSize = defaultUDPSize
|
|
}
|
|
}
|
|
|
|
// isStarted returns true if the server is processing queries right now
|
|
// it means that Server.ServeTCP and/or Server.ServeUDP have been called
|
|
func (s *Server) isStarted() bool {
|
|
s.lock.RLock()
|
|
started := s.started
|
|
s.lock.RUnlock()
|
|
return started
|
|
}
|
|
|
|
// serveDNS serves a DNS response
|
|
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error {
|
|
if r == nil || len(r.Question) != 1 || r.Response {
|
|
return ErrInvalidQuery
|
|
}
|
|
|
|
log.Tracef("Handling a DNS query: %s", r.Question[0].Name)
|
|
|
|
handler := s.Handler
|
|
if handler == nil {
|
|
handler = DefaultHandler
|
|
}
|
|
|
|
err := handler.ServeDNS(rw, r)
|
|
if err != nil {
|
|
log.Tracef("Error while handing a DNS query: %v", err)
|
|
|
|
reply := &dns.Msg{}
|
|
reply.SetRcode(r, dns.RcodeServerFailure)
|
|
_ = rw.WriteMsg(reply)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// encrypt encrypts DNSCrypt response
|
|
func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
|
|
r := EncryptedResponse{
|
|
EsVersion: q.EsVersion,
|
|
Nonce: q.Nonce,
|
|
}
|
|
packet, err := m.Pack()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sharedKey, err := computeSharedKey(q.EsVersion, &s.ResolverCert.ResolverSk, &q.ClientPk)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.Encrypt(packet, sharedKey)
|
|
}
|
|
|
|
// decrypt decrypts the incoming message and returns a DNS message to process
|
|
func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
|
|
q := EncryptedQuery{
|
|
EsVersion: s.ResolverCert.EsVersion,
|
|
ClientMagic: s.ResolverCert.ClientMagic,
|
|
}
|
|
msg, err := q.Decrypt(b, s.ResolverCert.ResolverSk)
|
|
if err != nil {
|
|
// Failed to decrypt, dropping it
|
|
return nil, q, err
|
|
}
|
|
|
|
r := new(dns.Msg)
|
|
err = r.Unpack(msg)
|
|
if err != nil {
|
|
// Invalid DNS message, ignore
|
|
return nil, q, err
|
|
}
|
|
|
|
return r, q, nil
|
|
}
|
|
|
|
// handleHandshake handles a TXT request that requests certificate data
|
|
func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
|
|
m := new(dns.Msg)
|
|
err := m.Unpack(b)
|
|
if err != nil {
|
|
// Not a handshake, just ignore it
|
|
return nil, err
|
|
}
|
|
|
|
if len(m.Question) != 1 || m.Response {
|
|
// Invalid query
|
|
return nil, ErrInvalidQuery
|
|
}
|
|
|
|
q := m.Question[0]
|
|
providerName := dns.Fqdn(s.ProviderName)
|
|
qName := strings.ToLower(q.Name) // important, may be random case
|
|
if q.Qtype != dns.TypeTXT || qName != providerName {
|
|
// Invalid provider name or type, doing nothing
|
|
return nil, ErrInvalidQuery
|
|
}
|
|
|
|
reply := new(dns.Msg)
|
|
reply.SetReply(m)
|
|
txt := &dns.TXT{
|
|
Hdr: dns.RR_Header{
|
|
Name: q.Name,
|
|
Rrtype: dns.TypeTXT,
|
|
Ttl: 60, // use 60 seconds by default, but it shouldn't matter
|
|
Class: dns.ClassINET,
|
|
},
|
|
Txt: []string{
|
|
certTxt,
|
|
},
|
|
}
|
|
reply.Answer = append(reply.Answer, txt)
|
|
|
|
// These bits are important for the old dnscrypt-proxy versions
|
|
reply.Authoritative = true
|
|
reply.RecursionAvailable = true
|
|
return reply.Pack()
|
|
}
|
|
|
|
// validate checks if the Server config is properly set
|
|
func (s *Server) validate() bool {
|
|
if s.ResolverCert == nil {
|
|
log.Error("ResolverCert must be set")
|
|
return false
|
|
}
|
|
|
|
if !s.ResolverCert.VerifyDate() {
|
|
log.Error("ResolverCert date is not valid")
|
|
return false
|
|
}
|
|
|
|
if s.ProviderName == "" {
|
|
log.Error("ProviderName must be set")
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// getCertTXT serializes the cert TXT record that are to be sent to the client
|
|
func (s *Server) getCertTXT() (string, error) {
|
|
certBuf, err := s.ResolverCert.Serialize()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
certTxt := packTxtString(certBuf)
|
|
return certTxt, nil
|
|
}
|