mirror of
https://github.com/SamTherapy/dnscrypt.git
synced 2024-12-21 16:50:42 +00:00
Graceful shutdown of the DNSCrypt server (#6)
Graceful shutdown of the DNSCrypt server This PR implements Server.Shutdown(ctx context.Context) method that allows to shut down the DNSCrypt server gracefully. Some additional changes that were inadvertently made while doing that: 1. Added benchmark tests 2. Started using dns.ReadFromSessionUDP / dns.WriteToSessionUDP instead of implementing it by ourselves 3. Generally improved tests 4. Added depguard 5. Improved comments overall in the code
This commit is contained in:
parent
d0ae1d198d
commit
5ddb58f703
28 changed files with 744 additions and 384 deletions
|
@ -20,6 +20,12 @@ linters-settings:
|
|||
min-complexity: 20
|
||||
lll:
|
||||
line-length: 200
|
||||
depguard:
|
||||
list-type: blacklist
|
||||
include-go-root: false
|
||||
packages:
|
||||
- golang.org/x/net/context # we use context
|
||||
- log # we use github.com/AdguardTeam/golibs/log
|
||||
|
||||
linters:
|
||||
enable:
|
||||
|
@ -41,6 +47,7 @@ linters:
|
|||
- misspell
|
||||
- stylecheck
|
||||
- unconvert
|
||||
- depguard
|
||||
disable-all: true
|
||||
fast: true
|
||||
|
||||
|
|
30
cert.go
30
cert.go
|
@ -8,10 +8,10 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// Cert - DNSCrypt server certificate
|
||||
// Cert is a DNSCrypt server certificate
|
||||
// See ResolverConfig for more info on how to create one
|
||||
type Cert struct {
|
||||
// Serial - a 4 byte serial number in big-endian format. If more than
|
||||
// Serial is a 4 byte serial number in big-endian format. If more than
|
||||
// one certificates are valid, the client must prefer the certificate
|
||||
// with a higher serial number.
|
||||
Serial uint32
|
||||
|
@ -22,36 +22,36 @@ type Cert struct {
|
|||
// For X25519-XChacha20Poly1305, <es-version> must be 0x00 0x02.
|
||||
EsVersion CryptoConstruction
|
||||
|
||||
// Signature - a 64-byte signature of (<resolver-pk> <client-magic>
|
||||
// Signature is a 64-byte signature of (<resolver-pk> <client-magic>
|
||||
// <serial> <ts-start> <ts-end> <extensions>) using the Ed25519 algorithm and the
|
||||
// provider secret key. Ed25519 must be used in this version of the
|
||||
// protocol.
|
||||
Signature [ed25519.SignatureSize]byte
|
||||
|
||||
// ResolverPk - the resolver's short-term public key, which is 32 bytes when using X25519.
|
||||
// ResolverPk is the resolver's short-term public key, which is 32 bytes when using X25519.
|
||||
// This key is used to encrypt/decrypt DNS queries
|
||||
ResolverPk [keySize]byte
|
||||
|
||||
// ResolverSk - the resolver's short-term private key, which is 32 bytes when using X25519.
|
||||
// ResolverSk is the resolver's short-term private key, which is 32 bytes when using X25519.
|
||||
// Note that it's only used in the server implementation and never serialized/deserialized.
|
||||
// This key is used to encrypt/decrypt DNS queries
|
||||
ResolverSk [keySize]byte
|
||||
|
||||
// ClientMagic - the first 8 bytes of a client query that is to be built
|
||||
// ClientMagic is the first 8 bytes of a client query that is to be built
|
||||
// using the information from this certificate. It may be a truncated
|
||||
// public key. Two valid certificates cannot share the same <client-magic>.
|
||||
ClientMagic [clientMagicSize]byte
|
||||
|
||||
// NotAfter - the date the certificate is valid from, as a big-endian
|
||||
// NotAfter is the date the certificate is valid from, as a big-endian
|
||||
// 4-byte unsigned Unix timestamp.
|
||||
NotBefore uint32
|
||||
|
||||
// NotAfter - the date the certificate is valid until (inclusive), as a
|
||||
// NotAfter is the date the certificate is valid until (inclusive), as a
|
||||
// big-endian 4-byte unsigned Unix timestamp.
|
||||
NotAfter uint32
|
||||
}
|
||||
|
||||
// Serialize - serializes the cert to bytes
|
||||
// Serialize serializes the cert to bytes
|
||||
// <cert> ::= <cert-magic> <es-version> <protocol-minor-version> <signature>
|
||||
// <resolver-pk> <client-magic> <serial> <ts-start> <ts-end>
|
||||
// <extensions>
|
||||
|
@ -86,7 +86,7 @@ func (c *Cert) Serialize() ([]byte, error) {
|
|||
return b, nil
|
||||
}
|
||||
|
||||
// Deserialize - deserializes certificate from a byte array
|
||||
// Deserialize deserializes certificate from a byte array
|
||||
// <cert> ::= <cert-magic> <es-version> <protocol-minor-version> <signature>
|
||||
// <resolver-pk> <client-magic> <serial> <ts-start> <ts-end>
|
||||
// <extensions>
|
||||
|
@ -127,7 +127,7 @@ func (c *Cert) Deserialize(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// VerifyDate - checks that cert is valid at this moment
|
||||
// VerifyDate checks that the cert is valid at this moment
|
||||
func (c *Cert) VerifyDate() bool {
|
||||
if c.NotBefore >= c.NotAfter {
|
||||
return false
|
||||
|
@ -139,14 +139,14 @@ func (c *Cert) VerifyDate() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// VerifySignature - checks if the cert is properly signed with the specified signature
|
||||
// VerifySignature checks if the cert is properly signed with the specified signature
|
||||
func (c *Cert) VerifySignature(publicKey ed25519.PublicKey) bool {
|
||||
b := make([]byte, 52)
|
||||
c.writeSigned(b)
|
||||
return ed25519.Verify(publicKey, b, c.Signature[:])
|
||||
}
|
||||
|
||||
// Sign - creates cert.Signature
|
||||
// Sign creates cert.Signature
|
||||
func (c *Cert) Sign(privateKey ed25519.PrivateKey) {
|
||||
b := make([]byte, 52)
|
||||
c.writeSigned(b)
|
||||
|
@ -154,14 +154,14 @@ func (c *Cert) Sign(privateKey ed25519.PrivateKey) {
|
|||
copy(c.Signature[:64], signature[:64])
|
||||
}
|
||||
|
||||
// String - Cert's string representation
|
||||
// String Cert's string representation
|
||||
func (c *Cert) String() string {
|
||||
return fmt.Sprintf("Certificate Serial=%d NotBefore=%s NotAfter=%s EsVersion=%s",
|
||||
c.Serial, time.Unix(int64(c.NotBefore), 0).String(),
|
||||
time.Unix(int64(c.NotAfter), 0).String(), c.EsVersion.String())
|
||||
}
|
||||
|
||||
// writeSigned - writes (<resolver-pk> <client-magic> <serial> <ts-start> <ts-end> <extensions>)
|
||||
// writeSigned writes (<resolver-pk> <client-magic> <serial> <ts-start> <ts-end> <extensions>)
|
||||
func (c *Cert) writeSigned(dst []byte) {
|
||||
// <resolver-pk>
|
||||
copy(dst[:32], c.ResolverPk[:keySize])
|
||||
|
|
16
cert_test.go
16
cert_test.go
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -21,13 +22,13 @@ func TestCertSerialize(t *testing.T) {
|
|||
|
||||
// serialize
|
||||
b, err := cert.Serialize()
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 124, len(b))
|
||||
|
||||
// check that we can deserialize it
|
||||
cert2 := Cert{}
|
||||
err = cert2.Deserialize(b)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cert.Serial, cert2.Serial)
|
||||
assert.Equal(t, cert.NotBefore, cert2.NotBefore)
|
||||
assert.Equal(t, cert.NotAfter, cert2.NotAfter)
|
||||
|
@ -39,12 +40,15 @@ func TestCertSerialize(t *testing.T) {
|
|||
|
||||
func TestCertDeserialize(t *testing.T) {
|
||||
// dig -t txt 2.dnscrypt-cert.opendns.com. -p 443 @208.67.220.220
|
||||
b, err := unpackTxtString("DNSC\\000\\001\\000\\000\\200\\226E:H\\156\\203%\\134\\218\\127]\\168\\239\\027u\\011$\\191\\008\\239\\176F\\133\\017\\171\\161\\219\\154\\142i\\164\\010\\239\\017f\\168dS\\210f\\197\\194\\169\\171w\\2499\\1891\\155<\\130\\218@/\\155\\023v\\153#d\\024\\004\\136\\180\\228K5\\233d\\180\\144\\189\\218\\186\\232%\\162K\\004\\021\\160\\139\\225\\157}\\219\\135\\163<\\215~\\223\\142/qc78aWoo]\\221\\184`]\\221\\184`_\\190\\235\\224")
|
||||
assert.Nil(t, err)
|
||||
certBytes, err := ioutil.ReadFile("testdata/dnscrypt-cert.opendns.txt")
|
||||
assert.NoError(t, err)
|
||||
|
||||
b, err := unpackTxtString(string(certBytes))
|
||||
assert.NoError(t, err)
|
||||
|
||||
cert := &Cert{}
|
||||
err = cert.Deserialize(b)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, uint32(1574811744), cert.Serial)
|
||||
assert.Equal(t, XSalsa20Poly1305, cert.EsVersion)
|
||||
assert.Equal(t, uint32(1574811744), cert.NotBefore)
|
||||
|
@ -69,7 +73,7 @@ func generateValidCert(t *testing.T) (*Cert, ed25519.PublicKey, ed25519.PrivateK
|
|||
|
||||
// generate private key
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// sign the data
|
||||
cert.Sign(privateKey)
|
||||
|
|
12
client.go
12
client.go
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Client - DNSCrypt resolver client
|
||||
// Client is a DNSCrypt resolver client
|
||||
type Client struct {
|
||||
Net string // protocol (can be "udp" or "tcp", by default - "udp")
|
||||
Timeout time.Duration // read/write timeout
|
||||
|
@ -124,7 +124,7 @@ func (c *Client) ExchangeConn(conn net.Conn, m *dns.Msg, resolverInfo *ResolverI
|
|||
return res, nil
|
||||
}
|
||||
|
||||
// writeQuery - writes query to the network connection
|
||||
// writeQuery writes query to the network connection
|
||||
// depending on the protocol we may write a 2-byte prefix or not
|
||||
func (c *Client) writeQuery(conn net.Conn, query []byte) error {
|
||||
var err error
|
||||
|
@ -145,7 +145,7 @@ func (c *Client) writeQuery(conn net.Conn, query []byte) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// readResponse - reads response from the network connection
|
||||
// readResponse reads response from the network connection
|
||||
// depending on the protocol, we may read a 2-byte prefix or not
|
||||
func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
|
||||
if c.Timeout > 0 {
|
||||
|
@ -171,7 +171,7 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) {
|
|||
return readPrefixed(conn)
|
||||
}
|
||||
|
||||
// encrypt - encrypts a DNS message using shared key from the resolver info
|
||||
// encrypt encrypts a DNS message using shared key from the resolver info
|
||||
func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) {
|
||||
q := EncryptedQuery{
|
||||
EsVersion: resolverInfo.ResolverCert.EsVersion,
|
||||
|
@ -185,7 +185,7 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error)
|
|||
return q.Encrypt(query, resolverInfo.SharedKey)
|
||||
}
|
||||
|
||||
// decrypts - decrypts a DNS message using shared key from the resolver info
|
||||
// decrypts decrypts a DNS message using a shared key from the resolver info
|
||||
func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) {
|
||||
dr := EncryptedResponse{
|
||||
EsVersion: resolverInfo.ResolverCert.EsVersion,
|
||||
|
@ -203,7 +203,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error)
|
|||
return res, nil
|
||||
}
|
||||
|
||||
// fetchCert - loads DNSCrypt cert from the specified server
|
||||
// fetchCert loads DNSCrypt cert from the specified server
|
||||
func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) {
|
||||
providerName := stamp.ProviderName
|
||||
if !strings.HasSuffix(providerName, ".") {
|
||||
|
|
|
@ -63,7 +63,7 @@ func TestTimeoutOnDialExchange(t *testing.T) {
|
|||
client := Client{Timeout: 300 * time.Millisecond}
|
||||
|
||||
serverInfo, err := client.Dial(stampStr)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Point it to an IP where there's no DNSCrypt server
|
||||
serverInfo.ServerAddress = "8.8.8.8:5443"
|
||||
|
@ -105,12 +105,15 @@ func TestFetchCertPublicResolvers(t *testing.T) {
|
|||
|
||||
for _, test := range stamps {
|
||||
stamp, err := dnsstamps.NewServerStampFromString(test.stampStr)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run(stamp.ProviderName, func(t *testing.T) {
|
||||
c := &Client{Net: "udp"}
|
||||
c := &Client{
|
||||
Net: "udp",
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
resolverInfo, err := c.DialStamp(stamp)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resolverInfo)
|
||||
assert.True(t, resolverInfo.ResolverCert.VerifyDate())
|
||||
assert.True(t, resolverInfo.ResolverCert.VerifySignature(stamp.ServerPk))
|
||||
|
@ -146,7 +149,7 @@ func TestExchangePublicResolvers(t *testing.T) {
|
|||
|
||||
for _, test := range stamps {
|
||||
stamp, err := dnsstamps.NewServerStampFromString(test.stampStr)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run(stamp.ProviderName, func(t *testing.T) {
|
||||
checkDNSCryptServer(t, test.stampStr, "udp")
|
||||
|
@ -158,12 +161,12 @@ func TestExchangePublicResolvers(t *testing.T) {
|
|||
func checkDNSCryptServer(t *testing.T, stampStr string, network string) {
|
||||
client := Client{Net: network, Timeout: 10 * time.Second}
|
||||
resolverInfo, err := client.Dial(stampStr)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := createTestMessage()
|
||||
|
||||
reply, err := client.Exchange(req, resolverInfo)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assertTestMessageResponse(t, reply)
|
||||
}
|
||||
|
||||
|
@ -177,7 +180,7 @@ func createTestMessage() *dns.Msg {
|
|||
return &req
|
||||
}
|
||||
|
||||
func assertTestMessageResponse(t *testing.T, reply *dns.Msg) {
|
||||
func assertTestMessageResponse(t assert.TestingT, reply *dns.Msg) {
|
||||
assert.NotNil(t, reply)
|
||||
assert.Equal(t, 1, len(reply.Answer))
|
||||
a, ok := reply.Answer[0].(*dns.A)
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConvertWrapperArgs - "convert-dnscrypt-wrapper" command arguments
|
||||
// ConvertWrapperArgs is the "convert-dnscrypt-wrapper" command arguments structure
|
||||
type ConvertWrapperArgs struct {
|
||||
PrivateKeyFile string `short:"p" long:"private-key" description:"Path to the DNSCrypt resolver private key file that is used for signing certificates. Param is required." required:"true"`
|
||||
ResolverSkFile string `short:"r" long:"resolver-secret" description:"Path to the Short-term privacy key file for encrypting/decrypting DNS queries. If not specified, resolver_secret and resolver_public will be randomly generated."`
|
||||
|
@ -21,7 +21,7 @@ type ConvertWrapperArgs struct {
|
|||
CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"`
|
||||
}
|
||||
|
||||
// convertWrapper - generates DNSCrypt configuration from both dnscrypt and server private keys
|
||||
// convertWrapper generates DNSCrypt configuration from both dnscrypt and server private keys
|
||||
func convertWrapper(args ConvertWrapperArgs) {
|
||||
|
||||
log.Info("Generating configuration for %s", args.ProviderName)
|
||||
|
@ -71,7 +71,7 @@ func convertWrapper(args ConvertWrapperArgs) {
|
|||
}
|
||||
}
|
||||
|
||||
// validateRc - verifies that the certificate is correctly
|
||||
// validateRc verifies that the certificate is correctly
|
||||
// created and validated for this resolver config. if rc valid returns nil.
|
||||
func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error {
|
||||
cert, err := rc.CreateCert()
|
||||
|
@ -90,7 +90,7 @@ func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// getResolverPk - calculates public key from private key
|
||||
// getResolverPk calculates public key from private key
|
||||
func getResolverPk(private ed25519.PrivateKey) ed25519.PublicKey {
|
||||
resolverSk := [32]byte{}
|
||||
resolverPk := [32]byte{}
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// GenerateArgs - "generate" command arguments
|
||||
// GenerateArgs is the "generate" command arguments structure
|
||||
type GenerateArgs struct {
|
||||
ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt provider name. Param is required." required:"true"`
|
||||
Out string `short:"o" long:"out" description:"Path to the resulting config file. Param is required." required:"true"`
|
||||
|
@ -16,7 +16,7 @@ type GenerateArgs struct {
|
|||
CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"`
|
||||
}
|
||||
|
||||
// generate - generates DNSCrypt server configuration
|
||||
// generate generates a DNSCrypt server configuration
|
||||
func generate(args GenerateArgs) {
|
||||
log.Info("Generating configuration for %s", args.ProviderName)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// LookupStampArgs - "lookup-stamp" command arguments
|
||||
// LookupStampArgs is the "lookup-stamp" command arguments structure
|
||||
type LookupStampArgs struct {
|
||||
Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"`
|
||||
Stamp string `short:"s" long:"stamp" description:"DNSCrypt resolver stamp. Param is required." required:"true"`
|
||||
|
@ -20,7 +20,7 @@ type LookupStampArgs struct {
|
|||
Type string `short:"t" long:"type" description:"DNS query type" default:"A"`
|
||||
}
|
||||
|
||||
// LookupArgs - "lookup" command arguments
|
||||
// LookupArgs is the "lookup" command arguments structure
|
||||
type LookupArgs struct {
|
||||
Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"`
|
||||
ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt resolver provider name. Param is required." required:"true"`
|
||||
|
@ -30,7 +30,7 @@ type LookupArgs struct {
|
|||
Type string `short:"t" long:"type" description:"DNS query type" default:"A"`
|
||||
}
|
||||
|
||||
// LookupResult - lookup result that contains the cert info and the query response
|
||||
// LookupResult is the lookup result that contains the cert info and the query response
|
||||
type LookupResult struct {
|
||||
Certificate struct {
|
||||
Serial uint32 `json:"serial"`
|
||||
|
@ -42,7 +42,7 @@ type LookupResult struct {
|
|||
Reply *dns.Msg `json:"reply"`
|
||||
}
|
||||
|
||||
// lookup - performs a DNS lookup, prints DNSCrypt info and lookup results
|
||||
// lookup performs a DNS lookup, prints DNSCrypt info and lookup results
|
||||
func lookup(args LookupArgs) {
|
||||
serverPk, err := dnscrypt.HexDecodeKey(args.PublicKey)
|
||||
if err != nil {
|
||||
|
@ -64,7 +64,7 @@ func lookup(args LookupArgs) {
|
|||
})
|
||||
}
|
||||
|
||||
// lookupStamp - performs a DNS lookup, prints DNSCrypt cert info and lookup results
|
||||
// lookupStamp performs a DNS lookup, prints DNSCrypt cert info and lookup results
|
||||
func lookupStamp(args LookupStampArgs) {
|
||||
c := &dnscrypt.Client{
|
||||
Net: args.Network,
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
|
||||
goFlags "github.com/jessevdk/go-flags"
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import (
|
|||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ServerArgs - "server" command arguments
|
||||
// ServerArgs is the "server" command arguments
|
||||
type ServerArgs struct {
|
||||
Config string `short:"c" long:"config" description:"Path to the DNSCrypt configuration file. Param is required." required:"true"`
|
||||
Forward string `short:"f" long:"forward" description:"Forwards DNS queries to the specified address" default:"94.140.14.140:53"`
|
||||
|
@ -21,7 +21,7 @@ type ServerArgs struct {
|
|||
ListenPorts []int `short:"p" long:"port" description:"Listening ports" default:"443"`
|
||||
}
|
||||
|
||||
// server - runs a DNSCrypt server
|
||||
// server runs a DNSCrypt server
|
||||
func server(args ServerArgs) {
|
||||
log.Info("Starting DNSCrypt server")
|
||||
|
||||
|
@ -72,7 +72,7 @@ func server(args ServerArgs) {
|
|||
}
|
||||
}
|
||||
|
||||
// createListeners - creates listeners for our server
|
||||
// createListeners creates listeners for our server
|
||||
func createListeners(args ServerArgs) (tcp []net.Listener, udp []*net.UDPConn) {
|
||||
for _, addr := range args.ListenAddrs {
|
||||
ip := net.ParseIP(addr)
|
||||
|
@ -101,7 +101,10 @@ type forwardHandler struct {
|
|||
addr string
|
||||
}
|
||||
|
||||
// ServeDNS - implements Handler interface
|
||||
// type check
|
||||
var _ dnscrypt.Handler = &forwardHandler{}
|
||||
|
||||
// ServeDNS implements Handler interface
|
||||
func (f *forwardHandler) ServeDNS(rw dnscrypt.ResponseWriter, r *dns.Msg) error {
|
||||
res, err := dns.Exchange(r, f.addr)
|
||||
if err != nil {
|
||||
|
|
78
constants.go
78
constants.go
|
@ -1,52 +1,58 @@
|
|||
package dnscrypt
|
||||
|
||||
import "errors"
|
||||
// Error represents a dnscrypt error.
|
||||
type Error string
|
||||
|
||||
var (
|
||||
// ErrTooShort - DNS query is shorter than possible
|
||||
ErrTooShort = errors.New("DNSCrypt message is too short")
|
||||
func (e Error) Error() string { return "dnscrypt: " + string(e) }
|
||||
|
||||
// ErrQueryTooLarge - DNS query is larger than max allowed size
|
||||
ErrQueryTooLarge = errors.New("DNSCrypt query is too large")
|
||||
const (
|
||||
// ErrTooShort means that the DNS query is shorter than possible
|
||||
ErrTooShort = Error("message is too short")
|
||||
|
||||
// ErrEsVersion - cert contains unsupported es-version
|
||||
ErrEsVersion = errors.New("unsupported es-version")
|
||||
// ErrQueryTooLarge means that the DNS query is larger than max allowed size
|
||||
ErrQueryTooLarge = Error("DNSCrypt query is too large")
|
||||
|
||||
// ErrInvalidDate - cert is not valid for the current time
|
||||
ErrInvalidDate = errors.New("cert has invalid ts-start or ts-end")
|
||||
// ErrEsVersion means that the cert contains unsupported es-version
|
||||
ErrEsVersion = Error("unsupported es-version")
|
||||
|
||||
// ErrInvalidCertSignature - cert has invalid signature
|
||||
ErrInvalidCertSignature = errors.New("cert has invalid signature")
|
||||
// ErrInvalidDate means that the cert is not valid for the current time
|
||||
ErrInvalidDate = Error("cert has invalid ts-start or ts-end")
|
||||
|
||||
// ErrInvalidQuery - failed to decrypt a DNSCrypt query
|
||||
ErrInvalidQuery = errors.New("DNSCrypt query is invalid and cannot be decrypted")
|
||||
// ErrInvalidCertSignature means that the cert has invalid signature
|
||||
ErrInvalidCertSignature = Error("cert has invalid signature")
|
||||
|
||||
// ErrInvalidClientMagic - client-magic does not match
|
||||
ErrInvalidClientMagic = errors.New("DNSCrypt query contains invalid client magic")
|
||||
// ErrInvalidQuery means that it failed to decrypt a DNSCrypt query
|
||||
ErrInvalidQuery = Error("DNSCrypt query is invalid and cannot be decrypted")
|
||||
|
||||
// ErrInvalidResolverMagic - server-magic does not match
|
||||
ErrInvalidResolverMagic = errors.New("DNSCrypt response contains invalid resolver magic")
|
||||
// ErrInvalidClientMagic means that client-magic does not match
|
||||
ErrInvalidClientMagic = Error("DNSCrypt query contains invalid client magic")
|
||||
|
||||
// ErrInvalidResponse - failed to decrypt a DNSCrypt response
|
||||
ErrInvalidResponse = errors.New("DNSCrypt response is invalid and cannot be decrypted")
|
||||
// ErrInvalidResolverMagic means that server-magic does not match
|
||||
ErrInvalidResolverMagic = Error("DNSCrypt response contains invalid resolver magic")
|
||||
|
||||
// ErrInvalidPadding - failed to unpad a query
|
||||
ErrInvalidPadding = errors.New("invalid padding")
|
||||
// ErrInvalidResponse means that it failed to decrypt a DNSCrypt response
|
||||
ErrInvalidResponse = Error("DNSCrypt response is invalid and cannot be decrypted")
|
||||
|
||||
// ErrInvalidDNSStamp - invalid DNS stamp
|
||||
ErrInvalidDNSStamp = errors.New("invalid DNS stamp")
|
||||
// ErrInvalidPadding means that it failed to unpad a query
|
||||
ErrInvalidPadding = Error("invalid padding")
|
||||
|
||||
// ErrFailedToFetchCert - failed to fetch DNSCrypt certificate
|
||||
ErrFailedToFetchCert = errors.New("failed to fetch DNSCrypt certificate")
|
||||
// ErrInvalidDNSStamp means an invalid DNS stamp
|
||||
ErrInvalidDNSStamp = Error("invalid DNS stamp")
|
||||
|
||||
// ErrCertTooShort - failed to deserialize cert, too short
|
||||
ErrCertTooShort = errors.New("cert is too short")
|
||||
// ErrFailedToFetchCert means that it failed to fetch DNSCrypt certificate
|
||||
ErrFailedToFetchCert = Error("failed to fetch DNSCrypt certificate")
|
||||
|
||||
// ErrCertMagic - invalid cert magic
|
||||
ErrCertMagic = errors.New("invalid cert magic")
|
||||
// ErrCertTooShort means that it failed to deserialize cert, too short
|
||||
ErrCertTooShort = Error("cert is too short")
|
||||
|
||||
// ErrServerConfig - failed to start the DNSCrypt server - invalid configuration
|
||||
ErrServerConfig = errors.New("invalid server configuration")
|
||||
// ErrCertMagic means an invalid cert magic
|
||||
ErrCertMagic = Error("invalid cert magic")
|
||||
|
||||
// ErrServerConfig means that it failed to start the DNSCrypt server - invalid configuration
|
||||
ErrServerConfig = Error("invalid server configuration")
|
||||
|
||||
// ErrServerNotStarted is returned if there's nothing to shutdown
|
||||
ErrServerNotStarted = Error("server is not started")
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -55,7 +61,7 @@ const (
|
|||
// Some servers do not work if padded length is less than 256. Example: Quad9
|
||||
minUDPQuestionSize = 256
|
||||
|
||||
// <max-query-len> - maximum allowed query length
|
||||
// <max-query-len> is the maximum allowed query length
|
||||
maxQueryLen = 1252
|
||||
|
||||
// Minimum possible DNS packet size
|
||||
|
@ -68,7 +74,7 @@ const (
|
|||
// size of the shared key used to encrypt/decrypt messages
|
||||
sharedKeySize = 32
|
||||
|
||||
// ClientMagic - the first 8 bytes of a client query that is to be built
|
||||
// ClientMagic is the first 8 bytes of a client query that is to be built
|
||||
// using the information from this certificate. It may be a truncated
|
||||
// public key. Two valid certificates cannot share the same <client-magic>.
|
||||
clientMagicSize = 8
|
||||
|
@ -82,10 +88,10 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
// certMagic - bytes sequence that must be in the beginning of the serialized cert
|
||||
// certMagic is a bytes sequence that must be in the beginning of the serialized cert
|
||||
certMagic = [4]byte{0x44, 0x4e, 0x53, 0x43}
|
||||
|
||||
// resolverMagic - byte sequence that must be in the beginning of every response
|
||||
// resolverMagic is a byte sequence that must be in the beginning of every response
|
||||
resolverMagic = []byte{0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38}
|
||||
)
|
||||
|
||||
|
|
|
@ -10,19 +10,19 @@ import (
|
|||
"golang.org/x/crypto/nacl/secretbox"
|
||||
)
|
||||
|
||||
// EncryptedQuery - a structure for encrypting and decrypting client queries
|
||||
// EncryptedQuery is a structure for encrypting and decrypting client queries
|
||||
//
|
||||
// <dnscrypt-query> ::= <client-magic> <client-pk> <client-nonce> <encrypted-query>
|
||||
// <encrypted-query> ::= AE(<shared-key> <client-nonce> <client-nonce-pad>, <client-query> <client-query-pad>)
|
||||
type EncryptedQuery struct {
|
||||
// EsVersion - encryption to use
|
||||
// EsVersion is the encryption to use
|
||||
EsVersion CryptoConstruction
|
||||
|
||||
// ClientMagic - a 8 byte identifier for the resolver certificate
|
||||
// ClientMagic is a 8 byte identifier for the resolver certificate
|
||||
// chosen by the client.
|
||||
ClientMagic [clientMagicSize]byte
|
||||
|
||||
// ClientPk - the client's public key
|
||||
// ClientPk is the client's public key
|
||||
ClientPk [keySize]byte
|
||||
|
||||
// With a 24 bytes nonce, a question sent by a DNSCrypt client must be
|
||||
|
@ -36,7 +36,7 @@ type EncryptedQuery struct {
|
|||
Nonce [nonceSize]byte
|
||||
}
|
||||
|
||||
// Encrypt - encrypts the specified DNS query, returns encrypted data ready to be sent.
|
||||
// Encrypt encrypts the specified DNS query, returns encrypted data ready to be sent.
|
||||
//
|
||||
// Note that this method will generate a random nonce automatically.
|
||||
//
|
||||
|
@ -79,7 +79,7 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) (
|
|||
return query, nil
|
||||
}
|
||||
|
||||
// Decrypt - decrypts the client query, returns decrypted DNS packet.
|
||||
// Decrypt decrypts the client query, returns decrypted DNS packet.
|
||||
//
|
||||
// Please note, that before calling this method the following fields must be set:
|
||||
// * ClientMagic -- to verify the query
|
||||
|
|
|
@ -23,7 +23,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction)
|
|||
|
||||
// Generate client shared key
|
||||
clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
clientMagic := [clientMagicSize]byte{}
|
||||
_, _ = rand.Read(clientMagic[:])
|
||||
|
@ -40,7 +40,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction)
|
|||
|
||||
// Encrypt it
|
||||
encrypted, err := q1.Encrypt(packet, clientSharedKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Now let's try decrypting it
|
||||
q2 := EncryptedQuery{
|
||||
|
@ -50,7 +50,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction)
|
|||
|
||||
// Decrypt it
|
||||
decrypted, err := q2.Decrypt(encrypted, serverSecretKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that packet is the same
|
||||
assert.True(t, bytes.Equal(packet, decrypted))
|
||||
|
|
|
@ -10,12 +10,12 @@ import (
|
|||
"golang.org/x/crypto/nacl/secretbox"
|
||||
)
|
||||
|
||||
// EncryptedResponse - structure for encrypting/decrypting server responses
|
||||
// EncryptedResponse is a structure for encrypting/decrypting server responses
|
||||
//
|
||||
// <dnscrypt-response> ::= <resolver-magic> <nonce> <encrypted-response>
|
||||
// <encrypted-response> ::= AE(<shared-key>, <nonce>, <resolver-response> <resolver-response-pad>)
|
||||
type EncryptedResponse struct {
|
||||
// EsVersion - encryption to use
|
||||
// EsVersion is the encryption to use
|
||||
EsVersion CryptoConstruction
|
||||
|
||||
// Nonce - <nonce> ::= <client-nonce> <resolver-nonce>
|
||||
|
@ -23,7 +23,7 @@ type EncryptedResponse struct {
|
|||
Nonce [nonceSize]byte
|
||||
}
|
||||
|
||||
// Encrypt - encrypts the server response
|
||||
// Encrypt encrypts the server response
|
||||
//
|
||||
// EsVersion must be set.
|
||||
// Nonce needs to be set to "client-nonce".
|
||||
|
@ -57,7 +57,7 @@ func (r *EncryptedResponse) Encrypt(packet []byte, sharedKey [sharedKeySize]byte
|
|||
return response, nil
|
||||
}
|
||||
|
||||
// Decrypt - decrypts the server response
|
||||
// Decrypt decrypts the server response
|
||||
//
|
||||
// EsVersion must be set.
|
||||
func (r *EncryptedResponse) Decrypt(response []byte, sharedKey [sharedKeySize]byte) ([]byte, error) {
|
||||
|
|
|
@ -24,11 +24,11 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi
|
|||
|
||||
// Generate client shared key
|
||||
clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Generate server shared key
|
||||
serverSharedKey, err := computeSharedKey(esVersion, &serverSecretKey, &clientPublicKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
r1 := &EncryptedResponse{
|
||||
EsVersion: esVersion,
|
||||
|
@ -42,7 +42,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi
|
|||
|
||||
// Encrypt it
|
||||
encrypted, err := r1.Encrypt(packet, serverSharedKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Now let's try decrypting it
|
||||
r2 := &EncryptedResponse{
|
||||
|
@ -51,7 +51,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi
|
|||
|
||||
// Decrypt it
|
||||
decrypted, err := r2.Decrypt(encrypted, clientSharedKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check that packet is the same
|
||||
assert.True(t, bytes.Equal(packet, decrypted))
|
||||
|
|
27
generate.go
27
generate.go
|
@ -14,36 +14,37 @@ import (
|
|||
|
||||
const dnsCryptV2Prefix = "2.dnscrypt-cert."
|
||||
|
||||
// ResolverConfig - DNSCrypt resolver configuration
|
||||
// ResolverConfig is the DNSCrypt resolver configuration
|
||||
type ResolverConfig struct {
|
||||
// DNSCrypt provider name
|
||||
ProviderName string `yaml:"provider_name"`
|
||||
|
||||
// PublicKey - DNSCrypt resolver public key
|
||||
// PublicKey is the DNSCrypt resolver public key
|
||||
PublicKey string `yaml:"public_key"`
|
||||
|
||||
// PrivateKey - DNSCrypt resolver private key
|
||||
// PrivateKey is the DNSCrypt resolver private key
|
||||
// The main and only purpose of this key is to sign the certificate
|
||||
PrivateKey string `yaml:"private_key"`
|
||||
|
||||
// ResolverSk - hex-encoded short-term private key.
|
||||
// ResolverSk is a hex-encoded short-term private key.
|
||||
// This key is used to encrypt/decrypt DNS queries.
|
||||
// If not set, we'll generate a new random ResolverSk and ResolverPk.
|
||||
ResolverSk string `yaml:"resolver_secret"`
|
||||
|
||||
// ResolverPk - hex-encoded short-term public key corresponding to ResolverSk.
|
||||
// ResolverPk is a hex-encoded short-term public key corresponding to ResolverSk.
|
||||
// This key is used to encrypt/decrypt DNS queries.
|
||||
ResolverPk string `yaml:"resolver_public"`
|
||||
|
||||
// EsVersion - crypto to use in this resolver
|
||||
// EsVersion is the crypto to use in this resolver
|
||||
EsVersion CryptoConstruction `yaml:"es_version"`
|
||||
|
||||
// CertificateTTL - time-to-live for the certificate that is generated using this ResolverConfig.
|
||||
// CertificateTTL is the time-to-live value for the certificate that is
|
||||
// generated using this ResolverConfig.
|
||||
// If not set, we'll use 1 year by default.
|
||||
CertificateTTL time.Duration `yaml:"certificate_ttl"`
|
||||
}
|
||||
|
||||
// CreateCert - generates a signed Cert to be used by Server
|
||||
// CreateCert generates a signed Cert to be used by Server
|
||||
func (rc *ResolverConfig) CreateCert() (*Cert, error) {
|
||||
log.Printf("Creating signed DNSCrypt certificate")
|
||||
|
||||
|
@ -98,7 +99,7 @@ func (rc *ResolverConfig) CreateCert() (*Cert, error) {
|
|||
return cert, nil
|
||||
}
|
||||
|
||||
// CreateStamp - generates a DNS stamp for this resolver
|
||||
// CreateStamp generates a DNS stamp for this resolver
|
||||
func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error) {
|
||||
stamp := dnsstamps.ServerStamp{
|
||||
ProviderName: rc.ProviderName,
|
||||
|
@ -115,7 +116,7 @@ func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error
|
|||
return stamp, nil
|
||||
}
|
||||
|
||||
// GenerateResolverConfig - generates resolver configuration for a given provider name.
|
||||
// GenerateResolverConfig generates resolver configuration for a given provider name.
|
||||
// providerName is mandatory. If needed, "2.dnscrypt-cert." prefix is added to it.
|
||||
// privateKey is optional. If not set, it will be generated automatically.
|
||||
func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey) (ResolverConfig, error) {
|
||||
|
@ -145,18 +146,18 @@ func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey)
|
|||
return rc, nil
|
||||
}
|
||||
|
||||
// HexEncodeKey - encodes a byte slice to a hex-encoded string.
|
||||
// HexEncodeKey encodes a byte slice to a hex-encoded string.
|
||||
func HexEncodeKey(b []byte) string {
|
||||
return strings.ToUpper(hex.EncodeToString(b))
|
||||
}
|
||||
|
||||
// HexDecodeKey - decodes a hex-encoded string with (optional) colons
|
||||
// HexDecodeKey decodes a hex-encoded string with (optional) colons
|
||||
// to a byte array.
|
||||
func HexDecodeKey(str string) ([]byte, error) {
|
||||
return hex.DecodeString(strings.ReplaceAll(str, ":", ""))
|
||||
}
|
||||
|
||||
// generateRandomKeyPair - generates a random key-pair
|
||||
// generateRandomKeyPair generates a random key-pair
|
||||
func generateRandomKeyPair() (privateKey [keySize]byte, publicKey [keySize]byte) {
|
||||
privateKey = [keySize]byte{}
|
||||
publicKey = [keySize]byte{}
|
||||
|
|
|
@ -15,24 +15,24 @@ func TestHexEncodeKey(t *testing.T) {
|
|||
|
||||
func TestHexDecodeKey(t *testing.T) {
|
||||
b, err := HexDecodeKey("01:02:03:04")
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, bytes.Equal(b, []byte{1, 2, 3, 4}))
|
||||
}
|
||||
|
||||
func TestGenerateResolverConfig(t *testing.T) {
|
||||
rc, err := GenerateResolverConfig("example.org", nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2.dnscrypt-cert.example.org", rc.ProviderName)
|
||||
assert.Equal(t, ed25519.PrivateKeySize*2, len(rc.PrivateKey))
|
||||
assert.Equal(t, keySize*2, len(rc.ResolverSk))
|
||||
assert.Equal(t, keySize*2, len(rc.ResolverPk))
|
||||
|
||||
cert, err := rc.CreateCert()
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, cert.VerifyDate())
|
||||
|
||||
publicKey, err := HexDecodeKey(rc.PublicKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, cert.VerifySignature(publicKey))
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -6,7 +6,7 @@ require (
|
|||
github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635
|
||||
github.com/ameshkov/dnsstamps v1.0.1
|
||||
github.com/jessevdk/go-flags v1.4.0
|
||||
github.com/miekg/dns v1.1.29
|
||||
github.com/miekg/dns v1.1.40
|
||||
github.com/stretchr/testify v1.6.1
|
||||
golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
|
||||
|
|
2
go.sum
2
go.sum
|
@ -14,6 +14,8 @@ github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGAR
|
|||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg=
|
||||
github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/miekg/dns v1.1.40 h1:pyyPFfGMnciYUk/mXpKkVmeMQjfXqt3FAJ2hy7tPiLA=
|
||||
github.com/miekg/dns v1.1.40/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
|
|
@ -14,14 +14,14 @@ type Handler interface {
|
|||
ServeDNS(rw ResponseWriter, r *dns.Msg) error
|
||||
}
|
||||
|
||||
// ResponseWriter - interface that needs to be implemented for different protocols
|
||||
// ResponseWriter is the interface that needs to be implemented for different protocols
|
||||
type ResponseWriter interface {
|
||||
LocalAddr() net.Addr // LocalAddr - local socket address
|
||||
RemoteAddr() net.Addr // RemoteAddr - remote client socket address
|
||||
WriteMsg(m *dns.Msg) error // WriteMsg - writes response message to the client
|
||||
}
|
||||
|
||||
// DefaultHandler - default Handler implementation
|
||||
// DefaultHandler is the default Handler implementation
|
||||
// that is used by Server if custom handler is not configured
|
||||
var DefaultHandler Handler = &defaultHandler{
|
||||
udpClient: &dns.Client{
|
||||
|
@ -41,7 +41,7 @@ type defaultHandler struct {
|
|||
addr string
|
||||
}
|
||||
|
||||
// ServeDNS - implements Handler interface
|
||||
// ServeDNS implements Handler interface
|
||||
func (h *defaultHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error {
|
||||
// Google DNS
|
||||
res, _, err := h.udpClient.Exchange(r, h.addr)
|
||||
|
|
174
server.go
174
server.go
|
@ -1,27 +1,167 @@
|
|||
package dnscrypt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Server - a simple DNSCrypt server implementation
|
||||
// default read timeout for all reads
|
||||
const defaultReadTimeout = 2 * time.Second
|
||||
|
||||
// in case of TCP we only use defaultReadTimeout for the first read
|
||||
// then we start using defaultTCPIdleTimeout
|
||||
const defaultTCPIdleTimeout = 8 * time.Second
|
||||
|
||||
// helper struct that is used in several SetReadDeadline calls
|
||||
var longTimeAgo = time.Unix(1, 0)
|
||||
|
||||
// ServerDNSCrypt is an interface for a DNSCrypt server
|
||||
type ServerDNSCrypt interface {
|
||||
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
|
||||
// It blocks the calling goroutine and to stop it you need to close the listener
|
||||
// or call ServerDNSCrypt.Shutdown.
|
||||
ServeTCP(l net.Listener) error
|
||||
|
||||
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
|
||||
// It blocks the calling goroutine and to stop it you need to close the listener
|
||||
// or call ServerDNSCrypt.Shutdown.
|
||||
ServeUDP(l *net.UDPConn) error
|
||||
|
||||
// Shutdown tries to gracefully shutdown the server. It waits until all
|
||||
// connections are processed and only after that it leaves the method.
|
||||
// If context deadline is specified, it will exit earlier
|
||||
// or call ServerDNSCrypt.Shutdown.
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Server is a simple DNSCrypt server implementation
|
||||
type Server struct {
|
||||
// ProviderName - DNSCrypt provider name
|
||||
// ProviderName is a DNSCrypt provider name
|
||||
ProviderName string
|
||||
|
||||
// ResolverCert - contains resolver certificate.
|
||||
// ResolverCert contains resolver certificate.
|
||||
ResolverCert *Cert
|
||||
|
||||
// Handler to invoke. If nil, uses DefaultHandler.
|
||||
Handler Handler
|
||||
|
||||
// make sure init is called only once
|
||||
initOnce sync.Once
|
||||
|
||||
// Shutdown handling
|
||||
// --
|
||||
lock sync.RWMutex // protects access to all the fields below
|
||||
started bool
|
||||
wg sync.WaitGroup // active workers (servers)
|
||||
tcpListeners map[net.Listener]struct{} // track active TCP listeners
|
||||
udpListeners map[*net.UDPConn]struct{} // track active UDP listeners
|
||||
tcpConns map[net.Conn]struct{} // track active connections
|
||||
}
|
||||
|
||||
// serveDNS - serves DNS response
|
||||
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) {
|
||||
// type check
|
||||
var _ ServerDNSCrypt = &Server{}
|
||||
|
||||
// prepareShutdown - prepares the server to shutdown:
|
||||
// unblocks reads from all connections related to this server
|
||||
// marks the server as stopped
|
||||
// if the server is not started, returns ErrServerNotStarted
|
||||
func (s *Server) prepareShutdown() error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if !s.started {
|
||||
log.Info("Server is not started")
|
||||
return ErrServerNotStarted
|
||||
}
|
||||
|
||||
s.started = false
|
||||
|
||||
// These listeners were passed to us from the outside so we cannot close
|
||||
// them here - this is up to the calling code to do that. Instead of that,
|
||||
// we call Set(Read)Deadline to unblock goroutines that are currently
|
||||
// blocked on reading from those listeners.
|
||||
// For tcpConns we would like to avoid closing them to be able to process
|
||||
// queries before shutting everything down.
|
||||
|
||||
// Unblock reads for all active tcpConns
|
||||
for conn := range s.tcpConns {
|
||||
_ = conn.SetReadDeadline(longTimeAgo)
|
||||
}
|
||||
|
||||
// Unblock reads for all active TCP listeners
|
||||
for l := range s.tcpListeners {
|
||||
switch v := l.(type) {
|
||||
case *net.TCPListener:
|
||||
_ = v.SetDeadline(longTimeAgo)
|
||||
}
|
||||
}
|
||||
|
||||
// Unblock reads for all active UDP listeners
|
||||
for l := range s.udpListeners {
|
||||
_ = l.SetReadDeadline(longTimeAgo)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown tries to gracefully shutdown the server. It waits until all
|
||||
// connections are processed and only after that it leaves the method.
|
||||
// If context deadline is specified, it will exit earlier.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
log.Info("Shutting down the DNSCrypt server")
|
||||
|
||||
err := s.prepareShutdown()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Using this channel to wait until all goroutines finish their work
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
log.Info("Serve goroutines finished their work")
|
||||
close(closed)
|
||||
}()
|
||||
|
||||
// Wait for either all goroutines finish their work
|
||||
// Or for the context deadline
|
||||
select {
|
||||
case <-closed:
|
||||
log.Info("DNSCrypt server has been stopped")
|
||||
case <-ctx.Done():
|
||||
log.Info("DNSCrypt server shutdown has timed out")
|
||||
err = ctx.Err()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// init initializes (lazily) Server properties on startup
|
||||
// this method is called from Server.ServeTCP and Server.ServeUDP
|
||||
func (s *Server) init() {
|
||||
s.tcpConns = map[net.Conn]struct{}{}
|
||||
s.udpListeners = map[*net.UDPConn]struct{}{}
|
||||
s.tcpListeners = map[net.Listener]struct{}{}
|
||||
}
|
||||
|
||||
// isStarted returns true if the server is processing queries right now
|
||||
// it means that Server.ServeTCP and/or Server.ServeUDP have been called
|
||||
func (s *Server) isStarted() bool {
|
||||
s.lock.RLock()
|
||||
started := s.started
|
||||
s.lock.RUnlock()
|
||||
return started
|
||||
}
|
||||
|
||||
// serveDNS serves DNS response
|
||||
func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error {
|
||||
if r == nil || len(r.Question) != 1 || r.Response {
|
||||
log.Tracef("Invalid query: %v", r)
|
||||
return
|
||||
return ErrInvalidQuery
|
||||
}
|
||||
|
||||
log.Tracef("Handling a DNS query: %s", r.Question[0].Name)
|
||||
|
@ -39,9 +179,11 @@ func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) {
|
|||
reply.SetRcode(r, dns.RcodeServerFailure)
|
||||
_ = rw.WriteMsg(reply)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encrypt - encrypts DNSCrypt response
|
||||
// encrypt encrypts DNSCrypt response
|
||||
func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
|
||||
r := EncryptedResponse{
|
||||
EsVersion: q.EsVersion,
|
||||
|
@ -60,7 +202,7 @@ func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) {
|
|||
return r.Encrypt(packet, sharedKey)
|
||||
}
|
||||
|
||||
// decrypt - decrypts the incoming message and returns a DNS message to process
|
||||
// decrypt decrypts the incoming message and returns a DNS message to process
|
||||
func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
|
||||
q := EncryptedQuery{
|
||||
EsVersion: s.ResolverCert.EsVersion,
|
||||
|
@ -82,7 +224,7 @@ func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) {
|
|||
return r, q, nil
|
||||
}
|
||||
|
||||
// handleHandshake - handles a TXT request that requests certificate data
|
||||
// handleHandshake handles a TXT request that requests certificate data
|
||||
func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
|
||||
m := new(dns.Msg)
|
||||
err := m.Unpack(b)
|
||||
|
@ -120,7 +262,7 @@ func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) {
|
|||
return reply.Pack()
|
||||
}
|
||||
|
||||
// validate - checks if the Server config is properly set
|
||||
// validate checks if the Server config is properly set
|
||||
func (s *Server) validate() bool {
|
||||
if s.ResolverCert == nil {
|
||||
log.Error("ResolverCert must be set")
|
||||
|
@ -139,3 +281,13 @@ func (s *Server) validate() bool {
|
|||
|
||||
return true
|
||||
}
|
||||
|
||||
// getCertTXT serializes the cert TXT record that are to be sent to the client
|
||||
func (s *Server) getCertTXT() (string, error) {
|
||||
certBuf, err := s.ResolverCert.Serialize()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
certTxt := packTxtString(certBuf)
|
||||
return certTxt, nil
|
||||
}
|
||||
|
|
60
server_bench_test.go
Normal file
60
server_bench_test.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package dnscrypt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ameshkov/dnsstamps"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkServeUDP(b *testing.B) {
|
||||
benchmarkServe(b, "udp")
|
||||
}
|
||||
|
||||
func BenchmarkServeTCP(b *testing.B) {
|
||||
benchmarkServe(b, "tcp")
|
||||
}
|
||||
|
||||
func benchmarkServe(b *testing.B, network string) {
|
||||
srv := newTestServer(b, &testHandler{})
|
||||
b.Cleanup(func() {
|
||||
err := srv.Close()
|
||||
require.NoError(b, err)
|
||||
})
|
||||
|
||||
client := &Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Net: network,
|
||||
}
|
||||
|
||||
serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port)
|
||||
if network == "tcp" {
|
||||
serverAddr = fmt.Sprintf("127.0.0.1:%d", srv.TCPAddr().Port)
|
||||
}
|
||||
|
||||
stamp := dnsstamps.ServerStamp{
|
||||
ServerAddrStr: serverAddr,
|
||||
ServerPk: srv.resolverPk,
|
||||
ProviderName: srv.server.ProviderName,
|
||||
Proto: dnsstamps.StampProtoTypeDNSCrypt,
|
||||
}
|
||||
ri, err := client.DialStamp(stamp)
|
||||
require.NoError(b, err)
|
||||
require.NotNil(b, ri)
|
||||
|
||||
conn, err := net.Dial(network, stamp.ServerAddrStr)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
m := createTestMessage()
|
||||
res, err := client.ExchangeConn(conn, m, ri)
|
||||
require.NoError(b, err)
|
||||
assertTestMessageResponse(b, res)
|
||||
}
|
||||
b.StopTimer()
|
||||
}
|
201
server_tcp.go
201
server_tcp.go
|
@ -2,13 +2,17 @@ package dnscrypt
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// TCPResponseWriter - ResponseWriter implementation for TCP
|
||||
// TCPResponseWriter is the ResponseWriter implementation for TCP
|
||||
type TCPResponseWriter struct {
|
||||
tcpConn net.Conn
|
||||
encrypt encryptionFunc
|
||||
|
@ -19,17 +23,17 @@ type TCPResponseWriter struct {
|
|||
// type check
|
||||
var _ ResponseWriter = &TCPResponseWriter{}
|
||||
|
||||
// LocalAddr - server socket local address
|
||||
// LocalAddr is the server socket local address
|
||||
func (w *TCPResponseWriter) LocalAddr() net.Addr {
|
||||
return w.tcpConn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr - client's address
|
||||
// RemoteAddr is the client's address
|
||||
func (w *TCPResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.tcpConn.RemoteAddr()
|
||||
}
|
||||
|
||||
// WriteMsg - writes DNS message to the client
|
||||
// WriteMsg writes DNS message to the client
|
||||
func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
m.Truncate(dnsSize("tcp", w.req))
|
||||
|
||||
|
@ -42,33 +46,44 @@ func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error {
|
|||
return writePrefixed(res, w.tcpConn)
|
||||
}
|
||||
|
||||
// ServeTCP - listens to TCP connections, queries are then processed by Server.Handler.
|
||||
// It blocks the calling goroutine and to stop it you need to close the listener.
|
||||
// ServeTCP listens to TCP connections, queries are then processed by Server.Handler.
|
||||
// It blocks the calling goroutine and to stop it you need to close the listener
|
||||
// or call Server.Shutdown.
|
||||
func (s *Server) ServeTCP(l net.Listener) error {
|
||||
// Check that server is properly configured
|
||||
if !s.validate() {
|
||||
return ErrServerConfig
|
||||
}
|
||||
|
||||
// Serialize the cert right away and prepare it to be sent to the client
|
||||
certBuf, err := s.ResolverCert.Serialize()
|
||||
err := s.prepareServeTCP(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certTxt := packTxtString(certBuf)
|
||||
|
||||
log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr().String())
|
||||
log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr())
|
||||
|
||||
for {
|
||||
// Tracks TCP connection handling goroutines
|
||||
tcpWg := &sync.WaitGroup{}
|
||||
defer s.cleanUpTCP(tcpWg, l)
|
||||
|
||||
// Track active goroutine
|
||||
s.wg.Add(1)
|
||||
|
||||
// Serialize the cert right away and prepare it to be sent to the client
|
||||
certTxt, err := s.getCertTXT()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for s.isStarted() {
|
||||
conn, err := l.Accept()
|
||||
if err == nil {
|
||||
go func() {
|
||||
_ = s.handleTCPConnection(conn, certTxt)
|
||||
_ = conn.Close()
|
||||
}()
|
||||
}
|
||||
|
||||
// Check the error code and exit loop if necessary
|
||||
if err != nil {
|
||||
if !s.isStarted() {
|
||||
// Stopped gracefully
|
||||
break
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Temporary() {
|
||||
// Note that timeout errors will be here (i.e. hitting ReadDeadline)
|
||||
continue
|
||||
}
|
||||
if isConnClosed(err) {
|
||||
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
||||
} else {
|
||||
|
@ -76,47 +91,129 @@ func (s *Server) ServeTCP(l net.Listener) error {
|
|||
}
|
||||
break
|
||||
}
|
||||
|
||||
// If we got here, the connection is alive
|
||||
s.lock.Lock()
|
||||
// Track the connection to allow unblocking reads on shutdown.
|
||||
s.tcpConns[conn] = struct{}{}
|
||||
s.lock.Unlock()
|
||||
|
||||
tcpWg.Add(1)
|
||||
go func() {
|
||||
// Ignore error here, it is most probably a legit one
|
||||
// if not, it's written to the debug log
|
||||
_ = s.handleTCPConnection(conn, certTxt)
|
||||
|
||||
// Clean up
|
||||
_ = conn.Close()
|
||||
s.lock.Lock()
|
||||
delete(s.tcpConns, conn)
|
||||
s.lock.Unlock()
|
||||
tcpWg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareServeTCP prepares the server and listener to serving DNSCrypt
|
||||
func (s *Server) prepareServeTCP(l net.Listener) error {
|
||||
// Check that server is properly configured
|
||||
if !s.validate() {
|
||||
return ErrServerConfig
|
||||
}
|
||||
|
||||
// Protect shutdown-related fields
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.initOnce.Do(s.init)
|
||||
|
||||
// Mark the server as started if needed
|
||||
s.started = true
|
||||
|
||||
// Track an active TCP listener
|
||||
s.tcpListeners[l] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanUpTCP waits until all TCP messages before cleaning up
|
||||
func (s *Server) cleanUpTCP(tcpWg *sync.WaitGroup, l net.Listener) {
|
||||
// Wait until all TCP connections are processed
|
||||
tcpWg.Wait()
|
||||
|
||||
// Not using it anymore so can be removed from the active listeners
|
||||
s.lock.Lock()
|
||||
delete(s.tcpListeners, l)
|
||||
s.lock.Unlock()
|
||||
|
||||
// The work is finished
|
||||
s.wg.Done()
|
||||
}
|
||||
|
||||
// handleTCPMsg handles a single TCP message. If this method returns error
|
||||
// the connection will be closed
|
||||
func (s *Server) handleTCPMsg(b []byte, conn net.Conn, certTxt string) error {
|
||||
if len(b) < minDNSPacketSize {
|
||||
// Ignore the packets that are too short
|
||||
return ErrTooShort
|
||||
}
|
||||
|
||||
// First of all, check for "ClientMagic" in the incoming query
|
||||
if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
|
||||
// If there's no ClientMagic in the packet, we assume this
|
||||
// is a plain DNS query requesting the certificate data
|
||||
reply, err := s.handleHandshake(b, certTxt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process a plain DNS query: %w", err)
|
||||
}
|
||||
err = writePrefixed(reply, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write a response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we got here, this is an encrypted DNSCrypt message
|
||||
// We should decrypt it first to get the plain DNS query
|
||||
m, q, err := s.decrypt(b)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt incoming message: %w", err)
|
||||
}
|
||||
rw := &TCPResponseWriter{
|
||||
tcpConn: conn,
|
||||
encrypt: s.encrypt,
|
||||
req: m,
|
||||
query: q,
|
||||
}
|
||||
err = s.serveDNS(rw, m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process a DNS query: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleTCPConnection handles all queries that are coming to the
|
||||
// specified TCP connection.
|
||||
func (s *Server) handleTCPConnection(conn net.Conn, certTxt string) error {
|
||||
for {
|
||||
timeout := defaultReadTimeout
|
||||
|
||||
for s.isStarted() {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
|
||||
b, err := readPrefixed(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(b) < minDNSPacketSize {
|
||||
// Ignore the packets that are too short
|
||||
return ErrTooShort
|
||||
|
||||
err = s.handleTCPMsg(b, conn, certTxt)
|
||||
if err != nil {
|
||||
log.Debug("failed to process DNS query: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
|
||||
// This is an encrypted message, we should decrypt it
|
||||
m, q, err := s.decrypt(b)
|
||||
if err != nil {
|
||||
log.Tracef("failed to decrypt incoming message: %v", err)
|
||||
return err
|
||||
}
|
||||
rw := &TCPResponseWriter{
|
||||
tcpConn: conn,
|
||||
encrypt: s.encrypt,
|
||||
req: m,
|
||||
query: q,
|
||||
}
|
||||
s.serveDNS(rw, m)
|
||||
} else {
|
||||
// Most likely this a DNS message requesting the certificate
|
||||
reply, err := s.handleHandshake(b, certTxt)
|
||||
if err != nil {
|
||||
log.Tracef("Failed to process a plain DNS query: %v", err)
|
||||
return err
|
||||
}
|
||||
err = writePrefixed(reply, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
timeout = defaultTCPIdleTimeout
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,9 +2,11 @@ package dnscrypt
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -13,25 +15,52 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestServerUDPServeCert(t *testing.T) {
|
||||
func TestServer_Shutdown(t *testing.T) {
|
||||
n := runtime.GOMAXPROCS(1)
|
||||
t.Cleanup(func() {
|
||||
runtime.GOMAXPROCS(n)
|
||||
})
|
||||
srv := newTestServer(t, &testHandler{})
|
||||
// Serve* methods are called in different goroutines
|
||||
// give them at least a moment to actually start the server
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
assert.NoError(t, srv.Close())
|
||||
}
|
||||
|
||||
func TestServer_UDPServeCert(t *testing.T) {
|
||||
testServerServeCert(t, "udp")
|
||||
}
|
||||
|
||||
func TestServerTCPServeCert(t *testing.T) {
|
||||
func TestServer_TCPServeCert(t *testing.T) {
|
||||
testServerServeCert(t, "tcp")
|
||||
}
|
||||
|
||||
func TestServerUDPRespondMessages(t *testing.T) {
|
||||
func TestServer_UDPRespondMessages(t *testing.T) {
|
||||
testServerRespondMessages(t, "udp")
|
||||
}
|
||||
|
||||
func TestServerTCPRespondMessages(t *testing.T) {
|
||||
func TestServer_TCPRespondMessages(t *testing.T) {
|
||||
testServerRespondMessages(t, "tcp")
|
||||
}
|
||||
|
||||
func TestServer_ReadTimeout(t *testing.T) {
|
||||
srv := newTestServer(t, &testHandler{})
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, srv.Close())
|
||||
})
|
||||
// Sleep for "defaultReadTimeout" before trying to shutdown the server
|
||||
// The point is to make sure readTimeout is properly handled by
|
||||
// the "Serve*" goroutines and they don't finish their work unexpectedly
|
||||
time.Sleep(defaultReadTimeout)
|
||||
testThisServerRespondMessages(t, "udp", srv)
|
||||
testThisServerRespondMessages(t, "tcp", srv)
|
||||
}
|
||||
|
||||
func testServerServeCert(t *testing.T, network string) {
|
||||
srv := newTestServer(t, &testHandler{})
|
||||
defer srv.Close()
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, srv.Close())
|
||||
})
|
||||
|
||||
client := &Client{
|
||||
Net: network,
|
||||
|
@ -50,7 +79,7 @@ func testServerServeCert(t *testing.T, network string) {
|
|||
Proto: dnsstamps.StampProtoTypeDNSCrypt,
|
||||
}
|
||||
ri, err := client.DialStamp(stamp)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ri)
|
||||
|
||||
assert.Equal(t, ri.ProviderName, srv.server.ProviderName)
|
||||
|
@ -65,8 +94,13 @@ func testServerServeCert(t *testing.T, network string) {
|
|||
|
||||
func testServerRespondMessages(t *testing.T, network string) {
|
||||
srv := newTestServer(t, &testHandler{})
|
||||
defer srv.Close()
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, srv.Close())
|
||||
})
|
||||
testThisServerRespondMessages(t, network, srv)
|
||||
}
|
||||
|
||||
func testThisServerRespondMessages(t *testing.T, network string, srv *testServer) {
|
||||
client := &Client{
|
||||
Timeout: 1 * time.Second,
|
||||
Net: network,
|
||||
|
@ -84,16 +118,16 @@ func testServerRespondMessages(t *testing.T, network string) {
|
|||
Proto: dnsstamps.StampProtoTypeDNSCrypt,
|
||||
}
|
||||
ri, err := client.DialStamp(stamp)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, ri)
|
||||
|
||||
conn, err := net.Dial(network, stamp.ServerAddrStr)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
m := createTestMessage()
|
||||
res, err := client.ExchangeConn(conn, m, ri)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
assertTestMessageResponse(t, res)
|
||||
}
|
||||
}
|
||||
|
@ -114,16 +148,22 @@ func (s *testServer) UDPAddr() *net.UDPAddr {
|
|||
return s.udpConn.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
func (s *testServer) Close() {
|
||||
func (s *testServer) Close() error {
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second))
|
||||
defer cancel()
|
||||
|
||||
err := s.server.Shutdown(ctx)
|
||||
_ = s.udpConn.Close()
|
||||
_ = s.tcpListen.Close()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T, handler Handler) *testServer {
|
||||
func newTestServer(t assert.TestingT, handler Handler) *testServer {
|
||||
rc, err := GenerateResolverConfig("example.org", nil)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
cert, err := rc.CreateCert()
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := &Server{
|
||||
ProviderName: rc.ProviderName,
|
||||
|
@ -132,7 +172,7 @@ func newTestServer(t *testing.T, handler Handler) *testServer {
|
|||
}
|
||||
|
||||
privateKey, err := HexDecodeKey(rc.PrivateKey)
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
publicKey := ed25519.PrivateKey(privateKey).Public().(ed25519.PublicKey)
|
||||
srv := &testServer{
|
||||
server: s,
|
||||
|
@ -140,9 +180,9 @@ func newTestServer(t *testing.T, handler Handler) *testServer {
|
|||
}
|
||||
|
||||
srv.tcpListen, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
srv.udpConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
assert.Nil(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
go s.ServeUDP(srv.udpConn)
|
||||
go s.ServeTCP(srv.tcpListen)
|
||||
|
|
230
server_udp.go
230
server_udp.go
|
@ -2,38 +2,43 @@ package dnscrypt
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error)
|
||||
|
||||
// UDPResponseWriter - ResponseWriter implementation for UDP
|
||||
// UDPResponseWriter is the ResponseWriter implementation for UDP
|
||||
type UDPResponseWriter struct {
|
||||
udpConn *net.UDPConn // UDP connection
|
||||
remoteAddr *net.UDPAddr // Remote peer address
|
||||
localIP net.IP // Local IP (that was used to accept the remote connection)
|
||||
encrypt encryptionFunc // DNSCRypt encryption function
|
||||
req *dns.Msg // DNS query that was processed
|
||||
query EncryptedQuery // DNSCrypt query properties
|
||||
udpConn *net.UDPConn // UDP connection
|
||||
sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP)
|
||||
encrypt encryptionFunc // DNSCRypt encryption function
|
||||
req *dns.Msg // DNS query that was processed
|
||||
query EncryptedQuery // DNSCrypt query properties
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ ResponseWriter = &UDPResponseWriter{}
|
||||
|
||||
// LocalAddr - server socket local address
|
||||
// LocalAddr is the server socket local address
|
||||
func (w *UDPResponseWriter) LocalAddr() net.Addr {
|
||||
return w.udpConn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr - client's address
|
||||
// RemoteAddr is the client's address
|
||||
func (w *UDPResponseWriter) RemoteAddr() net.Addr {
|
||||
return w.remoteAddr
|
||||
return w.udpConn.RemoteAddr()
|
||||
}
|
||||
|
||||
// WriteMsg - writes DNS message to the client
|
||||
// WriteMsg writes DNS message to the client
|
||||
func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||
m.Truncate(dnsSize("udp", w.req))
|
||||
|
||||
|
@ -42,83 +47,176 @@ func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error {
|
|||
log.Tracef("Failed to encrypt the DNS query: %v", err)
|
||||
return err
|
||||
}
|
||||
_, err = dns.WriteToSessionUDP(w.udpConn, res, w.sess)
|
||||
return err
|
||||
}
|
||||
|
||||
// ServeUDP listens to UDP connections, queries are then processed by Server.Handler.
|
||||
// It blocks the calling goroutine and to stop it you need to close the listener
|
||||
// or call Server.Shutdown.
|
||||
func (s *Server) ServeUDP(l *net.UDPConn) error {
|
||||
err := s.prepareServeUDP(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Tracks UDP handling goroutines
|
||||
udpWg := &sync.WaitGroup{}
|
||||
defer s.cleanUpUDP(udpWg, l)
|
||||
|
||||
// Track active goroutine
|
||||
s.wg.Add(1)
|
||||
|
||||
log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr())
|
||||
|
||||
// Serialize the cert right away and prepare it to be sent to the client
|
||||
certTxt, err := s.getCertTXT()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for s.isStarted() {
|
||||
b, sess, err := s.readUDPMsg(l)
|
||||
|
||||
// Check the error code and exit loop if necessary
|
||||
if err != nil {
|
||||
if !s.isStarted() {
|
||||
// Stopped gracefully
|
||||
return nil
|
||||
}
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Temporary() {
|
||||
// Note that timeout errors will be here (i.e. hitting ReadDeadline)
|
||||
continue
|
||||
}
|
||||
if isConnClosed(err) {
|
||||
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
||||
} else {
|
||||
log.Info("got error when reading from UDP listen: %s", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if len(b) < minDNSPacketSize {
|
||||
// Ignore the packets that are too short
|
||||
continue
|
||||
}
|
||||
|
||||
udpWg.Add(1)
|
||||
go func() {
|
||||
s.serveUDPMsg(b, certTxt, sess, l)
|
||||
udpWg.Done()
|
||||
}()
|
||||
}
|
||||
|
||||
_, _ = udpWrite(res, w.udpConn, w.remoteAddr, w.localIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeUDP - listens to UDP connections, queries are then processed by Server.Handler.
|
||||
// It blocks the calling goroutine and to stop it you need to close the listener.
|
||||
func (s *Server) ServeUDP(l *net.UDPConn) error {
|
||||
// prepareServeUDP prepares the server and listener to serving DNSCrypt
|
||||
func (s *Server) prepareServeUDP(l *net.UDPConn) error {
|
||||
// Check that server is properly configured
|
||||
if !s.validate() {
|
||||
return ErrServerConfig
|
||||
}
|
||||
|
||||
// set UDP options to allow receiving OOB data
|
||||
err := udpSetOptions(l)
|
||||
err := setUDPSocketOptions(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffer to read incoming messages
|
||||
b := make([]byte, dns.MaxMsgSize)
|
||||
// Protect shutdown-related fields
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.initOnce.Do(s.init)
|
||||
|
||||
// Serialize the cert right away and prepare it to be sent to the client
|
||||
certBuf, err := s.ResolverCert.Serialize()
|
||||
// Mark the server as started.
|
||||
// Note that we don't check if it was started before as
|
||||
// Serve* methods can be called multiple times.
|
||||
s.started = true
|
||||
|
||||
// Track an active UDP listener
|
||||
s.udpListeners[l] = struct{}{}
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanUpUDP waits until all UDP messages before cleaning up
|
||||
func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) {
|
||||
// Wait until UDP messages are processed
|
||||
udpWg.Wait()
|
||||
|
||||
// Not using it anymore so can be removed from the active listeners
|
||||
s.lock.Lock()
|
||||
delete(s.udpListeners, l)
|
||||
s.lock.Unlock()
|
||||
|
||||
// The work is finished
|
||||
s.wg.Done()
|
||||
}
|
||||
|
||||
// readUDPMsg reads incoming UDP message
|
||||
func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) {
|
||||
_ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout))
|
||||
b := make([]byte, dns.MinMsgSize)
|
||||
n, sess, err := dns.ReadFromSessionUDP(l, b)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
certTxt := packTxtString(certBuf)
|
||||
|
||||
// Init oobSize - it will be used later when reading and writing UDP messages
|
||||
oobSize := udpGetOOBSize()
|
||||
|
||||
log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr().String())
|
||||
|
||||
for {
|
||||
n, localIP, addr, err := udpRead(l, b, oobSize)
|
||||
if n < minDNSPacketSize {
|
||||
// Ignore the packets that are too short
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
|
||||
// This is an encrypted message, we should decrypt it
|
||||
m, q, err := s.decrypt(b[:n])
|
||||
if err == nil {
|
||||
rw := &UDPResponseWriter{
|
||||
udpConn: l,
|
||||
remoteAddr: addr,
|
||||
localIP: localIP,
|
||||
encrypt: s.encrypt,
|
||||
req: m,
|
||||
query: q,
|
||||
}
|
||||
go s.serveDNS(rw, m)
|
||||
} else {
|
||||
log.Tracef("Failed to decrypt incoming message len=%d: %v", n, err)
|
||||
}
|
||||
} else {
|
||||
// Most likely this a DNS message requesting the certificate
|
||||
reply, err := s.handleHandshake(b, certTxt)
|
||||
if err != nil {
|
||||
log.Tracef("Failed to process a plain DNS query: %v", err)
|
||||
}
|
||||
if err == nil {
|
||||
_, _ = l.WriteTo(reply, addr)
|
||||
}
|
||||
}
|
||||
return b[:n], sess, err
|
||||
}
|
||||
|
||||
// serveUDPMsg handles incoming DNS message
|
||||
func (s *Server) serveUDPMsg(b []byte, certTxt string, sess *dns.SessionUDP, l *net.UDPConn) {
|
||||
// First of all, check for "ClientMagic" in the incoming query
|
||||
if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) {
|
||||
// If there's no ClientMagic in the packet, we assume this
|
||||
// is a plain DNS query requesting the certificate data
|
||||
reply, err := s.handleHandshake(b, certTxt)
|
||||
if err != nil {
|
||||
if isConnClosed(err) {
|
||||
log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop")
|
||||
} else {
|
||||
log.Info("got error when reading from UDP listen: %s", err)
|
||||
}
|
||||
break
|
||||
log.Tracef("failed to process a plain DNS query: %v", err)
|
||||
}
|
||||
if err == nil {
|
||||
// Ignore errors, we don't care and can't handle them anyway
|
||||
_, _ = dns.WriteToSessionUDP(l, reply, sess)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// If we got here, this is an encrypted DNSCrypt message
|
||||
// We should decrypt it first to get the plain DNS query
|
||||
m, q, err := s.decrypt(b)
|
||||
if err == nil {
|
||||
rw := &UDPResponseWriter{
|
||||
udpConn: l,
|
||||
sess: sess,
|
||||
encrypt: s.encrypt,
|
||||
req: m,
|
||||
query: q,
|
||||
}
|
||||
err = s.serveDNS(rw, m)
|
||||
if err != nil {
|
||||
log.Tracef("failed to process a DNS query: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Tracef("failed to decrypt incoming message len=%d: %v", len(b), err)
|
||||
}
|
||||
}
|
||||
|
||||
// setUDPSocketOptions method is necessary to be able to use dns.ReadFromSessionUDP / dns.WriteToSessionUDP
|
||||
func setUDPSocketOptions(conn *net.UDPConn) error {
|
||||
if runtime.GOOS == "windows" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection.
|
||||
// Try enabling receiving of ECN and packet info for both IP versions.
|
||||
// We expect at least one of those syscalls to succeed.
|
||||
err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
|
||||
err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
|
||||
if err6 != nil && err4 != nil {
|
||||
return err4
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
1
testdata/dnscrypt-cert.opendns.txt
vendored
Normal file
1
testdata/dnscrypt-cert.opendns.txt
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
DNSC\000\001\000\000\200\226E:H\156\203%\134\218\127]\168\239\027u\011$\191\008\239\176F\133\017\171\161\219\154\142i\164\010\239\017f\168dS\210f\197\194\169\171w\2499\1891\155<\130\218@/\155\023v\153#d\024\004\136\180\228K5\233d\180\144\189\218\186\232%\162K\004\021\160\139\225\157}\219\135\163<\215~\223\142/qc78aWoo]\221\184`]\221\184`_\190\235\224
|
83
udp_unix.go
83
udp_unix.go
|
@ -1,83 +0,0 @@
|
|||
// +build aix darwin dragonfly linux netbsd openbsd solaris freebsd
|
||||
|
||||
package dnscrypt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
// udpGetOOBSize - get max. size of received OOB data
|
||||
// It will then be used in the ReadMsgUDP function
|
||||
func udpGetOOBSize() int {
|
||||
oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)
|
||||
oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)
|
||||
|
||||
if len(oob4) > len(oob6) {
|
||||
return len(oob4)
|
||||
}
|
||||
return len(oob6)
|
||||
}
|
||||
|
||||
// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data
|
||||
func udpSetOptions(c *net.UDPConn) error {
|
||||
err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true)
|
||||
err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true)
|
||||
if err6 != nil && err4 != nil {
|
||||
return fmt.Errorf("failed to call SetControlMessage: ipv4: %v ipv6: %v", err4, err6)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// udpRead - receive payload and OOB data from the UDP socket
|
||||
func udpRead(c *net.UDPConn, buf []byte, udpOOBSize int) (int, net.IP, *net.UDPAddr, error) {
|
||||
var oobn int
|
||||
oob := make([]byte, udpOOBSize)
|
||||
var err error
|
||||
var n int
|
||||
var remoteAddr *net.UDPAddr
|
||||
n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob)
|
||||
if err != nil {
|
||||
return -1, nil, nil, err
|
||||
}
|
||||
|
||||
localIP := udpGetDstFromOOB(oob[:oobn])
|
||||
return n, localIP, remoteAddr, nil
|
||||
}
|
||||
|
||||
// udpWrite - writes to the UDP socket and sets local IP to OOB data
|
||||
func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, localIP net.IP) (int, error) {
|
||||
n, _, err := conn.WriteMsgUDP(bytes, udpMakeOOBWithSrc(localIP), remoteAddr)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// udpGetDstFromOOB - get destination IP from OOB data
|
||||
func udpGetDstFromOOB(oob []byte) net.IP {
|
||||
cm6 := &ipv6.ControlMessage{}
|
||||
if cm6.Parse(oob) == nil && cm6.Dst != nil {
|
||||
return cm6.Dst
|
||||
}
|
||||
|
||||
cm4 := &ipv4.ControlMessage{}
|
||||
if cm4.Parse(oob) == nil && cm4.Dst != nil {
|
||||
return cm4.Dst
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// udpMakeOOBWithSrc - make OOB data with a specified source IP
|
||||
func udpMakeOOBWithSrc(ip net.IP) []byte {
|
||||
if ip.To4() == nil {
|
||||
cm := &ipv6.ControlMessage{}
|
||||
cm.Src = ip
|
||||
return cm.Marshal()
|
||||
}
|
||||
|
||||
cm := &ipv4.ControlMessage{}
|
||||
cm.Src = ip
|
||||
return cm.Marshal()
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package dnscrypt
|
||||
|
||||
import "net"
|
||||
|
||||
// udpGetOOBSize - get max. size of received OOB data
|
||||
// Does nothing on Windows
|
||||
func udpGetOOBSize() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data
|
||||
// Does nothing on Windows
|
||||
func udpSetOptions(c *net.UDPConn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// udpRead - receive payload from the UDP socket
|
||||
func udpRead(c *net.UDPConn, buf []byte, _ int) (int, net.IP, *net.UDPAddr, error) {
|
||||
n, addr, err := c.ReadFrom(buf)
|
||||
var udpAddr *net.UDPAddr
|
||||
if addr != nil {
|
||||
udpAddr = addr.(*net.UDPAddr)
|
||||
}
|
||||
return n, nil, udpAddr, err
|
||||
}
|
||||
|
||||
// udpWrite - writes to the UDP socket
|
||||
func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ net.IP) (int, error) {
|
||||
return conn.WriteTo(bytes, remoteAddr)
|
||||
}
|
Loading…
Reference in a new issue