1
0
Fork 0
mirror of https://github.com/SamTherapy/dnscrypt.git synced 2024-12-21 16:50:42 +00:00

Graceful shutdown of the DNSCrypt server (#6)

Graceful shutdown of the DNSCrypt server

This PR implements Server.Shutdown(ctx context.Context) method that allows
to shut down the DNSCrypt server gracefully.

Some additional changes that were inadvertently made while doing that:
1. Added benchmark tests
2. Started using dns.ReadFromSessionUDP / dns.WriteToSessionUDP instead of implementing it by ourselves
3. Generally improved tests
4. Added depguard 
5. Improved comments overall in the code
This commit is contained in:
Andrey Meshkov 2021-03-19 15:42:48 +03:00 committed by GitHub
parent d0ae1d198d
commit 5ddb58f703
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 744 additions and 384 deletions

View file

@ -20,6 +20,12 @@ linters-settings:
min-complexity: 20
lll:
line-length: 200
depguard:
list-type: blacklist
include-go-root: false
packages:
- golang.org/x/net/context # we use context
- log # we use github.com/AdguardTeam/golibs/log
linters:
enable:
@ -41,6 +47,7 @@ linters:
- misspell
- stylecheck
- unconvert
- depguard
disable-all: true
fast: true

30
cert.go
View file

@ -8,10 +8,10 @@ import (
"time"
)
// Cert - DNSCrypt server certificate
// Cert is a DNSCrypt server certificate
// See ResolverConfig for more info on how to create one
type Cert struct {
// Serial - a 4 byte serial number in big-endian format. If more than
// Serial is a 4 byte serial number in big-endian format. If more than
// one certificates are valid, the client must prefer the certificate
// with a higher serial number.
Serial uint32
@ -22,36 +22,36 @@ type Cert struct {
// For X25519-XChacha20Poly1305, <es-version> must be 0x00 0x02.
EsVersion CryptoConstruction
// Signature - a 64-byte signature of (<resolver-pk> <client-magic>
// Signature is a 64-byte signature of (<resolver-pk> <client-magic>
// <serial> <ts-start> <ts-end> <extensions>) using the Ed25519 algorithm and the
// provider secret key. Ed25519 must be used in this version of the
// protocol.
Signature [ed25519.SignatureSize]byte
// ResolverPk - the resolver's short-term public key, which is 32 bytes when using X25519.
// ResolverPk is the resolver's short-term public key, which is 32 bytes when using X25519.
// This key is used to encrypt/decrypt DNS queries
ResolverPk [keySize]byte
// ResolverSk - the resolver's short-term private key, which is 32 bytes when using X25519.
// ResolverSk is the resolver's short-term private key, which is 32 bytes when using X25519.
// Note that it's only used in the server implementation and never serialized/deserialized.
// This key is used to encrypt/decrypt DNS queries
ResolverSk [keySize]byte
// ClientMagic - the first 8 bytes of a client query that is to be built
// ClientMagic is the first 8 bytes of a client query that is to be built
// using the information from this certificate. It may be a truncated
// public key. Two valid certificates cannot share the same <client-magic>.
ClientMagic [clientMagicSize]byte
// NotAfter - the date the certificate is valid from, as a big-endian
// NotAfter is the date the certificate is valid from, as a big-endian
// 4-byte unsigned Unix timestamp.
NotBefore uint32
// NotAfter - the date the certificate is valid until (inclusive), as a
// NotAfter is the date the certificate is valid until (inclusive), as a
// big-endian 4-byte unsigned Unix timestamp.
NotAfter uint32
}
// Serialize - serializes the cert to bytes
// Serialize serializes the cert to bytes
// <cert> ::= <cert-magic> <es-version> <protocol-minor-version> <signature>
// <resolver-pk> <client-magic> <serial> <ts-start> <ts-end>
// <extensions>
@ -86,7 +86,7 @@ func (c *Cert) Serialize() ([]byte, error) {
return b, nil
}
// Deserialize - deserializes certificate from a byte array
// Deserialize deserializes certificate from a byte array
// <cert> ::= <cert-magic> <es-version> <protocol-minor-version> <signature>
// <resolver-pk> <client-magic> <serial> <ts-start> <ts-end>
// <extensions>
@ -127,7 +127,7 @@ func (c *Cert) Deserialize(b []byte) error {
return nil
}
// VerifyDate - checks that cert is valid at this moment
// VerifyDate checks that the cert is valid at this moment
func (c *Cert) VerifyDate() bool {
if c.NotBefore >= c.NotAfter {
return false
@ -139,14 +139,14 @@ func (c *Cert) VerifyDate() bool {
return true
}
// VerifySignature - checks if the cert is properly signed with the specified signature
// VerifySignature checks if the cert is properly signed with the specified signature
func (c *Cert) VerifySignature(publicKey ed25519.PublicKey) bool {
b := make([]byte, 52)
c.writeSigned(b)
return ed25519.Verify(publicKey, b, c.Signature[:])
}
// Sign - creates cert.Signature
// Sign creates cert.Signature
func (c *Cert) Sign(privateKey ed25519.PrivateKey) {
b := make([]byte, 52)
c.writeSigned(b)
@ -154,14 +154,14 @@ func (c *Cert) Sign(privateKey ed25519.PrivateKey) {
copy(c.Signature[:64], signature[:64])
}
// String - Cert's string representation
// String Cert's string representation
func (c *Cert) String() string {
return fmt.Sprintf("Certificate Serial=%d NotBefore=%s NotAfter=%s EsVersion=%s",
c.Serial, time.Unix(int64(c.NotBefore), 0).String(),
time.Unix(int64(c.NotAfter), 0).String(), c.EsVersion.String())
}
// writeSigned - writes (<resolver-pk> <client-magic> <serial> <ts-start> <ts-end> <extensions>)
// writeSigned writes (<resolver-pk> <client-magic> <serial> <ts-start> <ts-end> <extensions>)
func (c *Cert) writeSigned(dst []byte) {
// <resolver-pk>
copy(dst[:32], c.ResolverPk[:keySize])

View file

@ -4,6 +4,7 @@ import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"io/ioutil"
"testing"
"time"
@ -21,13 +22,13 @@ func TestCertSerialize(t *testing.T) {
// serialize
b, err := cert.Serialize()
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, 124, len(b))
// check that we can deserialize it
cert2 := Cert{}
err = cert2.Deserialize(b)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, cert.Serial, cert2.Serial)
assert.Equal(t, cert.NotBefore, cert2.NotBefore)
assert.Equal(t, cert.NotAfter, cert2.NotAfter)
@ -39,12 +40,15 @@ func TestCertSerialize(t *testing.T) {
func TestCertDeserialize(t *testing.T) {
// dig -t txt 2.dnscrypt-cert.opendns.com. -p 443 @208.67.220.220
b, err := unpackTxtString("DNSC\\000\\001\\000\\000\\200\\226E:H\\156\\203%\\134\\218\\127]\\168\\239\\027u\\011$\\191\\008\\239\\176F\\133\\017\\171\\161\\219\\154\\142i\\164\\010\\239\\017f\\168dS\\210f\\197\\194\\169\\171w\\2499\\1891\\155<\\130\\218@/\\155\\023v\\153#d\\024\\004\\136\\180\\228K5\\233d\\180\\144\\189\\218\\186\\232%\\162K\\004\\021\\160\\139\\225\\157}\\219\\135\\163<\\215~\\223\\142/qc78aWoo]\\221\\184`]\\221\\184`_\\190\\235\\224")
assert.Nil(t, err)
certBytes, err := ioutil.ReadFile("testdata/dnscrypt-cert.opendns.txt")
assert.NoError(t, err)
b, err := unpackTxtString(string(certBytes))
assert.NoError(t, err)
cert := &Cert{}
err = cert.Deserialize(b)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, uint32(1574811744), cert.Serial)
assert.Equal(t, XSalsa20Poly1305, cert.EsVersion)
assert.Equal(t, uint32(1574811744), cert.NotBefore)
@ -69,7 +73,7 @@ func generateValidCert(t *testing.T) (*Cert, ed25519.PublicKey, ed25519.PrivateK
// generate private key
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
assert.Nil(t, err)
assert.NoError(t, err)
// sign the data
cert.Sign(privateKey)

View file

@ -12,7 +12,7 @@ import (
"github.com/miekg/dns"
)
// Client - DNSCrypt resolver client
// Client is a DNSCrypt resolver client
type Client struct {
Net string // protocol (can be "udp" or "tcp", by default - "udp")
Timeout time.Duration // read/write timeout
@ -124,7 +124,7 @@ func (c *Client) ExchangeConn(conn net.Conn, m *dns.Msg, resolverInfo *ResolverI
return res, nil
}
// writeQuery - writes query to the network connection
// writeQuery writes query to the network connection
// depending on the protocol we may write a 2-byte prefix or not
func (c *Client) writeQuery(conn net.Conn, query []byte) error {
var err error
@ -145,7 +145,7 @@ func (c *Client) writeQuery(conn net.Conn, query []byte) error {
return err
}
// readResponse - reads response from the network connection
// readResponse reads response from the network connection
// depending on the protocol, we may read a 2-byte prefix or not
func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
if c.Timeout > 0 {
@ -171,7 +171,7 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
return readPrefixed(conn)
}
// encrypt - encrypts a DNS message using shared key from the resolver info
// encrypt encrypts a DNS message using shared key from the resolver info
func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) {
q := EncryptedQuery{
EsVersion: resolverInfo.ResolverCert.EsVersion,
@ -185,7 +185,7 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error)
return q.Encrypt(query, resolverInfo.SharedKey)
}
// decrypts - decrypts a DNS message using shared key from the resolver info
// decrypts decrypts a DNS message using a shared key from the resolver info
func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) {
dr := EncryptedResponse{
EsVersion: resolverInfo.ResolverCert.EsVersion,
@ -203,7 +203,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error)
return res, nil
}
// fetchCert - loads DNSCrypt cert from the specified server
// fetchCert loads DNSCrypt cert from the specified server
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
providerName := stamp.ProviderName
if !strings.HasSuffix(providerName, ".") {

View file

@ -63,7 +63,7 @@ func TestTimeoutOnDialExchange(t *testing.T) {
client := Client{Timeout: 300 * time.Millisecond}
serverInfo, err := client.Dial(stampStr)
assert.Nil(t, err)
assert.NoError(t, err)
// Point it to an IP where there's no DNSCrypt server
serverInfo.ServerAddress = "8.8.8.8:5443"
@ -105,12 +105,15 @@ func TestFetchCertPublicResolvers(t *testing.T) {
for _, test := range stamps {
stamp, err := dnsstamps.NewServerStampFromString(test.stampStr)
assert.Nil(t, err)
assert.NoError(t, err)
t.Run(stamp.ProviderName, func(t *testing.T) {
c := &Client{Net: "udp"}
c := &Client{
Net: "udp",
Timeout: time.Second * 5,
}
resolverInfo, err := c.DialStamp(stamp)
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotNil(t, resolverInfo)
assert.True(t, resolverInfo.ResolverCert.VerifyDate())
assert.True(t, resolverInfo.ResolverCert.VerifySignature(stamp.ServerPk))
@ -146,7 +149,7 @@ func TestExchangePublicResolvers(t *testing.T) {
for _, test := range stamps {
stamp, err := dnsstamps.NewServerStampFromString(test.stampStr)
assert.Nil(t, err)
assert.NoError(t, err)
t.Run(stamp.ProviderName, func(t *testing.T) {
checkDNSCryptServer(t, test.stampStr, "udp")
@ -158,12 +161,12 @@ func TestExchangePublicResolvers(t *testing.T) {
func checkDNSCryptServer(t *testing.T, stampStr string, network string) {
client := Client{Net: network, Timeout: 10 * time.Second}
resolverInfo, err := client.Dial(stampStr)
assert.Nil(t, err)
assert.NoError(t, err)
req := createTestMessage()
reply, err := client.Exchange(req, resolverInfo)
assert.Nil(t, err)
assert.NoError(t, err)
assertTestMessageResponse(t, reply)
}
@ -177,7 +180,7 @@ func createTestMessage() *dns.Msg {
return &req
}
func assertTestMessageResponse(t *testing.T, reply *dns.Msg) {
func assertTestMessageResponse(t assert.TestingT, reply *dns.Msg) {
assert.NotNil(t, reply)
assert.Equal(t, 1, len(reply.Answer))
a, ok := reply.Answer[0].(*dns.A)

View file

@ -12,7 +12,7 @@ import (
"gopkg.in/yaml.v3"
)
// ConvertWrapperArgs - "convert-dnscrypt-wrapper" command arguments
// ConvertWrapperArgs is the "convert-dnscrypt-wrapper" command arguments structure
type ConvertWrapperArgs struct {
PrivateKeyFile string `short:"p" long:"private-key" description:"Path to the DNSCrypt resolver private key file that is used for signing certificates. Param is required." required:"true"`
ResolverSkFile string `short:"r" long:"resolver-secret" description:"Path to the Short-term privacy key file for encrypting/decrypting DNS queries. If not specified, resolver_secret and resolver_public will be randomly generated."`
@ -21,7 +21,7 @@ type ConvertWrapperArgs struct {
CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"`
}
// convertWrapper - generates DNSCrypt configuration from both dnscrypt and server private keys
// convertWrapper generates DNSCrypt configuration from both dnscrypt and server private keys
func convertWrapper(args ConvertWrapperArgs) {
log.Info("Generating configuration for %s", args.ProviderName)
@ -71,7 +71,7 @@ func convertWrapper(args ConvertWrapperArgs) {
}
}
// validateRc - verifies that the certificate is correctly
// validateRc verifies that the certificate is correctly
// created and validated for this resolver config. if rc valid returns nil.
func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error {
cert, err := rc.CreateCert()
@ -90,7 +90,7 @@ func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error {
return nil
}
// getResolverPk - calculates public key from private key
// getResolverPk calculates public key from private key
func getResolverPk(private ed25519.PrivateKey) ed25519.PublicKey {
resolverSk := [32]byte{}
resolverPk := [32]byte{}

View file

@ -8,7 +8,7 @@ import (
"gopkg.in/yaml.v3"
)
// GenerateArgs - "generate" command arguments
// GenerateArgs is the "generate" command arguments structure
type GenerateArgs struct {
ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt provider name. Param is required." required:"true"`
Out string `short:"o" long:"out" description:"Path to the resulting config file. Param is required." required:"true"`
@ -16,7 +16,7 @@ type GenerateArgs struct {
CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"`
}
// generate - generates DNSCrypt server configuration
// generate generates a DNSCrypt server configuration
func generate(args GenerateArgs) {
log.Info("Generating configuration for %s", args.ProviderName)

View file

@ -12,7 +12,7 @@ import (
"github.com/miekg/dns"
)
// LookupStampArgs - "lookup-stamp" command arguments
// LookupStampArgs is the "lookup-stamp" command arguments structure
type LookupStampArgs struct {
Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"`
Stamp string `short:"s" long:"stamp" description:"DNSCrypt resolver stamp. Param is required." required:"true"`
@ -20,7 +20,7 @@ type LookupStampArgs struct {
Type string `short:"t" long:"type" description:"DNS query type" default:"A"`
}
// LookupArgs - "lookup" command arguments
// LookupArgs is the "lookup" command arguments structure
type LookupArgs struct {
Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"`
ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt resolver provider name. Param is required." required:"true"`
@ -30,7 +30,7 @@ type LookupArgs struct {
Type string `short:"t" long:"type" description:"DNS query type" default:"A"`
}
// LookupResult - lookup result that contains the cert info and the query response
// LookupResult is the lookup result that contains the cert info and the query response
type LookupResult struct {
Certificate struct {
Serial uint32 `json:"serial"`
@ -42,7 +42,7 @@ type LookupResult struct {
Reply *dns.Msg `json:"reply"`
}
// lookup - performs a DNS lookup, prints DNSCrypt info and lookup results
// lookup performs a DNS lookup, prints DNSCrypt info and lookup results
func lookup(args LookupArgs) {
serverPk, err := dnscrypt.HexDecodeKey(args.PublicKey)
if err != nil {
@ -64,7 +64,7 @@ func lookup(args LookupArgs) {
})
}
// lookupStamp - performs a DNS lookup, prints DNSCrypt cert info and lookup results
// lookupStamp performs a DNS lookup, prints DNSCrypt cert info and lookup results
func lookupStamp(args LookupStampArgs) {
c := &dnscrypt.Client{
Net: args.Network,

View file

@ -4,7 +4,6 @@ import (
"os"
"github.com/AdguardTeam/golibs/log"
goFlags "github.com/jessevdk/go-flags"
)

View file

@ -13,7 +13,7 @@ import (
"gopkg.in/yaml.v3"
)
// ServerArgs - "server" command arguments
// ServerArgs is the "server" command arguments
type ServerArgs struct {
Config string `short:"c" long:"config" description:"Path to the DNSCrypt configuration file. Param is required." required:"true"`
Forward string `short:"f" long:"forward" description:"Forwards DNS queries to the specified address" default:"94.140.14.140:53"`
@ -21,7 +21,7 @@ type ServerArgs struct {
ListenPorts []int `short:"p" long:"port" description:"Listening ports" default:"443"`
}
// server - runs a DNSCrypt server
// server runs a DNSCrypt server
func server(args ServerArgs) {
log.Info("Starting DNSCrypt server")
@ -72,7 +72,7 @@ func server(args ServerArgs) {
}
}
// createListeners - creates listeners for our server
// createListeners creates listeners for our server
func createListeners(args ServerArgs) (tcp []net.Listener, udp []*net.UDPConn) {
for _, addr := range args.ListenAddrs {
ip := net.ParseIP(addr)
@ -101,7 +101,10 @@ type forwardHandler struct {
addr string
}
// ServeDNS - implements Handler interface
// type check
var _ dnscrypt.Handler = &forwardHandler{}
// ServeDNS implements Handler interface
func (f *forwardHandler) ServeDNS(rw dnscrypt.ResponseWriter, r *dns.Msg) error {
res, err := dns.Exchange(r, f.addr)
if err != nil {

View file

@ -1,52 +1,58 @@
package dnscrypt
import "errors"
// Error represents a dnscrypt error.
type Error string
var (
// ErrTooShort - DNS query is shorter than possible
ErrTooShort = errors.New("DNSCrypt message is too short")
func (e Error) Error() string { return "dnscrypt: " + string(e) }
// ErrQueryTooLarge - DNS query is larger than max allowed size
ErrQueryTooLarge = errors.New("DNSCrypt query is too large")
const (
// ErrTooShort means that the DNS query is shorter than possible
ErrTooShort = Error("message is too short")
// ErrEsVersion - cert contains unsupported es-version
ErrEsVersion = errors.New("unsupported es-version")
// ErrQueryTooLarge means that the DNS query is larger than max allowed size
ErrQueryTooLarge = Error("DNSCrypt query is too large")
// ErrInvalidDate - cert is not valid for the current time
ErrInvalidDate = errors.New("cert has invalid ts-start or ts-end")
// ErrEsVersion means that the cert contains unsupported es-version
ErrEsVersion = Error("unsupported es-version")
// ErrInvalidCertSignature - cert has invalid signature
ErrInvalidCertSignature = errors.New("cert has invalid signature")
// ErrInvalidDate means that the cert is not valid for the current time
ErrInvalidDate = Error("cert has invalid ts-start or ts-end")
// ErrInvalidQuery - failed to decrypt a DNSCrypt query
ErrInvalidQuery = errors.New("DNSCrypt query is invalid and cannot be decrypted")
// ErrInvalidCertSignature means that the cert has invalid signature
ErrInvalidCertSignature = Error("cert has invalid signature")
// ErrInvalidClientMagic - client-magic does not match
ErrInvalidClientMagic = errors.New("DNSCrypt query contains invalid client magic")
// ErrInvalidQuery means that it failed to decrypt a DNSCrypt query
ErrInvalidQuery = Error("DNSCrypt query is invalid and cannot be decrypted")
// ErrInvalidResolverMagic - server-magic does not match
ErrInvalidResolverMagic = errors.New("DNSCrypt response contains invalid resolver magic")
// ErrInvalidClientMagic means that client-magic does not match
ErrInvalidClientMagic = Error("DNSCrypt query contains invalid client magic")
// ErrInvalidResponse - failed to decrypt a DNSCrypt response
ErrInvalidResponse = errors.New("DNSCrypt response is invalid and cannot be decrypted")
// ErrInvalidResolverMagic means that server-magic does not match
ErrInvalidResolverMagic = Error("DNSCrypt response contains invalid resolver magic")
// ErrInvalidPadding - failed to unpad a query
ErrInvalidPadding = errors.New("invalid padding")
// ErrInvalidResponse means that it failed to decrypt a DNSCrypt response
ErrInvalidResponse = Error("DNSCrypt response is invalid and cannot be decrypted")
// ErrInvalidDNSStamp - invalid DNS stamp
ErrInvalidDNSStamp = errors.New("invalid DNS stamp")
// ErrInvalidPadding means that it failed to unpad a query
ErrInvalidPadding = Error("invalid padding")
// ErrFailedToFetchCert - failed to fetch DNSCrypt certificate
ErrFailedToFetchCert = errors.New("failed to fetch DNSCrypt certificate")
// ErrInvalidDNSStamp means an invalid DNS stamp
ErrInvalidDNSStamp = Error("invalid DNS stamp")
// ErrCertTooShort - failed to deserialize cert, too short
ErrCertTooShort = errors.New("cert is too short")
// ErrFailedToFetchCert means that it failed to fetch DNSCrypt certificate
ErrFailedToFetchCert = Error("failed to fetch DNSCrypt certificate")
// ErrCertMagic - invalid cert magic
ErrCertMagic = errors.New("invalid cert magic")
// ErrCertTooShort means that it failed to deserialize cert, too short
ErrCertTooShort = Error("cert is too short")
// ErrServerConfig - failed to start the DNSCrypt server - invalid configuration
ErrServerConfig = errors.New("invalid server configuration")
// ErrCertMagic means an invalid cert magic
ErrCertMagic = Error("invalid cert magic")
// ErrServerConfig means that it failed to start the DNSCrypt server - invalid configuration
ErrServerConfig = Error("invalid server configuration")
// ErrServerNotStarted is returned if there's nothing to shutdown
ErrServerNotStarted = Error("server is not started")
)
const (
@ -55,7 +61,7 @@ const (
// Some servers do not work if padded length is less than 256. Example: Quad9
minUDPQuestionSize = 256
// <max-query-len> - maximum allowed query length
// <max-query-len> is the maximum allowed query length
maxQueryLen = 1252
// Minimum possible DNS packet size
@ -68,7 +74,7 @@ const (
// size of the shared key used to encrypt/decrypt messages
sharedKeySize = 32
// ClientMagic - the first 8 bytes of a client query that is to be built
// ClientMagic is the first 8 bytes of a client query that is to be built
// using the information from this certificate. It may be a truncated
// public key. Two valid certificates cannot share the same <client-magic>.
clientMagicSize = 8
@ -82,10 +88,10 @@ const (
)
var (
// certMagic - bytes sequence that must be in the beginning of the serialized cert
// certMagic is a bytes sequence that must be in the beginning of the serialized cert
certMagic = [4]byte{0x44, 0x4e, 0x53, 0x43}
// resolverMagic - byte sequence that must be in the beginning of every response
// resolverMagic is a byte sequence that must be in the beginning of every response
resolverMagic = []byte{0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38}
)

View file

@ -10,19 +10,19 @@ import (
"golang.org/x/crypto/nacl/secretbox"
)
// EncryptedQuery - a structure for encrypting and decrypting client queries
// EncryptedQuery is a structure for encrypting and decrypting client queries
//
// <dnscrypt-query> ::= <client-magic> <client-pk> <client-nonce> <encrypted-query>
// <encrypted-query> ::= AE(<shared-key> <client-nonce> <client-nonce-pad>, <client-query> <client-query-pad>)
type EncryptedQuery struct {
// EsVersion - encryption to use
// EsVersion is the encryption to use
EsVersion CryptoConstruction
// ClientMagic - a 8 byte identifier for the resolver certificate
// ClientMagic is a 8 byte identifier for the resolver certificate
// chosen by the client.
ClientMagic [clientMagicSize]byte
// ClientPk - the client's public key
// ClientPk is the client's public key
ClientPk [keySize]byte
// With a 24 bytes nonce, a question sent by a DNSCrypt client must be
@ -36,7 +36,7 @@ type EncryptedQuery struct {
Nonce [nonceSize]byte
}
// Encrypt - encrypts the specified DNS query, returns encrypted data ready to be sent.
// Encrypt encrypts the specified DNS query, returns encrypted data ready to be sent.
//
// Note that this method will generate a random nonce automatically.
//
@ -79,7 +79,7 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) (
return query, nil
}
// Decrypt - decrypts the client query, returns decrypted DNS packet.
// Decrypt decrypts the client query, returns decrypted DNS packet.
//
// Please note, that before calling this method the following fields must be set:
// * ClientMagic -- to verify the query

View file

@ -23,7 +23,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction)
// Generate client shared key
clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey)
assert.Nil(t, err)
assert.NoError(t, err)
clientMagic := [clientMagicSize]byte{}
_, _ = rand.Read(clientMagic[:])
@ -40,7 +40,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction)
// Encrypt it
encrypted, err := q1.Encrypt(packet, clientSharedKey)
assert.Nil(t, err)
assert.NoError(t, err)
// Now let's try decrypting it
q2 := EncryptedQuery{
@ -50,7 +50,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction)
// Decrypt it
decrypted, err := q2.Decrypt(encrypted, serverSecretKey)
assert.Nil(t, err)
assert.NoError(t, err)
// Check that packet is the same
assert.True(t, bytes.Equal(packet, decrypted))

View file

@ -10,12 +10,12 @@ import (
"golang.org/x/crypto/nacl/secretbox"
)
// EncryptedResponse - structure for encrypting/decrypting server responses
// EncryptedResponse is a structure for encrypting/decrypting server responses
//
// <dnscrypt-response> ::= <resolver-magic> <nonce> <encrypted-response>
// <encrypted-response> ::= AE(<shared-key>, <nonce>, <resolver-response> <resolver-response-pad>)
type EncryptedResponse struct {
// EsVersion - encryption to use
// EsVersion is the encryption to use
EsVersion CryptoConstruction
// Nonce - <nonce> ::= <client-nonce> <resolver-nonce>
@ -23,7 +23,7 @@ type EncryptedResponse struct {
Nonce [nonceSize]byte
}
// Encrypt - encrypts the server response
// Encrypt encrypts the server response
//
// EsVersion must be set.
// Nonce needs to be set to "client-nonce".
@ -57,7 +57,7 @@ func (r *EncryptedResponse) Encrypt(packet []byte, sharedKey [sharedKeySize]byte
return response, nil
}
// Decrypt - decrypts the server response
// Decrypt decrypts the server response
//
// EsVersion must be set.
func (r *EncryptedResponse) Decrypt(response []byte, sharedKey [sharedKeySize]byte) ([]byte, error) {

View file

@ -24,11 +24,11 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi
// Generate client shared key
clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey)
assert.Nil(t, err)
assert.NoError(t, err)
// Generate server shared key
serverSharedKey, err := computeSharedKey(esVersion, &serverSecretKey, &clientPublicKey)
assert.Nil(t, err)
assert.NoError(t, err)
r1 := &EncryptedResponse{
EsVersion: esVersion,
@ -42,7 +42,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi
// Encrypt it
encrypted, err := r1.Encrypt(packet, serverSharedKey)
assert.Nil(t, err)
assert.NoError(t, err)
// Now let's try decrypting it
r2 := &EncryptedResponse{
@ -51,7 +51,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi
// Decrypt it
decrypted, err := r2.Decrypt(encrypted, clientSharedKey)
assert.Nil(t, err)
assert.NoError(t, err)
// Check that packet is the same
assert.True(t, bytes.Equal(packet, decrypted))

View file

@ -14,36 +14,37 @@ import (
const dnsCryptV2Prefix = "2.dnscrypt-cert."
// ResolverConfig - DNSCrypt resolver configuration
// ResolverConfig is the DNSCrypt resolver configuration
type ResolverConfig struct {
// DNSCrypt provider name
ProviderName string `yaml:"provider_name"`
// PublicKey - DNSCrypt resolver public key
// PublicKey is the DNSCrypt resolver public key
PublicKey string `yaml:"public_key"`
// PrivateKey - DNSCrypt resolver private key
// PrivateKey is the DNSCrypt resolver private key
// The main and only purpose of this key is to sign the certificate
PrivateKey string `yaml:"private_key"`
// ResolverSk - hex-encoded short-term private key.
// ResolverSk is a hex-encoded short-term private key.
// This key is used to encrypt/decrypt DNS queries.
// If not set, we'll generate a new random ResolverSk and ResolverPk.
ResolverSk string `yaml:"resolver_secret"`
// ResolverPk - hex-encoded short-term public key corresponding to ResolverSk.
// ResolverPk is a hex-encoded short-term public key corresponding to ResolverSk.
// This key is used to encrypt/decrypt DNS queries.
ResolverPk string `yaml:"resolver_public"`
// EsVersion - crypto to use in this resolver
// EsVersion is the crypto to use in this resolver
EsVersion CryptoConstruction `yaml:"es_version"`
// CertificateTTL - time-to-live for the certificate that is generated using this ResolverConfig.
// CertificateTTL is the time-to-live value for the certificate that is
// generated using this ResolverConfig.
// If not set, we'll use 1 year by default.
CertificateTTL time.Duration `yaml:"certificate_ttl"`
}
// CreateCert - generates a signed Cert to be used by Server
// CreateCert generates a signed Cert to be used by Server
func (rc *ResolverConfig) CreateCert() (*Cert, error) {
log.Printf("Creating signed DNSCrypt certificate")
@ -98,7 +99,7 @@ func (rc *ResolverConfig) CreateCert() (*Cert, error) {
return cert, nil
}
// CreateStamp - generates a DNS stamp for this resolver
// CreateStamp generates a DNS stamp for this resolver
func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error) {
stamp := dnsstamps.ServerStamp{
ProviderName: rc.ProviderName,
@ -115,7 +116,7 @@ func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error
return stamp, nil
}
// GenerateResolverConfig - generates resolver configuration for a given provider name.
// GenerateResolverConfig generates resolver configuration for a given provider name.
// providerName is mandatory. If needed, "2.dnscrypt-cert." prefix is added to it.
// privateKey is optional. If not set, it will be generated automatically.
func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey) (ResolverConfig, error) {
@ -145,18 +146,18 @@ func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey)
return rc, nil
}
// HexEncodeKey - encodes a byte slice to a hex-encoded string.
// HexEncodeKey encodes a byte slice to a hex-encoded string.
func HexEncodeKey(b []byte) string {
return strings.ToUpper(hex.EncodeToString(b))
}
// HexDecodeKey - decodes a hex-encoded string with (optional) colons
// HexDecodeKey decodes a hex-encoded string with (optional) colons
// to a byte array.
func HexDecodeKey(str string) ([]byte, error) {
return hex.DecodeString(strings.ReplaceAll(str, ":", ""))
}
// generateRandomKeyPair - generates a random key-pair
// generateRandomKeyPair generates a random key-pair
func generateRandomKeyPair() (privateKey [keySize]byte, publicKey [keySize]byte) {
privateKey = [keySize]byte{}
publicKey = [keySize]byte{}

View file

@ -15,24 +15,24 @@ func TestHexEncodeKey(t *testing.T) {
func TestHexDecodeKey(t *testing.T) {
b, err := HexDecodeKey("01:02:03:04")
assert.Nil(t, err)
assert.NoError(t, err)
assert.True(t, bytes.Equal(b, []byte{1, 2, 3, 4}))
}
func TestGenerateResolverConfig(t *testing.T) {
rc, err := GenerateResolverConfig("example.org", nil)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, "2.dnscrypt-cert.example.org", rc.ProviderName)
assert.Equal(t, ed25519.PrivateKeySize*2, len(rc.PrivateKey))
assert.Equal(t, keySize*2, len(rc.ResolverSk))
assert.Equal(t, keySize*2, len(rc.ResolverPk))
cert, err := rc.CreateCert()
assert.Nil(t, err)
assert.NoError(t, err)
assert.True(t, cert.VerifyDate())
publicKey, err := HexDecodeKey(rc.PublicKey)
assert.Nil(t, err)
assert.NoError(t, err)
assert.True(t, cert.VerifySignature(publicKey))
}

2
go.mod
View file

@ -6,7 +6,7 @@ require (
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635
github.com/ameshkov/dnsstamps v1.0.1
github.com/jessevdk/go-flags v1.4.0
github.com/miekg/dns v1.1.29
github.com/miekg/dns v1.1.40
github.com/stretchr/testify v1.6.1
golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e

2
go.sum
View file

@ -14,6 +14,8 @@ github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGAR
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg=
github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/miekg/dns v1.1.40 h1:pyyPFfGMnciYUk/mXpKkVmeMQjfXqt3FAJ2hy7tPiLA=
github.com/miekg/dns v1.1.40/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View file

@ -14,14 +14,14 @@ type Handler interface {
ServeDNS(rw ResponseWriter, r *dns.Msg) error
}
// ResponseWriter - interface that needs to be implemented for different protocols
// ResponseWriter is the interface that needs to be implemented for different protocols
type ResponseWriter interface {
LocalAddr() net.Addr // LocalAddr - local socket address
RemoteAddr() net.Addr // RemoteAddr - remote client socket address
WriteMsg(m *dns.Msg) error // WriteMsg - writes response message to the client
}
// DefaultHandler - default Handler implementation
// DefaultHandler is the default Handler implementation
// that is used by Server if custom handler is not configured
var DefaultHandler Handler = &defaultHandler{
udpClient: &dns.Client{
@ -41,7 +41,7 @@ type defaultHandler struct {
addr string
}
// ServeDNS - implements Handler interface
// ServeDNS implements Handler interface
func (h *defaultHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
// Google DNS
res, _, err := h.udpClient.Exchange(r, h.addr)

174
server.go
View file

@ -1,27 +1,167 @@
package dnscrypt
import (
"context"
"net"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// Server - a simple DNSCrypt server implementation
// default read timeout for all reads
const defaultReadTimeout = 2 * time.Second
// in case of TCP we only use defaultReadTimeout for the first read
// then we start using defaultTCPIdleTimeout
const defaultTCPIdleTimeout = 8 * time.Second
// helper struct that is used in several SetReadDeadline calls
var longTimeAgo = time.Unix(1, 0)
// ServerDNSCrypt is an interface for a DNSCrypt server
type ServerDNSCrypt interface {
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call ServerDNSCrypt.Shutdown.
ServeTCP(l net.Listener) error
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call ServerDNSCrypt.Shutdown.
ServeUDP(l *net.UDPConn) error
// Shutdown tries to gracefully shutdown the server. It waits until all
// connections are processed and only after that it leaves the method.
// If context deadline is specified, it will exit earlier
// or call ServerDNSCrypt.Shutdown.
Shutdown(ctx context.Context) error
}
// Server is a simple DNSCrypt server implementation
type Server struct {
// ProviderName - DNSCrypt provider name
// ProviderName is a DNSCrypt provider name
ProviderName string
// ResolverCert - contains resolver certificate.
// ResolverCert contains resolver certificate.
ResolverCert *Cert
// Handler to invoke. If nil, uses DefaultHandler.
Handler Handler
// make sure init is called only once
initOnce sync.Once
// Shutdown handling
// --
lock sync.RWMutex // protects access to all the fields below
started bool
wg sync.WaitGroup // active workers (servers)
tcpListeners map[net.Listener]struct{} // track active TCP listeners
udpListeners map[*net.UDPConn]struct{} // track active UDP listeners
tcpConns map[net.Conn]struct{} // track active connections
}
// serveDNS - serves DNS response
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) {
// type check
var _ ServerDNSCrypt = &Server{}
// prepareShutdown - prepares the server to shutdown:
// unblocks reads from all connections related to this server
// marks the server as stopped
// if the server is not started, returns ErrServerNotStarted
func (s *Server) prepareShutdown() error {
s.lock.Lock()
defer s.lock.Unlock()
if !s.started {
log.Info("Server is not started")
return ErrServerNotStarted
}
s.started = false
// These listeners were passed to us from the outside so we cannot close
// them here - this is up to the calling code to do that. Instead of that,
// we call Set(Read)Deadline to unblock goroutines that are currently
// blocked on reading from those listeners.
// For tcpConns we would like to avoid closing them to be able to process
// queries before shutting everything down.
// Unblock reads for all active tcpConns
for conn := range s.tcpConns {
_ = conn.SetReadDeadline(longTimeAgo)
}
// Unblock reads for all active TCP listeners
for l := range s.tcpListeners {
switch v := l.(type) {
case *net.TCPListener:
_ = v.SetDeadline(longTimeAgo)
}
}
// Unblock reads for all active UDP listeners
for l := range s.udpListeners {
_ = l.SetReadDeadline(longTimeAgo)
}
return nil
}
// Shutdown tries to gracefully shutdown the server. It waits until all
// connections are processed and only after that it leaves the method.
// If context deadline is specified, it will exit earlier.
func (s *Server) Shutdown(ctx context.Context) error {
log.Info("Shutting down the DNSCrypt server")
err := s.prepareShutdown()
if err != nil {
return err
}
// Using this channel to wait until all goroutines finish their work
closed := make(chan struct{})
go func() {
s.wg.Wait()
log.Info("Serve goroutines finished their work")
close(closed)
}()
// Wait for either all goroutines finish their work
// Or for the context deadline
select {
case <-closed:
log.Info("DNSCrypt server has been stopped")
case <-ctx.Done():
log.Info("DNSCrypt server shutdown has timed out")
err = ctx.Err()
}
return err
}
// init initializes (lazily) Server properties on startup
// this method is called from Server.ServeTCP and Server.ServeUDP
func (s *Server) init() {
s.tcpConns = map[net.Conn]struct{}{}
s.udpListeners = map[*net.UDPConn]struct{}{}
s.tcpListeners = map[net.Listener]struct{}{}
}
// isStarted returns true if the server is processing queries right now
// it means that Server.ServeTCP and/or Server.ServeUDP have been called
func (s *Server) isStarted() bool {
s.lock.RLock()
started := s.started
s.lock.RUnlock()
return started
}
// serveDNS serves DNS response
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error {
if r == nil || len(r.Question) != 1 || r.Response {
log.Tracef("Invalid query: %v", r)
return
return ErrInvalidQuery
}
log.Tracef("Handling a DNS query: %s", r.Question[0].Name)
@ -39,9 +179,11 @@ func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) {
reply.SetRcode(r, dns.RcodeServerFailure)
_ = rw.WriteMsg(reply)
}
return nil
}
// encrypt - encrypts DNSCrypt response
// encrypt encrypts DNSCrypt response
func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
r := EncryptedResponse{
EsVersion: q.EsVersion,
@ -60,7 +202,7 @@ func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
return r.Encrypt(packet, sharedKey)
}
// decrypt - decrypts the incoming message and returns a DNS message to process
// decrypt decrypts the incoming message and returns a DNS message to process
func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
q := EncryptedQuery{
EsVersion: s.ResolverCert.EsVersion,
@ -82,7 +224,7 @@ func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
return r, q, nil
}
// handleHandshake - handles a TXT request that requests certificate data
// handleHandshake handles a TXT request that requests certificate data
func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
m := new(dns.Msg)
err := m.Unpack(b)
@ -120,7 +262,7 @@ func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
return reply.Pack()
}
// validate - checks if the Server config is properly set
// validate checks if the Server config is properly set
func (s *Server) validate() bool {
if s.ResolverCert == nil {
log.Error("ResolverCert must be set")
@ -139,3 +281,13 @@ func (s *Server) validate() bool {
return true
}
// getCertTXT serializes the cert TXT record that are to be sent to the client
func (s *Server) getCertTXT() (string, error) {
certBuf, err := s.ResolverCert.Serialize()
if err != nil {
return "", err
}
certTxt := packTxtString(certBuf)
return certTxt, nil
}

60
server_bench_test.go Normal file
View file

@ -0,0 +1,60 @@
package dnscrypt
import (
"fmt"
"net"
"testing"
"time"
"github.com/ameshkov/dnsstamps"
"github.com/stretchr/testify/require"
)
func BenchmarkServeUDP(b *testing.B) {
benchmarkServe(b, "udp")
}
func BenchmarkServeTCP(b *testing.B) {
benchmarkServe(b, "tcp")
}
func benchmarkServe(b *testing.B, network string) {
srv := newTestServer(b, &testHandler{})
b.Cleanup(func() {
err := srv.Close()
require.NoError(b, err)
})
client := &Client{
Timeout: 1 * time.Second,
Net: network,
}
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
if network == "tcp" {
serverAddr = fmt.Sprintf("127.0.0.1:%d", srv.TCPAddr().Port)
}
stamp := dnsstamps.ServerStamp{
ServerAddrStr: serverAddr,
ServerPk: srv.resolverPk,
ProviderName: srv.server.ProviderName,
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
require.NoError(b, err)
require.NotNil(b, ri)
conn, err := net.Dial(network, stamp.ServerAddrStr)
require.NoError(b, err)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
m := createTestMessage()
res, err := client.ExchangeConn(conn, m, ri)
require.NoError(b, err)
assertTestMessageResponse(b, res)
}
b.StopTimer()
}

View file

@ -2,13 +2,17 @@ package dnscrypt
import (
"bytes"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// TCPResponseWriter - ResponseWriter implementation for TCP
// TCPResponseWriter is the ResponseWriter implementation for TCP
type TCPResponseWriter struct {
tcpConn net.Conn
encrypt encryptionFunc
@ -19,17 +23,17 @@ type TCPResponseWriter struct {
// type check
var _ ResponseWriter = &TCPResponseWriter{}
// LocalAddr - server socket local address
// LocalAddr is the server socket local address
func (w *TCPResponseWriter) LocalAddr() net.Addr {
return w.tcpConn.LocalAddr()
}
// RemoteAddr - client's address
// RemoteAddr is the client's address
func (w *TCPResponseWriter) RemoteAddr() net.Addr {
return w.tcpConn.RemoteAddr()
}
// WriteMsg - writes DNS message to the client
// WriteMsg writes DNS message to the client
func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
m.Truncate(dnsSize("tcp", w.req))
@ -42,33 +46,44 @@ func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
return writePrefixed(res, w.tcpConn)
}
// ServeTCP - listens to TCP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener.
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call Server.Shutdown.
func (s *Server) ServeTCP(l net.Listener) error {
// Check that server is properly configured
if !s.validate() {
return ErrServerConfig
}
// Serialize the cert right away and prepare it to be sent to the client
certBuf, err := s.ResolverCert.Serialize()
err := s.prepareServeTCP(l)
if err != nil {
return err
}
certTxt := packTxtString(certBuf)
log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr().String())
log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr())
for {
// Tracks TCP connection handling goroutines
tcpWg := &sync.WaitGroup{}
defer s.cleanUpTCP(tcpWg, l)
// Track active goroutine
s.wg.Add(1)
// Serialize the cert right away and prepare it to be sent to the client
certTxt, err := s.getCertTXT()
if err != nil {
return err
}
for s.isStarted() {
conn, err := l.Accept()
if err == nil {
go func() {
_ = s.handleTCPConnection(conn, certTxt)
_ = conn.Close()
}()
}
// Check the error code and exit loop if necessary
if err != nil {
if !s.isStarted() {
// Stopped gracefully
break
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Temporary() {
// Note that timeout errors will be here (i.e. hitting ReadDeadline)
continue
}
if isConnClosed(err) {
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
} else {
@ -76,47 +91,129 @@ func (s *Server) ServeTCP(l net.Listener) error {
}
break
}
// If we got here, the connection is alive
s.lock.Lock()
// Track the connection to allow unblocking reads on shutdown.
s.tcpConns[conn] = struct{}{}
s.lock.Unlock()
tcpWg.Add(1)
go func() {
// Ignore error here, it is most probably a legit one
// if not, it's written to the debug log
_ = s.handleTCPConnection(conn, certTxt)
// Clean up
_ = conn.Close()
s.lock.Lock()
delete(s.tcpConns, conn)
s.lock.Unlock()
tcpWg.Done()
}()
}
return nil
}
// prepareServeTCP prepares the server and listener to serving DNSCrypt
func (s *Server) prepareServeTCP(l net.Listener) error {
// Check that server is properly configured
if !s.validate() {
return ErrServerConfig
}
// Protect shutdown-related fields
s.lock.Lock()
defer s.lock.Unlock()
s.initOnce.Do(s.init)
// Mark the server as started if needed
s.started = true
// Track an active TCP listener
s.tcpListeners[l] = struct{}{}
return nil
}
// cleanUpTCP waits until all TCP messages before cleaning up
func (s *Server) cleanUpTCP(tcpWg *sync.WaitGroup, l net.Listener) {
// Wait until all TCP connections are processed
tcpWg.Wait()
// Not using it anymore so can be removed from the active listeners
s.lock.Lock()
delete(s.tcpListeners, l)
s.lock.Unlock()
// The work is finished
s.wg.Done()
}
// handleTCPMsg handles a single TCP message. If this method returns error
// the connection will be closed
func (s *Server) handleTCPMsg(b []byte, conn net.Conn, certTxt string) error {
if len(b) < minDNSPacketSize {
// Ignore the packets that are too short
return ErrTooShort
}
// First of all, check for "ClientMagic" in the incoming query
if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
// If there's no ClientMagic in the packet, we assume this
// is a plain DNS query requesting the certificate data
reply, err := s.handleHandshake(b, certTxt)
if err != nil {
return fmt.Errorf("failed to process a plain DNS query: %w", err)
}
err = writePrefixed(reply, conn)
if err != nil {
return fmt.Errorf("failed to write a response: %w", err)
}
return nil
}
// If we got here, this is an encrypted DNSCrypt message
// We should decrypt it first to get the plain DNS query
m, q, err := s.decrypt(b)
if err != nil {
return fmt.Errorf("failed to decrypt incoming message: %w", err)
}
rw := &TCPResponseWriter{
tcpConn: conn,
encrypt: s.encrypt,
req: m,
query: q,
}
err = s.serveDNS(rw, m)
if err != nil {
return fmt.Errorf("failed to process a DNS query: %w", err)
}
return nil
}
// handleTCPConnection handles all queries that are coming to the
// specified TCP connection.
func (s *Server) handleTCPConnection(conn net.Conn, certTxt string) error {
for {
timeout := defaultReadTimeout
for s.isStarted() {
_ = conn.SetReadDeadline(time.Now().Add(timeout))
b, err := readPrefixed(conn)
if err != nil {
return err
}
if len(b) < minDNSPacketSize {
// Ignore the packets that are too short
return ErrTooShort
err = s.handleTCPMsg(b, conn, certTxt)
if err != nil {
log.Debug("failed to process DNS query: %v", err)
return err
}
if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
// This is an encrypted message, we should decrypt it
m, q, err := s.decrypt(b)
if err != nil {
log.Tracef("failed to decrypt incoming message: %v", err)
return err
}
rw := &TCPResponseWriter{
tcpConn: conn,
encrypt: s.encrypt,
req: m,
query: q,
}
s.serveDNS(rw, m)
} else {
// Most likely this a DNS message requesting the certificate
reply, err := s.handleHandshake(b, certTxt)
if err != nil {
log.Tracef("Failed to process a plain DNS query: %v", err)
return err
}
err = writePrefixed(reply, conn)
if err != nil {
return err
}
}
timeout = defaultTCPIdleTimeout
}
return nil
}

View file

@ -2,9 +2,11 @@ package dnscrypt
import (
"bytes"
"context"
"crypto/ed25519"
"fmt"
"net"
"runtime"
"testing"
"time"
@ -13,25 +15,52 @@ import (
"github.com/stretchr/testify/assert"
)
func TestServerUDPServeCert(t *testing.T) {
func TestServer_Shutdown(t *testing.T) {
n := runtime.GOMAXPROCS(1)
t.Cleanup(func() {
runtime.GOMAXPROCS(n)
})
srv := newTestServer(t, &testHandler{})
// Serve* methods are called in different goroutines
// give them at least a moment to actually start the server
time.Sleep(10 * time.Millisecond)
assert.NoError(t, srv.Close())
}
func TestServer_UDPServeCert(t *testing.T) {
testServerServeCert(t, "udp")
}
func TestServerTCPServeCert(t *testing.T) {
func TestServer_TCPServeCert(t *testing.T) {
testServerServeCert(t, "tcp")
}
func TestServerUDPRespondMessages(t *testing.T) {
func TestServer_UDPRespondMessages(t *testing.T) {
testServerRespondMessages(t, "udp")
}
func TestServerTCPRespondMessages(t *testing.T) {
func TestServer_TCPRespondMessages(t *testing.T) {
testServerRespondMessages(t, "tcp")
}
func TestServer_ReadTimeout(t *testing.T) {
srv := newTestServer(t, &testHandler{})
t.Cleanup(func() {
assert.NoError(t, srv.Close())
})
// Sleep for "defaultReadTimeout" before trying to shutdown the server
// The point is to make sure readTimeout is properly handled by
// the "Serve*" goroutines and they don't finish their work unexpectedly
time.Sleep(defaultReadTimeout)
testThisServerRespondMessages(t, "udp", srv)
testThisServerRespondMessages(t, "tcp", srv)
}
func testServerServeCert(t *testing.T, network string) {
srv := newTestServer(t, &testHandler{})
defer srv.Close()
t.Cleanup(func() {
assert.NoError(t, srv.Close())
})
client := &Client{
Net: network,
@ -50,7 +79,7 @@ func testServerServeCert(t *testing.T, network string) {
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotNil(t, ri)
assert.Equal(t, ri.ProviderName, srv.server.ProviderName)
@ -65,8 +94,13 @@ func testServerServeCert(t *testing.T, network string) {
func testServerRespondMessages(t *testing.T, network string) {
srv := newTestServer(t, &testHandler{})
defer srv.Close()
t.Cleanup(func() {
assert.NoError(t, srv.Close())
})
testThisServerRespondMessages(t, network, srv)
}
func testThisServerRespondMessages(t *testing.T, network string, srv *testServer) {
client := &Client{
Timeout: 1 * time.Second,
Net: network,
@ -84,16 +118,16 @@ func testServerRespondMessages(t *testing.T, network string) {
Proto: dnsstamps.StampProtoTypeDNSCrypt,
}
ri, err := client.DialStamp(stamp)
assert.Nil(t, err)
assert.NoError(t, err)
assert.NotNil(t, ri)
conn, err := net.Dial(network, stamp.ServerAddrStr)
assert.Nil(t, err)
assert.NoError(t, err)
for i := 0; i < 10; i++ {
m := createTestMessage()
res, err := client.ExchangeConn(conn, m, ri)
assert.Nil(t, err)
assert.NoError(t, err)
assertTestMessageResponse(t, res)
}
}
@ -114,16 +148,22 @@ func (s *testServer) UDPAddr() *net.UDPAddr {
return s.udpConn.LocalAddr().(*net.UDPAddr)
}
func (s *testServer) Close() {
func (s *testServer) Close() error {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
defer cancel()
err := s.server.Shutdown(ctx)
_ = s.udpConn.Close()
_ = s.tcpListen.Close()
return err
}
func newTestServer(t *testing.T, handler Handler) *testServer {
func newTestServer(t assert.TestingT, handler Handler) *testServer {
rc, err := GenerateResolverConfig("example.org", nil)
assert.Nil(t, err)
assert.NoError(t, err)
cert, err := rc.CreateCert()
assert.Nil(t, err)
assert.NoError(t, err)
s := &Server{
ProviderName: rc.ProviderName,
@ -132,7 +172,7 @@ func newTestServer(t *testing.T, handler Handler) *testServer {
}
privateKey, err := HexDecodeKey(rc.PrivateKey)
assert.Nil(t, err)
assert.NoError(t, err)
publicKey := ed25519.PrivateKey(privateKey).Public().(ed25519.PublicKey)
srv := &testServer{
server: s,
@ -140,9 +180,9 @@ func newTestServer(t *testing.T, handler Handler) *testServer {
}
srv.tcpListen, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0})
assert.Nil(t, err)
assert.NoError(t, err)
srv.udpConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
assert.Nil(t, err)
assert.NoError(t, err)
go s.ServeUDP(srv.udpConn)
go s.ServeTCP(srv.tcpListen)

View file

@ -2,38 +2,43 @@ package dnscrypt
import (
"bytes"
"errors"
"net"
"runtime"
"sync"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error)
// UDPResponseWriter - ResponseWriter implementation for UDP
// UDPResponseWriter is the ResponseWriter implementation for UDP
type UDPResponseWriter struct {
udpConn *net.UDPConn // UDP connection
remoteAddr *net.UDPAddr // Remote peer address
localIP net.IP // Local IP (that was used to accept the remote connection)
encrypt encryptionFunc // DNSCRypt encryption function
req *dns.Msg // DNS query that was processed
query EncryptedQuery // DNSCrypt query properties
udpConn *net.UDPConn // UDP connection
sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP)
encrypt encryptionFunc // DNSCRypt encryption function
req *dns.Msg // DNS query that was processed
query EncryptedQuery // DNSCrypt query properties
}
// type check
var _ ResponseWriter = &UDPResponseWriter{}
// LocalAddr - server socket local address
// LocalAddr is the server socket local address
func (w *UDPResponseWriter) LocalAddr() net.Addr {
return w.udpConn.LocalAddr()
}
// RemoteAddr - client's address
// RemoteAddr is the client's address
func (w *UDPResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
return w.udpConn.RemoteAddr()
}
// WriteMsg - writes DNS message to the client
// WriteMsg writes DNS message to the client
func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
m.Truncate(dnsSize("udp", w.req))
@ -42,83 +47,176 @@ func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
log.Tracef("Failed to encrypt the DNS query: %v", err)
return err
}
_, err = dns.WriteToSessionUDP(w.udpConn, res, w.sess)
return err
}
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener
// or call Server.Shutdown.
func (s *Server) ServeUDP(l *net.UDPConn) error {
err := s.prepareServeUDP(l)
if err != nil {
return err
}
// Tracks UDP handling goroutines
udpWg := &sync.WaitGroup{}
defer s.cleanUpUDP(udpWg, l)
// Track active goroutine
s.wg.Add(1)
log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr())
// Serialize the cert right away and prepare it to be sent to the client
certTxt, err := s.getCertTXT()
if err != nil {
return err
}
for s.isStarted() {
b, sess, err := s.readUDPMsg(l)
// Check the error code and exit loop if necessary
if err != nil {
if !s.isStarted() {
// Stopped gracefully
return nil
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Temporary() {
// Note that timeout errors will be here (i.e. hitting ReadDeadline)
continue
}
if isConnClosed(err) {
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
} else {
log.Info("got error when reading from UDP listen: %s", err)
}
return err
}
if len(b) < minDNSPacketSize {
// Ignore the packets that are too short
continue
}
udpWg.Add(1)
go func() {
s.serveUDPMsg(b, certTxt, sess, l)
udpWg.Done()
}()
}
_, _ = udpWrite(res, w.udpConn, w.remoteAddr, w.localIP)
return nil
}
// ServeUDP - listens to UDP connections, queries are then processed by Server.Handler.
// It blocks the calling goroutine and to stop it you need to close the listener.
func (s *Server) ServeUDP(l *net.UDPConn) error {
// prepareServeUDP prepares the server and listener to serving DNSCrypt
func (s *Server) prepareServeUDP(l *net.UDPConn) error {
// Check that server is properly configured
if !s.validate() {
return ErrServerConfig
}
// set UDP options to allow receiving OOB data
err := udpSetOptions(l)
err := setUDPSocketOptions(l)
if err != nil {
return err
}
// Buffer to read incoming messages
b := make([]byte, dns.MaxMsgSize)
// Protect shutdown-related fields
s.lock.Lock()
defer s.lock.Unlock()
s.initOnce.Do(s.init)
// Serialize the cert right away and prepare it to be sent to the client
certBuf, err := s.ResolverCert.Serialize()
// Mark the server as started.
// Note that we don't check if it was started before as
// Serve* methods can be called multiple times.
s.started = true
// Track an active UDP listener
s.udpListeners[l] = struct{}{}
return err
}
// cleanUpUDP waits until all UDP messages before cleaning up
func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) {
// Wait until UDP messages are processed
udpWg.Wait()
// Not using it anymore so can be removed from the active listeners
s.lock.Lock()
delete(s.udpListeners, l)
s.lock.Unlock()
// The work is finished
s.wg.Done()
}
// readUDPMsg reads incoming UDP message
func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) {
_ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout))
b := make([]byte, dns.MinMsgSize)
n, sess, err := dns.ReadFromSessionUDP(l, b)
if err != nil {
return err
return nil, nil, err
}
certTxt := packTxtString(certBuf)
// Init oobSize - it will be used later when reading and writing UDP messages
oobSize := udpGetOOBSize()
log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr().String())
for {
n, localIP, addr, err := udpRead(l, b, oobSize)
if n < minDNSPacketSize {
// Ignore the packets that are too short
continue
}
if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
// This is an encrypted message, we should decrypt it
m, q, err := s.decrypt(b[:n])
if err == nil {
rw := &UDPResponseWriter{
udpConn: l,
remoteAddr: addr,
localIP: localIP,
encrypt: s.encrypt,
req: m,
query: q,
}
go s.serveDNS(rw, m)
} else {
log.Tracef("Failed to decrypt incoming message len=%d: %v", n, err)
}
} else {
// Most likely this a DNS message requesting the certificate
reply, err := s.handleHandshake(b, certTxt)
if err != nil {
log.Tracef("Failed to process a plain DNS query: %v", err)
}
if err == nil {
_, _ = l.WriteTo(reply, addr)
}
}
return b[:n], sess, err
}
// serveUDPMsg handles incoming DNS message
func (s *Server) serveUDPMsg(b []byte, certTxt string, sess *dns.SessionUDP, l *net.UDPConn) {
// First of all, check for "ClientMagic" in the incoming query
if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
// If there's no ClientMagic in the packet, we assume this
// is a plain DNS query requesting the certificate data
reply, err := s.handleHandshake(b, certTxt)
if err != nil {
if isConnClosed(err) {
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
} else {
log.Info("got error when reading from UDP listen: %s", err)
}
break
log.Tracef("failed to process a plain DNS query: %v", err)
}
if err == nil {
// Ignore errors, we don't care and can't handle them anyway
_, _ = dns.WriteToSessionUDP(l, reply, sess)
}
return
}
// If we got here, this is an encrypted DNSCrypt message
// We should decrypt it first to get the plain DNS query
m, q, err := s.decrypt(b)
if err == nil {
rw := &UDPResponseWriter{
udpConn: l,
sess: sess,
encrypt: s.encrypt,
req: m,
query: q,
}
err = s.serveDNS(rw, m)
if err != nil {
log.Tracef("failed to process a DNS query: %v", err)
}
} else {
log.Tracef("failed to decrypt incoming message len=%d: %v", len(b), err)
}
}
// setUDPSocketOptions method is necessary to be able to use dns.ReadFromSessionUDP / dns.WriteToSessionUDP
func setUDPSocketOptions(conn *net.UDPConn) error {
if runtime.GOOS == "windows" {
return nil
}
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
// Try enabling receiving of ECN and packet info for both IP versions.
// We expect at least one of those syscalls to succeed.
err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
if err6 != nil && err4 != nil {
return err4
}
return nil
}

1
testdata/dnscrypt-cert.opendns.txt vendored Normal file
View file

@ -0,0 +1 @@
DNSC\000\001\000\000\200\226E:H\156\203%\134\218\127]\168\239\027u\011$\191\008\239\176F\133\017\171\161\219\154\142i\164\010\239\017f\168dS\210f\197\194\169\171w\2499\1891\155<\130\218@/\155\023v\153#d\024\004\136\180\228K5\233d\180\144\189\218\186\232%\162K\004\021\160\139\225\157}\219\135\163<\215~\223\142/qc78aWoo]\221\184`]\221\184`_\190\235\224

View file

@ -1,83 +0,0 @@
// +build aix darwin dragonfly linux netbsd openbsd solaris freebsd
package dnscrypt
import (
"fmt"
"net"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// udpGetOOBSize - get max. size of received OOB data
// It will then be used in the ReadMsgUDP function
func udpGetOOBSize() int {
oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)
oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)
if len(oob4) > len(oob6) {
return len(oob4)
}
return len(oob6)
}
// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data
func udpSetOptions(c *net.UDPConn) error {
err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
if err6 != nil && err4 != nil {
return fmt.Errorf("failed to call SetControlMessage: ipv4: %v ipv6: %v", err4, err6)
}
return nil
}
// udpRead - receive payload and OOB data from the UDP socket
func udpRead(c *net.UDPConn, buf []byte, udpOOBSize int) (int, net.IP, *net.UDPAddr, error) {
var oobn int
oob := make([]byte, udpOOBSize)
var err error
var n int
var remoteAddr *net.UDPAddr
n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob)
if err != nil {
return -1, nil, nil, err
}
localIP := udpGetDstFromOOB(oob[:oobn])
return n, localIP, remoteAddr, nil
}
// udpWrite - writes to the UDP socket and sets local IP to OOB data
func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, localIP net.IP) (int, error) {
n, _, err := conn.WriteMsgUDP(bytes, udpMakeOOBWithSrc(localIP), remoteAddr)
return n, err
}
// udpGetDstFromOOB - get destination IP from OOB data
func udpGetDstFromOOB(oob []byte) net.IP {
cm6 := &ipv6.ControlMessage{}
if cm6.Parse(oob) == nil && cm6.Dst != nil {
return cm6.Dst
}
cm4 := &ipv4.ControlMessage{}
if cm4.Parse(oob) == nil && cm4.Dst != nil {
return cm4.Dst
}
return nil
}
// udpMakeOOBWithSrc - make OOB data with a specified source IP
func udpMakeOOBWithSrc(ip net.IP) []byte {
if ip.To4() == nil {
cm := &ipv6.ControlMessage{}
cm.Src = ip
return cm.Marshal()
}
cm := &ipv4.ControlMessage{}
cm.Src = ip
return cm.Marshal()
}

View file

@ -1,30 +0,0 @@
package dnscrypt
import "net"
// udpGetOOBSize - get max. size of received OOB data
// Does nothing on Windows
func udpGetOOBSize() int {
return 0
}
// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data
// Does nothing on Windows
func udpSetOptions(c *net.UDPConn) error {
return nil
}
// udpRead - receive payload from the UDP socket
func udpRead(c *net.UDPConn, buf []byte, _ int) (int, net.IP, *net.UDPAddr, error) {
n, addr, err := c.ReadFrom(buf)
var udpAddr *net.UDPAddr
if addr != nil {
udpAddr = addr.(*net.UDPAddr)
}
return n, nil, udpAddr, err
}
// udpWrite - writes to the UDP socket
func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ net.IP) (int, error) {
return conn.WriteTo(bytes, remoteAddr)
}