mirror of
https://github.com/SamTherapy/dnscrypt.git
synced 2024-10-02 16:32:51 +00:00
220 lines
5.2 KiB
Go
220 lines
5.2 KiB
Go
package dnscrypt
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/golibs/log"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// TCPResponseWriter is the ResponseWriter implementation for TCP
|
|
type TCPResponseWriter struct {
|
|
tcpConn net.Conn
|
|
encrypt encryptionFunc
|
|
req *dns.Msg
|
|
query EncryptedQuery
|
|
}
|
|
|
|
// type check
|
|
var _ ResponseWriter = &TCPResponseWriter{}
|
|
|
|
// LocalAddr is the server socket local address
|
|
func (w *TCPResponseWriter) LocalAddr() net.Addr {
|
|
return w.tcpConn.LocalAddr()
|
|
}
|
|
|
|
// RemoteAddr is the client's address
|
|
func (w *TCPResponseWriter) RemoteAddr() net.Addr {
|
|
return w.tcpConn.RemoteAddr()
|
|
}
|
|
|
|
// WriteMsg writes DNS message to the client
|
|
func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
|
|
normalize("tcp", w.req, m)
|
|
|
|
res, err := w.encrypt(m, w.query)
|
|
if err != nil {
|
|
log.Tracef("Failed to encrypt the DNS query: %v", err)
|
|
return err
|
|
}
|
|
|
|
return writePrefixed(res, w.tcpConn)
|
|
}
|
|
|
|
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
|
|
// It blocks the calling goroutine and to stop it you need to close the listener
|
|
// or call Server.Shutdown.
|
|
func (s *Server) ServeTCP(l net.Listener) error {
|
|
err := s.prepareServeTCP(l)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr())
|
|
|
|
// Tracks TCP connection handling goroutines
|
|
tcpWg := &sync.WaitGroup{}
|
|
defer s.cleanUpTCP(tcpWg, l)
|
|
|
|
// Track active goroutine
|
|
s.wg.Add(1)
|
|
|
|
// Serialize the cert right away and prepare it to be sent to the client
|
|
certTxt, err := s.getCertTXT()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for s.isStarted() {
|
|
conn, err := l.Accept()
|
|
|
|
// Check the error code and exit loop if necessary
|
|
if err != nil {
|
|
if !s.isStarted() {
|
|
// Stopped gracefully
|
|
break
|
|
}
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
// Note that timeout errors will be here (i.e. hitting ReadDeadline)
|
|
continue
|
|
}
|
|
if isConnClosed(err) {
|
|
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
|
} else {
|
|
log.Info("got error when reading from UDP listen: %s", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
// If we got here, the connection is alive
|
|
s.lock.Lock()
|
|
// Track the connection to allow unblocking reads on shutdown.
|
|
s.tcpConns[conn] = struct{}{}
|
|
s.lock.Unlock()
|
|
|
|
tcpWg.Add(1)
|
|
go func() {
|
|
// Ignore error here, it is most probably a legit one
|
|
// if not, it's written to the debug log
|
|
_ = s.handleTCPConnection(conn, certTxt)
|
|
|
|
// Clean up
|
|
_ = conn.Close()
|
|
s.lock.Lock()
|
|
delete(s.tcpConns, conn)
|
|
s.lock.Unlock()
|
|
tcpWg.Done()
|
|
}()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// prepareServeTCP prepares the server and listener to serving DNSCrypt
|
|
func (s *Server) prepareServeTCP(l net.Listener) error {
|
|
// Check that server is properly configured
|
|
if !s.validate() {
|
|
return ErrServerConfig
|
|
}
|
|
|
|
// Protect shutdown-related fields
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.initOnce.Do(s.init)
|
|
|
|
// Mark the server as started if needed
|
|
s.started = true
|
|
|
|
// Track an active TCP listener
|
|
s.tcpListeners[l] = struct{}{}
|
|
return nil
|
|
}
|
|
|
|
// cleanUpTCP waits until all TCP messages before cleaning up
|
|
func (s *Server) cleanUpTCP(tcpWg *sync.WaitGroup, l net.Listener) {
|
|
// Wait until all TCP connections are processed
|
|
tcpWg.Wait()
|
|
|
|
// Not using it anymore so can be removed from the active listeners
|
|
s.lock.Lock()
|
|
delete(s.tcpListeners, l)
|
|
s.lock.Unlock()
|
|
|
|
// The work is finished
|
|
s.wg.Done()
|
|
}
|
|
|
|
// handleTCPMsg handles a single TCP message. If this method returns error
|
|
// the connection will be closed
|
|
func (s *Server) handleTCPMsg(b []byte, conn net.Conn, certTxt string) error {
|
|
if len(b) < minDNSPacketSize {
|
|
// Ignore the packets that are too short
|
|
return ErrTooShort
|
|
}
|
|
|
|
// First of all, check for "ClientMagic" in the incoming query
|
|
if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
|
|
// If there's no ClientMagic in the packet, we assume this
|
|
// is a plain DNS query requesting the certificate data
|
|
reply, err := s.handleHandshake(b, certTxt)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process a plain DNS query: %w", err)
|
|
}
|
|
err = writePrefixed(reply, conn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write a response: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// If we got here, this is an encrypted DNSCrypt message
|
|
// We should decrypt it first to get the plain DNS query
|
|
m, q, err := s.decrypt(b)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt incoming message: %w", err)
|
|
}
|
|
rw := &TCPResponseWriter{
|
|
tcpConn: conn,
|
|
encrypt: s.encrypt,
|
|
req: m,
|
|
query: q,
|
|
}
|
|
err = s.serveDNS(rw, m)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to process a DNS query: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleTCPConnection handles all queries that are coming to the
|
|
// specified TCP connection.
|
|
func (s *Server) handleTCPConnection(conn net.Conn, certTxt string) error {
|
|
timeout := defaultReadTimeout
|
|
|
|
for s.isStarted() {
|
|
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
|
|
|
b, err := readPrefixed(conn)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = s.handleTCPMsg(b, conn, certTxt)
|
|
if err != nil {
|
|
log.Debug("failed to process DNS query: %v", err)
|
|
return err
|
|
}
|
|
|
|
timeout = defaultTCPIdleTimeout
|
|
}
|
|
|
|
return nil
|
|
}
|