From 5ddb58f7036d52917b070e9b1679b04921519ae5 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Fri, 19 Mar 2021 15:42:48 +0300 Subject: [PATCH] Graceful shutdown of the DNSCrypt server (#6) Graceful shutdown of the DNSCrypt server This PR implements Server.Shutdown(ctx context.Context) method that allows to shut down the DNSCrypt server gracefully. Some additional changes that were inadvertently made while doing that: 1. Added benchmark tests 2. Started using dns.ReadFromSessionUDP / dns.WriteToSessionUDP instead of implementing it by ourselves 3. Generally improved tests 4. Added depguard 5. Improved comments overall in the code --- .golangci.yml | 7 + cert.go | 30 ++-- cert_test.go | 16 +- client.go | 12 +- client_test.go | 19 ++- cmd/convert_dnscrypt_wrapper.go | 8 +- cmd/generate.go | 4 +- cmd/lookup.go | 10 +- cmd/main.go | 1 - cmd/server.go | 11 +- constants.go | 78 +++++----- encrypted_query.go | 12 +- encrypted_query_test.go | 6 +- encrypted_response.go | 8 +- encrypted_response_test.go | 8 +- generate.go | 27 ++-- generate_test.go | 8 +- go.mod | 2 +- go.sum | 2 + handler.go | 6 +- server.go | 174 ++++++++++++++++++++-- server_bench_test.go | 60 ++++++++ server_tcp.go | 201 ++++++++++++++++++------- server_test.go | 74 +++++++--- server_udp.go | 230 ++++++++++++++++++++--------- testdata/dnscrypt-cert.opendns.txt | 1 + udp_unix.go | 83 ----------- udp_windows.go | 30 ---- 28 files changed, 744 insertions(+), 384 deletions(-) create mode 100644 server_bench_test.go create mode 100644 testdata/dnscrypt-cert.opendns.txt delete mode 100644 udp_unix.go delete mode 100644 udp_windows.go diff --git a/.golangci.yml b/.golangci.yml index 0dc25f5..bb43eb9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,6 +20,12 @@ linters-settings: min-complexity: 20 lll: line-length: 200 + depguard: + list-type: blacklist + include-go-root: false + packages: + - golang.org/x/net/context # we use context + - log # we use github.com/AdguardTeam/golibs/log linters: enable: @@ -41,6 +47,7 @@ linters: - misspell - stylecheck - unconvert + - depguard disable-all: true fast: true diff --git a/cert.go b/cert.go index 649cc22..d19c391 100644 --- a/cert.go +++ b/cert.go @@ -8,10 +8,10 @@ import ( "time" ) -// Cert - DNSCrypt server certificate +// Cert is a DNSCrypt server certificate // See ResolverConfig for more info on how to create one type Cert struct { - // Serial - a 4 byte serial number in big-endian format. If more than + // Serial is a 4 byte serial number in big-endian format. If more than // one certificates are valid, the client must prefer the certificate // with a higher serial number. Serial uint32 @@ -22,36 +22,36 @@ type Cert struct { // For X25519-XChacha20Poly1305, must be 0x00 0x02. EsVersion CryptoConstruction - // Signature - a 64-byte signature of ( + // Signature is a 64-byte signature of ( // ) using the Ed25519 algorithm and the // provider secret key. Ed25519 must be used in this version of the // protocol. Signature [ed25519.SignatureSize]byte - // ResolverPk - the resolver's short-term public key, which is 32 bytes when using X25519. + // ResolverPk is the resolver's short-term public key, which is 32 bytes when using X25519. // This key is used to encrypt/decrypt DNS queries ResolverPk [keySize]byte - // ResolverSk - the resolver's short-term private key, which is 32 bytes when using X25519. + // ResolverSk is the resolver's short-term private key, which is 32 bytes when using X25519. // Note that it's only used in the server implementation and never serialized/deserialized. // This key is used to encrypt/decrypt DNS queries ResolverSk [keySize]byte - // ClientMagic - the first 8 bytes of a client query that is to be built + // ClientMagic is the first 8 bytes of a client query that is to be built // using the information from this certificate. It may be a truncated // public key. Two valid certificates cannot share the same . ClientMagic [clientMagicSize]byte - // NotAfter - the date the certificate is valid from, as a big-endian + // NotAfter is the date the certificate is valid from, as a big-endian // 4-byte unsigned Unix timestamp. NotBefore uint32 - // NotAfter - the date the certificate is valid until (inclusive), as a + // NotAfter is the date the certificate is valid until (inclusive), as a // big-endian 4-byte unsigned Unix timestamp. NotAfter uint32 } -// Serialize - serializes the cert to bytes +// Serialize serializes the cert to bytes // ::= // // @@ -86,7 +86,7 @@ func (c *Cert) Serialize() ([]byte, error) { return b, nil } -// Deserialize - deserializes certificate from a byte array +// Deserialize deserializes certificate from a byte array // ::= // // @@ -127,7 +127,7 @@ func (c *Cert) Deserialize(b []byte) error { return nil } -// VerifyDate - checks that cert is valid at this moment +// VerifyDate checks that the cert is valid at this moment func (c *Cert) VerifyDate() bool { if c.NotBefore >= c.NotAfter { return false @@ -139,14 +139,14 @@ func (c *Cert) VerifyDate() bool { return true } -// VerifySignature - checks if the cert is properly signed with the specified signature +// VerifySignature checks if the cert is properly signed with the specified signature func (c *Cert) VerifySignature(publicKey ed25519.PublicKey) bool { b := make([]byte, 52) c.writeSigned(b) return ed25519.Verify(publicKey, b, c.Signature[:]) } -// Sign - creates cert.Signature +// Sign creates cert.Signature func (c *Cert) Sign(privateKey ed25519.PrivateKey) { b := make([]byte, 52) c.writeSigned(b) @@ -154,14 +154,14 @@ func (c *Cert) Sign(privateKey ed25519.PrivateKey) { copy(c.Signature[:64], signature[:64]) } -// String - Cert's string representation +// String Cert's string representation func (c *Cert) String() string { return fmt.Sprintf("Certificate Serial=%d NotBefore=%s NotAfter=%s EsVersion=%s", c.Serial, time.Unix(int64(c.NotBefore), 0).String(), time.Unix(int64(c.NotAfter), 0).String(), c.EsVersion.String()) } -// writeSigned - writes ( ) +// writeSigned writes ( ) func (c *Cert) writeSigned(dst []byte) { // copy(dst[:32], c.ResolverPk[:keySize]) diff --git a/cert_test.go b/cert_test.go index edbc6f7..4b1d5e1 100644 --- a/cert_test.go +++ b/cert_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/ed25519" "crypto/rand" + "io/ioutil" "testing" "time" @@ -21,13 +22,13 @@ func TestCertSerialize(t *testing.T) { // serialize b, err := cert.Serialize() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 124, len(b)) // check that we can deserialize it cert2 := Cert{} err = cert2.Deserialize(b) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, cert.Serial, cert2.Serial) assert.Equal(t, cert.NotBefore, cert2.NotBefore) assert.Equal(t, cert.NotAfter, cert2.NotAfter) @@ -39,12 +40,15 @@ func TestCertSerialize(t *testing.T) { func TestCertDeserialize(t *testing.T) { // dig -t txt 2.dnscrypt-cert.opendns.com. -p 443 @208.67.220.220 - b, err := unpackTxtString("DNSC\\000\\001\\000\\000\\200\\226E:H\\156\\203%\\134\\218\\127]\\168\\239\\027u\\011$\\191\\008\\239\\176F\\133\\017\\171\\161\\219\\154\\142i\\164\\010\\239\\017f\\168dS\\210f\\197\\194\\169\\171w\\2499\\1891\\155<\\130\\218@/\\155\\023v\\153#d\\024\\004\\136\\180\\228K5\\233d\\180\\144\\189\\218\\186\\232%\\162K\\004\\021\\160\\139\\225\\157}\\219\\135\\163<\\215~\\223\\142/qc78aWoo]\\221\\184`]\\221\\184`_\\190\\235\\224") - assert.Nil(t, err) + certBytes, err := ioutil.ReadFile("testdata/dnscrypt-cert.opendns.txt") + assert.NoError(t, err) + + b, err := unpackTxtString(string(certBytes)) + assert.NoError(t, err) cert := &Cert{} err = cert.Deserialize(b) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint32(1574811744), cert.Serial) assert.Equal(t, XSalsa20Poly1305, cert.EsVersion) assert.Equal(t, uint32(1574811744), cert.NotBefore) @@ -69,7 +73,7 @@ func generateValidCert(t *testing.T) (*Cert, ed25519.PublicKey, ed25519.PrivateK // generate private key publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) - assert.Nil(t, err) + assert.NoError(t, err) // sign the data cert.Sign(privateKey) diff --git a/client.go b/client.go index 9ca3925..ebf1574 100644 --- a/client.go +++ b/client.go @@ -12,7 +12,7 @@ import ( "github.com/miekg/dns" ) -// Client - DNSCrypt resolver client +// Client is a DNSCrypt resolver client type Client struct { Net string // protocol (can be "udp" or "tcp", by default - "udp") Timeout time.Duration // read/write timeout @@ -124,7 +124,7 @@ func (c *Client) ExchangeConn(conn net.Conn, m *dns.Msg, resolverInfo *ResolverI return res, nil } -// writeQuery - writes query to the network connection +// writeQuery writes query to the network connection // depending on the protocol we may write a 2-byte prefix or not func (c *Client) writeQuery(conn net.Conn, query []byte) error { var err error @@ -145,7 +145,7 @@ func (c *Client) writeQuery(conn net.Conn, query []byte) error { return err } -// readResponse - reads response from the network connection +// readResponse reads response from the network connection // depending on the protocol, we may read a 2-byte prefix or not func (c *Client) readResponse(conn net.Conn) ([]byte, error) { if c.Timeout > 0 { @@ -171,7 +171,7 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) { return readPrefixed(conn) } -// encrypt - encrypts a DNS message using shared key from the resolver info +// encrypt encrypts a DNS message using shared key from the resolver info func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) { q := EncryptedQuery{ EsVersion: resolverInfo.ResolverCert.EsVersion, @@ -185,7 +185,7 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) return q.Encrypt(query, resolverInfo.SharedKey) } -// decrypts - decrypts a DNS message using shared key from the resolver info +// decrypts decrypts a DNS message using a shared key from the resolver info func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) { dr := EncryptedResponse{ EsVersion: resolverInfo.ResolverCert.EsVersion, @@ -203,7 +203,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) return res, nil } -// fetchCert - loads DNSCrypt cert from the specified server +// fetchCert loads DNSCrypt cert from the specified server func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) { providerName := stamp.ProviderName if !strings.HasSuffix(providerName, ".") { diff --git a/client_test.go b/client_test.go index e5f294a..6bbd4bd 100644 --- a/client_test.go +++ b/client_test.go @@ -63,7 +63,7 @@ func TestTimeoutOnDialExchange(t *testing.T) { client := Client{Timeout: 300 * time.Millisecond} serverInfo, err := client.Dial(stampStr) - assert.Nil(t, err) + assert.NoError(t, err) // Point it to an IP where there's no DNSCrypt server serverInfo.ServerAddress = "8.8.8.8:5443" @@ -105,12 +105,15 @@ func TestFetchCertPublicResolvers(t *testing.T) { for _, test := range stamps { stamp, err := dnsstamps.NewServerStampFromString(test.stampStr) - assert.Nil(t, err) + assert.NoError(t, err) t.Run(stamp.ProviderName, func(t *testing.T) { - c := &Client{Net: "udp"} + c := &Client{ + Net: "udp", + Timeout: time.Second * 5, + } resolverInfo, err := c.DialStamp(stamp) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, resolverInfo) assert.True(t, resolverInfo.ResolverCert.VerifyDate()) assert.True(t, resolverInfo.ResolverCert.VerifySignature(stamp.ServerPk)) @@ -146,7 +149,7 @@ func TestExchangePublicResolvers(t *testing.T) { for _, test := range stamps { stamp, err := dnsstamps.NewServerStampFromString(test.stampStr) - assert.Nil(t, err) + assert.NoError(t, err) t.Run(stamp.ProviderName, func(t *testing.T) { checkDNSCryptServer(t, test.stampStr, "udp") @@ -158,12 +161,12 @@ func TestExchangePublicResolvers(t *testing.T) { func checkDNSCryptServer(t *testing.T, stampStr string, network string) { client := Client{Net: network, Timeout: 10 * time.Second} resolverInfo, err := client.Dial(stampStr) - assert.Nil(t, err) + assert.NoError(t, err) req := createTestMessage() reply, err := client.Exchange(req, resolverInfo) - assert.Nil(t, err) + assert.NoError(t, err) assertTestMessageResponse(t, reply) } @@ -177,7 +180,7 @@ func createTestMessage() *dns.Msg { return &req } -func assertTestMessageResponse(t *testing.T, reply *dns.Msg) { +func assertTestMessageResponse(t assert.TestingT, reply *dns.Msg) { assert.NotNil(t, reply) assert.Equal(t, 1, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) diff --git a/cmd/convert_dnscrypt_wrapper.go b/cmd/convert_dnscrypt_wrapper.go index fc889ac..e3d5e88 100644 --- a/cmd/convert_dnscrypt_wrapper.go +++ b/cmd/convert_dnscrypt_wrapper.go @@ -12,7 +12,7 @@ import ( "gopkg.in/yaml.v3" ) -// ConvertWrapperArgs - "convert-dnscrypt-wrapper" command arguments +// ConvertWrapperArgs is the "convert-dnscrypt-wrapper" command arguments structure type ConvertWrapperArgs struct { PrivateKeyFile string `short:"p" long:"private-key" description:"Path to the DNSCrypt resolver private key file that is used for signing certificates. Param is required." required:"true"` ResolverSkFile string `short:"r" long:"resolver-secret" description:"Path to the Short-term privacy key file for encrypting/decrypting DNS queries. If not specified, resolver_secret and resolver_public will be randomly generated."` @@ -21,7 +21,7 @@ type ConvertWrapperArgs struct { CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"` } -// convertWrapper - generates DNSCrypt configuration from both dnscrypt and server private keys +// convertWrapper generates DNSCrypt configuration from both dnscrypt and server private keys func convertWrapper(args ConvertWrapperArgs) { log.Info("Generating configuration for %s", args.ProviderName) @@ -71,7 +71,7 @@ func convertWrapper(args ConvertWrapperArgs) { } } -// validateRc - verifies that the certificate is correctly +// validateRc verifies that the certificate is correctly // created and validated for this resolver config. if rc valid returns nil. func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error { cert, err := rc.CreateCert() @@ -90,7 +90,7 @@ func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error { return nil } -// getResolverPk - calculates public key from private key +// getResolverPk calculates public key from private key func getResolverPk(private ed25519.PrivateKey) ed25519.PublicKey { resolverSk := [32]byte{} resolverPk := [32]byte{} diff --git a/cmd/generate.go b/cmd/generate.go index 6d93b4d..6f8d914 100644 --- a/cmd/generate.go +++ b/cmd/generate.go @@ -8,7 +8,7 @@ import ( "gopkg.in/yaml.v3" ) -// GenerateArgs - "generate" command arguments +// GenerateArgs is the "generate" command arguments structure type GenerateArgs struct { ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt provider name. Param is required." required:"true"` Out string `short:"o" long:"out" description:"Path to the resulting config file. Param is required." required:"true"` @@ -16,7 +16,7 @@ type GenerateArgs struct { CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"` } -// generate - generates DNSCrypt server configuration +// generate generates a DNSCrypt server configuration func generate(args GenerateArgs) { log.Info("Generating configuration for %s", args.ProviderName) diff --git a/cmd/lookup.go b/cmd/lookup.go index 0e873d4..e5d31c8 100644 --- a/cmd/lookup.go +++ b/cmd/lookup.go @@ -12,7 +12,7 @@ import ( "github.com/miekg/dns" ) -// LookupStampArgs - "lookup-stamp" command arguments +// LookupStampArgs is the "lookup-stamp" command arguments structure type LookupStampArgs struct { Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"` Stamp string `short:"s" long:"stamp" description:"DNSCrypt resolver stamp. Param is required." required:"true"` @@ -20,7 +20,7 @@ type LookupStampArgs struct { Type string `short:"t" long:"type" description:"DNS query type" default:"A"` } -// LookupArgs - "lookup" command arguments +// LookupArgs is the "lookup" command arguments structure type LookupArgs struct { Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"` ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt resolver provider name. Param is required." required:"true"` @@ -30,7 +30,7 @@ type LookupArgs struct { Type string `short:"t" long:"type" description:"DNS query type" default:"A"` } -// LookupResult - lookup result that contains the cert info and the query response +// LookupResult is the lookup result that contains the cert info and the query response type LookupResult struct { Certificate struct { Serial uint32 `json:"serial"` @@ -42,7 +42,7 @@ type LookupResult struct { Reply *dns.Msg `json:"reply"` } -// lookup - performs a DNS lookup, prints DNSCrypt info and lookup results +// lookup performs a DNS lookup, prints DNSCrypt info and lookup results func lookup(args LookupArgs) { serverPk, err := dnscrypt.HexDecodeKey(args.PublicKey) if err != nil { @@ -64,7 +64,7 @@ func lookup(args LookupArgs) { }) } -// lookupStamp - performs a DNS lookup, prints DNSCrypt cert info and lookup results +// lookupStamp performs a DNS lookup, prints DNSCrypt cert info and lookup results func lookupStamp(args LookupStampArgs) { c := &dnscrypt.Client{ Net: args.Network, diff --git a/cmd/main.go b/cmd/main.go index 077b484..28e2a2d 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,7 +4,6 @@ import ( "os" "github.com/AdguardTeam/golibs/log" - goFlags "github.com/jessevdk/go-flags" ) diff --git a/cmd/server.go b/cmd/server.go index a10609d..46295ed 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -13,7 +13,7 @@ import ( "gopkg.in/yaml.v3" ) -// ServerArgs - "server" command arguments +// ServerArgs is the "server" command arguments type ServerArgs struct { Config string `short:"c" long:"config" description:"Path to the DNSCrypt configuration file. Param is required." required:"true"` Forward string `short:"f" long:"forward" description:"Forwards DNS queries to the specified address" default:"94.140.14.140:53"` @@ -21,7 +21,7 @@ type ServerArgs struct { ListenPorts []int `short:"p" long:"port" description:"Listening ports" default:"443"` } -// server - runs a DNSCrypt server +// server runs a DNSCrypt server func server(args ServerArgs) { log.Info("Starting DNSCrypt server") @@ -72,7 +72,7 @@ func server(args ServerArgs) { } } -// createListeners - creates listeners for our server +// createListeners creates listeners for our server func createListeners(args ServerArgs) (tcp []net.Listener, udp []*net.UDPConn) { for _, addr := range args.ListenAddrs { ip := net.ParseIP(addr) @@ -101,7 +101,10 @@ type forwardHandler struct { addr string } -// ServeDNS - implements Handler interface +// type check +var _ dnscrypt.Handler = &forwardHandler{} + +// ServeDNS implements Handler interface func (f *forwardHandler) ServeDNS(rw dnscrypt.ResponseWriter, r *dns.Msg) error { res, err := dns.Exchange(r, f.addr) if err != nil { diff --git a/constants.go b/constants.go index 81a2cf2..2411d78 100644 --- a/constants.go +++ b/constants.go @@ -1,52 +1,58 @@ package dnscrypt -import "errors" +// Error represents a dnscrypt error. +type Error string -var ( - // ErrTooShort - DNS query is shorter than possible - ErrTooShort = errors.New("DNSCrypt message is too short") +func (e Error) Error() string { return "dnscrypt: " + string(e) } - // ErrQueryTooLarge - DNS query is larger than max allowed size - ErrQueryTooLarge = errors.New("DNSCrypt query is too large") +const ( + // ErrTooShort means that the DNS query is shorter than possible + ErrTooShort = Error("message is too short") - // ErrEsVersion - cert contains unsupported es-version - ErrEsVersion = errors.New("unsupported es-version") + // ErrQueryTooLarge means that the DNS query is larger than max allowed size + ErrQueryTooLarge = Error("DNSCrypt query is too large") - // ErrInvalidDate - cert is not valid for the current time - ErrInvalidDate = errors.New("cert has invalid ts-start or ts-end") + // ErrEsVersion means that the cert contains unsupported es-version + ErrEsVersion = Error("unsupported es-version") - // ErrInvalidCertSignature - cert has invalid signature - ErrInvalidCertSignature = errors.New("cert has invalid signature") + // ErrInvalidDate means that the cert is not valid for the current time + ErrInvalidDate = Error("cert has invalid ts-start or ts-end") - // ErrInvalidQuery - failed to decrypt a DNSCrypt query - ErrInvalidQuery = errors.New("DNSCrypt query is invalid and cannot be decrypted") + // ErrInvalidCertSignature means that the cert has invalid signature + ErrInvalidCertSignature = Error("cert has invalid signature") - // ErrInvalidClientMagic - client-magic does not match - ErrInvalidClientMagic = errors.New("DNSCrypt query contains invalid client magic") + // ErrInvalidQuery means that it failed to decrypt a DNSCrypt query + ErrInvalidQuery = Error("DNSCrypt query is invalid and cannot be decrypted") - // ErrInvalidResolverMagic - server-magic does not match - ErrInvalidResolverMagic = errors.New("DNSCrypt response contains invalid resolver magic") + // ErrInvalidClientMagic means that client-magic does not match + ErrInvalidClientMagic = Error("DNSCrypt query contains invalid client magic") - // ErrInvalidResponse - failed to decrypt a DNSCrypt response - ErrInvalidResponse = errors.New("DNSCrypt response is invalid and cannot be decrypted") + // ErrInvalidResolverMagic means that server-magic does not match + ErrInvalidResolverMagic = Error("DNSCrypt response contains invalid resolver magic") - // ErrInvalidPadding - failed to unpad a query - ErrInvalidPadding = errors.New("invalid padding") + // ErrInvalidResponse means that it failed to decrypt a DNSCrypt response + ErrInvalidResponse = Error("DNSCrypt response is invalid and cannot be decrypted") - // ErrInvalidDNSStamp - invalid DNS stamp - ErrInvalidDNSStamp = errors.New("invalid DNS stamp") + // ErrInvalidPadding means that it failed to unpad a query + ErrInvalidPadding = Error("invalid padding") - // ErrFailedToFetchCert - failed to fetch DNSCrypt certificate - ErrFailedToFetchCert = errors.New("failed to fetch DNSCrypt certificate") + // ErrInvalidDNSStamp means an invalid DNS stamp + ErrInvalidDNSStamp = Error("invalid DNS stamp") - // ErrCertTooShort - failed to deserialize cert, too short - ErrCertTooShort = errors.New("cert is too short") + // ErrFailedToFetchCert means that it failed to fetch DNSCrypt certificate + ErrFailedToFetchCert = Error("failed to fetch DNSCrypt certificate") - // ErrCertMagic - invalid cert magic - ErrCertMagic = errors.New("invalid cert magic") + // ErrCertTooShort means that it failed to deserialize cert, too short + ErrCertTooShort = Error("cert is too short") - // ErrServerConfig - failed to start the DNSCrypt server - invalid configuration - ErrServerConfig = errors.New("invalid server configuration") + // ErrCertMagic means an invalid cert magic + ErrCertMagic = Error("invalid cert magic") + + // ErrServerConfig means that it failed to start the DNSCrypt server - invalid configuration + ErrServerConfig = Error("invalid server configuration") + + // ErrServerNotStarted is returned if there's nothing to shutdown + ErrServerNotStarted = Error("server is not started") ) const ( @@ -55,7 +61,7 @@ const ( // Some servers do not work if padded length is less than 256. Example: Quad9 minUDPQuestionSize = 256 - // - maximum allowed query length + // is the maximum allowed query length maxQueryLen = 1252 // Minimum possible DNS packet size @@ -68,7 +74,7 @@ const ( // size of the shared key used to encrypt/decrypt messages sharedKeySize = 32 - // ClientMagic - the first 8 bytes of a client query that is to be built + // ClientMagic is the first 8 bytes of a client query that is to be built // using the information from this certificate. It may be a truncated // public key. Two valid certificates cannot share the same . clientMagicSize = 8 @@ -82,10 +88,10 @@ const ( ) var ( - // certMagic - bytes sequence that must be in the beginning of the serialized cert + // certMagic is a bytes sequence that must be in the beginning of the serialized cert certMagic = [4]byte{0x44, 0x4e, 0x53, 0x43} - // resolverMagic - byte sequence that must be in the beginning of every response + // resolverMagic is a byte sequence that must be in the beginning of every response resolverMagic = []byte{0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38} ) diff --git a/encrypted_query.go b/encrypted_query.go index 545991e..f64ba8f 100644 --- a/encrypted_query.go +++ b/encrypted_query.go @@ -10,19 +10,19 @@ import ( "golang.org/x/crypto/nacl/secretbox" ) -// EncryptedQuery - a structure for encrypting and decrypting client queries +// EncryptedQuery is a structure for encrypting and decrypting client queries // // ::= // ::= AE( , ) type EncryptedQuery struct { - // EsVersion - encryption to use + // EsVersion is the encryption to use EsVersion CryptoConstruction - // ClientMagic - a 8 byte identifier for the resolver certificate + // ClientMagic is a 8 byte identifier for the resolver certificate // chosen by the client. ClientMagic [clientMagicSize]byte - // ClientPk - the client's public key + // ClientPk is the client's public key ClientPk [keySize]byte // With a 24 bytes nonce, a question sent by a DNSCrypt client must be @@ -36,7 +36,7 @@ type EncryptedQuery struct { Nonce [nonceSize]byte } -// Encrypt - encrypts the specified DNS query, returns encrypted data ready to be sent. +// Encrypt encrypts the specified DNS query, returns encrypted data ready to be sent. // // Note that this method will generate a random nonce automatically. // @@ -79,7 +79,7 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) ( return query, nil } -// Decrypt - decrypts the client query, returns decrypted DNS packet. +// Decrypt decrypts the client query, returns decrypted DNS packet. // // Please note, that before calling this method the following fields must be set: // * ClientMagic -- to verify the query diff --git a/encrypted_query_test.go b/encrypted_query_test.go index d70643f..cffb75c 100644 --- a/encrypted_query_test.go +++ b/encrypted_query_test.go @@ -23,7 +23,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction) // Generate client shared key clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey) - assert.Nil(t, err) + assert.NoError(t, err) clientMagic := [clientMagicSize]byte{} _, _ = rand.Read(clientMagic[:]) @@ -40,7 +40,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction) // Encrypt it encrypted, err := q1.Encrypt(packet, clientSharedKey) - assert.Nil(t, err) + assert.NoError(t, err) // Now let's try decrypting it q2 := EncryptedQuery{ @@ -50,7 +50,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction) // Decrypt it decrypted, err := q2.Decrypt(encrypted, serverSecretKey) - assert.Nil(t, err) + assert.NoError(t, err) // Check that packet is the same assert.True(t, bytes.Equal(packet, decrypted)) diff --git a/encrypted_response.go b/encrypted_response.go index 30d8405..212cb87 100644 --- a/encrypted_response.go +++ b/encrypted_response.go @@ -10,12 +10,12 @@ import ( "golang.org/x/crypto/nacl/secretbox" ) -// EncryptedResponse - structure for encrypting/decrypting server responses +// EncryptedResponse is a structure for encrypting/decrypting server responses // // ::= // ::= AE(, , ) type EncryptedResponse struct { - // EsVersion - encryption to use + // EsVersion is the encryption to use EsVersion CryptoConstruction // Nonce - ::= @@ -23,7 +23,7 @@ type EncryptedResponse struct { Nonce [nonceSize]byte } -// Encrypt - encrypts the server response +// Encrypt encrypts the server response // // EsVersion must be set. // Nonce needs to be set to "client-nonce". @@ -57,7 +57,7 @@ func (r *EncryptedResponse) Encrypt(packet []byte, sharedKey [sharedKeySize]byte return response, nil } -// Decrypt - decrypts the server response +// Decrypt decrypts the server response // // EsVersion must be set. func (r *EncryptedResponse) Decrypt(response []byte, sharedKey [sharedKeySize]byte) ([]byte, error) { diff --git a/encrypted_response_test.go b/encrypted_response_test.go index 0de8dbe..6c755cc 100644 --- a/encrypted_response_test.go +++ b/encrypted_response_test.go @@ -24,11 +24,11 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi // Generate client shared key clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey) - assert.Nil(t, err) + assert.NoError(t, err) // Generate server shared key serverSharedKey, err := computeSharedKey(esVersion, &serverSecretKey, &clientPublicKey) - assert.Nil(t, err) + assert.NoError(t, err) r1 := &EncryptedResponse{ EsVersion: esVersion, @@ -42,7 +42,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi // Encrypt it encrypted, err := r1.Encrypt(packet, serverSharedKey) - assert.Nil(t, err) + assert.NoError(t, err) // Now let's try decrypting it r2 := &EncryptedResponse{ @@ -51,7 +51,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi // Decrypt it decrypted, err := r2.Decrypt(encrypted, clientSharedKey) - assert.Nil(t, err) + assert.NoError(t, err) // Check that packet is the same assert.True(t, bytes.Equal(packet, decrypted)) diff --git a/generate.go b/generate.go index 3891e31..430d955 100644 --- a/generate.go +++ b/generate.go @@ -14,36 +14,37 @@ import ( const dnsCryptV2Prefix = "2.dnscrypt-cert." -// ResolverConfig - DNSCrypt resolver configuration +// ResolverConfig is the DNSCrypt resolver configuration type ResolverConfig struct { // DNSCrypt provider name ProviderName string `yaml:"provider_name"` - // PublicKey - DNSCrypt resolver public key + // PublicKey is the DNSCrypt resolver public key PublicKey string `yaml:"public_key"` - // PrivateKey - DNSCrypt resolver private key + // PrivateKey is the DNSCrypt resolver private key // The main and only purpose of this key is to sign the certificate PrivateKey string `yaml:"private_key"` - // ResolverSk - hex-encoded short-term private key. + // ResolverSk is a hex-encoded short-term private key. // This key is used to encrypt/decrypt DNS queries. // If not set, we'll generate a new random ResolverSk and ResolverPk. ResolverSk string `yaml:"resolver_secret"` - // ResolverPk - hex-encoded short-term public key corresponding to ResolverSk. + // ResolverPk is a hex-encoded short-term public key corresponding to ResolverSk. // This key is used to encrypt/decrypt DNS queries. ResolverPk string `yaml:"resolver_public"` - // EsVersion - crypto to use in this resolver + // EsVersion is the crypto to use in this resolver EsVersion CryptoConstruction `yaml:"es_version"` - // CertificateTTL - time-to-live for the certificate that is generated using this ResolverConfig. + // CertificateTTL is the time-to-live value for the certificate that is + // generated using this ResolverConfig. // If not set, we'll use 1 year by default. CertificateTTL time.Duration `yaml:"certificate_ttl"` } -// CreateCert - generates a signed Cert to be used by Server +// CreateCert generates a signed Cert to be used by Server func (rc *ResolverConfig) CreateCert() (*Cert, error) { log.Printf("Creating signed DNSCrypt certificate") @@ -98,7 +99,7 @@ func (rc *ResolverConfig) CreateCert() (*Cert, error) { return cert, nil } -// CreateStamp - generates a DNS stamp for this resolver +// CreateStamp generates a DNS stamp for this resolver func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error) { stamp := dnsstamps.ServerStamp{ ProviderName: rc.ProviderName, @@ -115,7 +116,7 @@ func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error return stamp, nil } -// GenerateResolverConfig - generates resolver configuration for a given provider name. +// GenerateResolverConfig generates resolver configuration for a given provider name. // providerName is mandatory. If needed, "2.dnscrypt-cert." prefix is added to it. // privateKey is optional. If not set, it will be generated automatically. func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey) (ResolverConfig, error) { @@ -145,18 +146,18 @@ func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey) return rc, nil } -// HexEncodeKey - encodes a byte slice to a hex-encoded string. +// HexEncodeKey encodes a byte slice to a hex-encoded string. func HexEncodeKey(b []byte) string { return strings.ToUpper(hex.EncodeToString(b)) } -// HexDecodeKey - decodes a hex-encoded string with (optional) colons +// HexDecodeKey decodes a hex-encoded string with (optional) colons // to a byte array. func HexDecodeKey(str string) ([]byte, error) { return hex.DecodeString(strings.ReplaceAll(str, ":", "")) } -// generateRandomKeyPair - generates a random key-pair +// generateRandomKeyPair generates a random key-pair func generateRandomKeyPair() (privateKey [keySize]byte, publicKey [keySize]byte) { privateKey = [keySize]byte{} publicKey = [keySize]byte{} diff --git a/generate_test.go b/generate_test.go index 6b46079..a6099d2 100644 --- a/generate_test.go +++ b/generate_test.go @@ -15,24 +15,24 @@ func TestHexEncodeKey(t *testing.T) { func TestHexDecodeKey(t *testing.T) { b, err := HexDecodeKey("01:02:03:04") - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, bytes.Equal(b, []byte{1, 2, 3, 4})) } func TestGenerateResolverConfig(t *testing.T) { rc, err := GenerateResolverConfig("example.org", nil) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "2.dnscrypt-cert.example.org", rc.ProviderName) assert.Equal(t, ed25519.PrivateKeySize*2, len(rc.PrivateKey)) assert.Equal(t, keySize*2, len(rc.ResolverSk)) assert.Equal(t, keySize*2, len(rc.ResolverPk)) cert, err := rc.CreateCert() - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, cert.VerifyDate()) publicKey, err := HexDecodeKey(rc.PublicKey) - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, cert.VerifySignature(publicKey)) } diff --git a/go.mod b/go.mod index 9a29d25..1c90381 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 github.com/ameshkov/dnsstamps v1.0.1 github.com/jessevdk/go-flags v1.4.0 - github.com/miekg/dns v1.1.29 + github.com/miekg/dns v1.1.40 github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e diff --git a/go.sum b/go.sum index c07da94..4dbbe1e 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGAR github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg= github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/miekg/dns v1.1.40 h1:pyyPFfGMnciYUk/mXpKkVmeMQjfXqt3FAJ2hy7tPiLA= +github.com/miekg/dns v1.1.40/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/handler.go b/handler.go index eb3430f..2943ae5 100644 --- a/handler.go +++ b/handler.go @@ -14,14 +14,14 @@ type Handler interface { ServeDNS(rw ResponseWriter, r *dns.Msg) error } -// ResponseWriter - interface that needs to be implemented for different protocols +// ResponseWriter is the interface that needs to be implemented for different protocols type ResponseWriter interface { LocalAddr() net.Addr // LocalAddr - local socket address RemoteAddr() net.Addr // RemoteAddr - remote client socket address WriteMsg(m *dns.Msg) error // WriteMsg - writes response message to the client } -// DefaultHandler - default Handler implementation +// DefaultHandler is the default Handler implementation // that is used by Server if custom handler is not configured var DefaultHandler Handler = &defaultHandler{ udpClient: &dns.Client{ @@ -41,7 +41,7 @@ type defaultHandler struct { addr string } -// ServeDNS - implements Handler interface +// ServeDNS implements Handler interface func (h *defaultHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error { // Google DNS res, _, err := h.udpClient.Exchange(r, h.addr) diff --git a/server.go b/server.go index 0e08557..d127242 100644 --- a/server.go +++ b/server.go @@ -1,27 +1,167 @@ package dnscrypt import ( + "context" + "net" + "sync" + "time" + "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -// Server - a simple DNSCrypt server implementation +// default read timeout for all reads +const defaultReadTimeout = 2 * time.Second + +// in case of TCP we only use defaultReadTimeout for the first read +// then we start using defaultTCPIdleTimeout +const defaultTCPIdleTimeout = 8 * time.Second + +// helper struct that is used in several SetReadDeadline calls +var longTimeAgo = time.Unix(1, 0) + +// ServerDNSCrypt is an interface for a DNSCrypt server +type ServerDNSCrypt interface { + // ServeTCP listens to TCP connections, queries are then processed by Server.Handler. + // It blocks the calling goroutine and to stop it you need to close the listener + // or call ServerDNSCrypt.Shutdown. + ServeTCP(l net.Listener) error + + // ServeUDP listens to UDP connections, queries are then processed by Server.Handler. + // It blocks the calling goroutine and to stop it you need to close the listener + // or call ServerDNSCrypt.Shutdown. + ServeUDP(l *net.UDPConn) error + + // Shutdown tries to gracefully shutdown the server. It waits until all + // connections are processed and only after that it leaves the method. + // If context deadline is specified, it will exit earlier + // or call ServerDNSCrypt.Shutdown. + Shutdown(ctx context.Context) error +} + +// Server is a simple DNSCrypt server implementation type Server struct { - // ProviderName - DNSCrypt provider name + // ProviderName is a DNSCrypt provider name ProviderName string - // ResolverCert - contains resolver certificate. + // ResolverCert contains resolver certificate. ResolverCert *Cert // Handler to invoke. If nil, uses DefaultHandler. Handler Handler + + // make sure init is called only once + initOnce sync.Once + + // Shutdown handling + // -- + lock sync.RWMutex // protects access to all the fields below + started bool + wg sync.WaitGroup // active workers (servers) + tcpListeners map[net.Listener]struct{} // track active TCP listeners + udpListeners map[*net.UDPConn]struct{} // track active UDP listeners + tcpConns map[net.Conn]struct{} // track active connections } -// serveDNS - serves DNS response -func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) { +// type check +var _ ServerDNSCrypt = &Server{} + +// prepareShutdown - prepares the server to shutdown: +// unblocks reads from all connections related to this server +// marks the server as stopped +// if the server is not started, returns ErrServerNotStarted +func (s *Server) prepareShutdown() error { + s.lock.Lock() + defer s.lock.Unlock() + + if !s.started { + log.Info("Server is not started") + return ErrServerNotStarted + } + + s.started = false + + // These listeners were passed to us from the outside so we cannot close + // them here - this is up to the calling code to do that. Instead of that, + // we call Set(Read)Deadline to unblock goroutines that are currently + // blocked on reading from those listeners. + // For tcpConns we would like to avoid closing them to be able to process + // queries before shutting everything down. + + // Unblock reads for all active tcpConns + for conn := range s.tcpConns { + _ = conn.SetReadDeadline(longTimeAgo) + } + + // Unblock reads for all active TCP listeners + for l := range s.tcpListeners { + switch v := l.(type) { + case *net.TCPListener: + _ = v.SetDeadline(longTimeAgo) + } + } + + // Unblock reads for all active UDP listeners + for l := range s.udpListeners { + _ = l.SetReadDeadline(longTimeAgo) + } + + return nil +} + +// Shutdown tries to gracefully shutdown the server. It waits until all +// connections are processed and only after that it leaves the method. +// If context deadline is specified, it will exit earlier. +func (s *Server) Shutdown(ctx context.Context) error { + log.Info("Shutting down the DNSCrypt server") + + err := s.prepareShutdown() + if err != nil { + return err + } + + // Using this channel to wait until all goroutines finish their work + closed := make(chan struct{}) + go func() { + s.wg.Wait() + log.Info("Serve goroutines finished their work") + close(closed) + }() + + // Wait for either all goroutines finish their work + // Or for the context deadline + select { + case <-closed: + log.Info("DNSCrypt server has been stopped") + case <-ctx.Done(): + log.Info("DNSCrypt server shutdown has timed out") + err = ctx.Err() + } + + return err +} + +// init initializes (lazily) Server properties on startup +// this method is called from Server.ServeTCP and Server.ServeUDP +func (s *Server) init() { + s.tcpConns = map[net.Conn]struct{}{} + s.udpListeners = map[*net.UDPConn]struct{}{} + s.tcpListeners = map[net.Listener]struct{}{} +} + +// isStarted returns true if the server is processing queries right now +// it means that Server.ServeTCP and/or Server.ServeUDP have been called +func (s *Server) isStarted() bool { + s.lock.RLock() + started := s.started + s.lock.RUnlock() + return started +} + +// serveDNS serves DNS response +func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error { if r == nil || len(r.Question) != 1 || r.Response { - log.Tracef("Invalid query: %v", r) - return + return ErrInvalidQuery } log.Tracef("Handling a DNS query: %s", r.Question[0].Name) @@ -39,9 +179,11 @@ func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) { reply.SetRcode(r, dns.RcodeServerFailure) _ = rw.WriteMsg(reply) } + + return nil } -// encrypt - encrypts DNSCrypt response +// encrypt encrypts DNSCrypt response func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) { r := EncryptedResponse{ EsVersion: q.EsVersion, @@ -60,7 +202,7 @@ func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) { return r.Encrypt(packet, sharedKey) } -// decrypt - decrypts the incoming message and returns a DNS message to process +// decrypt decrypts the incoming message and returns a DNS message to process func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) { q := EncryptedQuery{ EsVersion: s.ResolverCert.EsVersion, @@ -82,7 +224,7 @@ func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) { return r, q, nil } -// handleHandshake - handles a TXT request that requests certificate data +// handleHandshake handles a TXT request that requests certificate data func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) { m := new(dns.Msg) err := m.Unpack(b) @@ -120,7 +262,7 @@ func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) { return reply.Pack() } -// validate - checks if the Server config is properly set +// validate checks if the Server config is properly set func (s *Server) validate() bool { if s.ResolverCert == nil { log.Error("ResolverCert must be set") @@ -139,3 +281,13 @@ func (s *Server) validate() bool { return true } + +// getCertTXT serializes the cert TXT record that are to be sent to the client +func (s *Server) getCertTXT() (string, error) { + certBuf, err := s.ResolverCert.Serialize() + if err != nil { + return "", err + } + certTxt := packTxtString(certBuf) + return certTxt, nil +} diff --git a/server_bench_test.go b/server_bench_test.go new file mode 100644 index 0000000..e473276 --- /dev/null +++ b/server_bench_test.go @@ -0,0 +1,60 @@ +package dnscrypt + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/ameshkov/dnsstamps" + "github.com/stretchr/testify/require" +) + +func BenchmarkServeUDP(b *testing.B) { + benchmarkServe(b, "udp") +} + +func BenchmarkServeTCP(b *testing.B) { + benchmarkServe(b, "tcp") +} + +func benchmarkServe(b *testing.B, network string) { + srv := newTestServer(b, &testHandler{}) + b.Cleanup(func() { + err := srv.Close() + require.NoError(b, err) + }) + + client := &Client{ + Timeout: 1 * time.Second, + Net: network, + } + + serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port) + if network == "tcp" { + serverAddr = fmt.Sprintf("127.0.0.1:%d", srv.TCPAddr().Port) + } + + stamp := dnsstamps.ServerStamp{ + ServerAddrStr: serverAddr, + ServerPk: srv.resolverPk, + ProviderName: srv.server.ProviderName, + Proto: dnsstamps.StampProtoTypeDNSCrypt, + } + ri, err := client.DialStamp(stamp) + require.NoError(b, err) + require.NotNil(b, ri) + + conn, err := net.Dial(network, stamp.ServerAddrStr) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m := createTestMessage() + res, err := client.ExchangeConn(conn, m, ri) + require.NoError(b, err) + assertTestMessageResponse(b, res) + } + b.StopTimer() +} diff --git a/server_tcp.go b/server_tcp.go index 4913def..84f9a8f 100644 --- a/server_tcp.go +++ b/server_tcp.go @@ -2,13 +2,17 @@ package dnscrypt import ( "bytes" + "errors" + "fmt" "net" + "sync" + "time" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -// TCPResponseWriter - ResponseWriter implementation for TCP +// TCPResponseWriter is the ResponseWriter implementation for TCP type TCPResponseWriter struct { tcpConn net.Conn encrypt encryptionFunc @@ -19,17 +23,17 @@ type TCPResponseWriter struct { // type check var _ ResponseWriter = &TCPResponseWriter{} -// LocalAddr - server socket local address +// LocalAddr is the server socket local address func (w *TCPResponseWriter) LocalAddr() net.Addr { return w.tcpConn.LocalAddr() } -// RemoteAddr - client's address +// RemoteAddr is the client's address func (w *TCPResponseWriter) RemoteAddr() net.Addr { return w.tcpConn.RemoteAddr() } -// WriteMsg - writes DNS message to the client +// WriteMsg writes DNS message to the client func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error { m.Truncate(dnsSize("tcp", w.req)) @@ -42,33 +46,44 @@ func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error { return writePrefixed(res, w.tcpConn) } -// ServeTCP - listens to TCP connections, queries are then processed by Server.Handler. -// It blocks the calling goroutine and to stop it you need to close the listener. +// ServeTCP listens to TCP connections, queries are then processed by Server.Handler. +// It blocks the calling goroutine and to stop it you need to close the listener +// or call Server.Shutdown. func (s *Server) ServeTCP(l net.Listener) error { - // Check that server is properly configured - if !s.validate() { - return ErrServerConfig - } - - // Serialize the cert right away and prepare it to be sent to the client - certBuf, err := s.ResolverCert.Serialize() + err := s.prepareServeTCP(l) if err != nil { return err } - certTxt := packTxtString(certBuf) - log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr().String()) + log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr()) - for { + // Tracks TCP connection handling goroutines + tcpWg := &sync.WaitGroup{} + defer s.cleanUpTCP(tcpWg, l) + + // Track active goroutine + s.wg.Add(1) + + // Serialize the cert right away and prepare it to be sent to the client + certTxt, err := s.getCertTXT() + if err != nil { + return err + } + + for s.isStarted() { conn, err := l.Accept() - if err == nil { - go func() { - _ = s.handleTCPConnection(conn, certTxt) - _ = conn.Close() - }() - } + // Check the error code and exit loop if necessary if err != nil { + if !s.isStarted() { + // Stopped gracefully + break + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Temporary() { + // Note that timeout errors will be here (i.e. hitting ReadDeadline) + continue + } if isConnClosed(err) { log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop") } else { @@ -76,47 +91,129 @@ func (s *Server) ServeTCP(l net.Listener) error { } break } + + // If we got here, the connection is alive + s.lock.Lock() + // Track the connection to allow unblocking reads on shutdown. + s.tcpConns[conn] = struct{}{} + s.lock.Unlock() + + tcpWg.Add(1) + go func() { + // Ignore error here, it is most probably a legit one + // if not, it's written to the debug log + _ = s.handleTCPConnection(conn, certTxt) + + // Clean up + _ = conn.Close() + s.lock.Lock() + delete(s.tcpConns, conn) + s.lock.Unlock() + tcpWg.Done() + }() } return nil } +// prepareServeTCP prepares the server and listener to serving DNSCrypt +func (s *Server) prepareServeTCP(l net.Listener) error { + // Check that server is properly configured + if !s.validate() { + return ErrServerConfig + } + + // Protect shutdown-related fields + s.lock.Lock() + defer s.lock.Unlock() + s.initOnce.Do(s.init) + + // Mark the server as started if needed + s.started = true + + // Track an active TCP listener + s.tcpListeners[l] = struct{}{} + return nil +} + +// cleanUpTCP waits until all TCP messages before cleaning up +func (s *Server) cleanUpTCP(tcpWg *sync.WaitGroup, l net.Listener) { + // Wait until all TCP connections are processed + tcpWg.Wait() + + // Not using it anymore so can be removed from the active listeners + s.lock.Lock() + delete(s.tcpListeners, l) + s.lock.Unlock() + + // The work is finished + s.wg.Done() +} + +// handleTCPMsg handles a single TCP message. If this method returns error +// the connection will be closed +func (s *Server) handleTCPMsg(b []byte, conn net.Conn, certTxt string) error { + if len(b) < minDNSPacketSize { + // Ignore the packets that are too short + return ErrTooShort + } + + // First of all, check for "ClientMagic" in the incoming query + if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { + // If there's no ClientMagic in the packet, we assume this + // is a plain DNS query requesting the certificate data + reply, err := s.handleHandshake(b, certTxt) + if err != nil { + return fmt.Errorf("failed to process a plain DNS query: %w", err) + } + err = writePrefixed(reply, conn) + if err != nil { + return fmt.Errorf("failed to write a response: %w", err) + } + return nil + } + + // If we got here, this is an encrypted DNSCrypt message + // We should decrypt it first to get the plain DNS query + m, q, err := s.decrypt(b) + if err != nil { + return fmt.Errorf("failed to decrypt incoming message: %w", err) + } + rw := &TCPResponseWriter{ + tcpConn: conn, + encrypt: s.encrypt, + req: m, + query: q, + } + err = s.serveDNS(rw, m) + if err != nil { + return fmt.Errorf("failed to process a DNS query: %w", err) + } + + return nil +} + +// handleTCPConnection handles all queries that are coming to the +// specified TCP connection. func (s *Server) handleTCPConnection(conn net.Conn, certTxt string) error { - for { + timeout := defaultReadTimeout + + for s.isStarted() { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + b, err := readPrefixed(conn) if err != nil { return err } - if len(b) < minDNSPacketSize { - // Ignore the packets that are too short - return ErrTooShort + + err = s.handleTCPMsg(b, conn, certTxt) + if err != nil { + log.Debug("failed to process DNS query: %v", err) + return err } - if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { - // This is an encrypted message, we should decrypt it - m, q, err := s.decrypt(b) - if err != nil { - log.Tracef("failed to decrypt incoming message: %v", err) - return err - } - rw := &TCPResponseWriter{ - tcpConn: conn, - encrypt: s.encrypt, - req: m, - query: q, - } - s.serveDNS(rw, m) - } else { - // Most likely this a DNS message requesting the certificate - reply, err := s.handleHandshake(b, certTxt) - if err != nil { - log.Tracef("Failed to process a plain DNS query: %v", err) - return err - } - err = writePrefixed(reply, conn) - if err != nil { - return err - } - } + timeout = defaultTCPIdleTimeout } + + return nil } diff --git a/server_test.go b/server_test.go index f981925..d1063e0 100644 --- a/server_test.go +++ b/server_test.go @@ -2,9 +2,11 @@ package dnscrypt import ( "bytes" + "context" "crypto/ed25519" "fmt" "net" + "runtime" "testing" "time" @@ -13,25 +15,52 @@ import ( "github.com/stretchr/testify/assert" ) -func TestServerUDPServeCert(t *testing.T) { +func TestServer_Shutdown(t *testing.T) { + n := runtime.GOMAXPROCS(1) + t.Cleanup(func() { + runtime.GOMAXPROCS(n) + }) + srv := newTestServer(t, &testHandler{}) + // Serve* methods are called in different goroutines + // give them at least a moment to actually start the server + time.Sleep(10 * time.Millisecond) + assert.NoError(t, srv.Close()) +} + +func TestServer_UDPServeCert(t *testing.T) { testServerServeCert(t, "udp") } -func TestServerTCPServeCert(t *testing.T) { +func TestServer_TCPServeCert(t *testing.T) { testServerServeCert(t, "tcp") } -func TestServerUDPRespondMessages(t *testing.T) { +func TestServer_UDPRespondMessages(t *testing.T) { testServerRespondMessages(t, "udp") } -func TestServerTCPRespondMessages(t *testing.T) { +func TestServer_TCPRespondMessages(t *testing.T) { testServerRespondMessages(t, "tcp") } +func TestServer_ReadTimeout(t *testing.T) { + srv := newTestServer(t, &testHandler{}) + t.Cleanup(func() { + assert.NoError(t, srv.Close()) + }) + // Sleep for "defaultReadTimeout" before trying to shutdown the server + // The point is to make sure readTimeout is properly handled by + // the "Serve*" goroutines and they don't finish their work unexpectedly + time.Sleep(defaultReadTimeout) + testThisServerRespondMessages(t, "udp", srv) + testThisServerRespondMessages(t, "tcp", srv) +} + func testServerServeCert(t *testing.T, network string) { srv := newTestServer(t, &testHandler{}) - defer srv.Close() + t.Cleanup(func() { + assert.NoError(t, srv.Close()) + }) client := &Client{ Net: network, @@ -50,7 +79,7 @@ func testServerServeCert(t *testing.T, network string) { Proto: dnsstamps.StampProtoTypeDNSCrypt, } ri, err := client.DialStamp(stamp) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, ri) assert.Equal(t, ri.ProviderName, srv.server.ProviderName) @@ -65,8 +94,13 @@ func testServerServeCert(t *testing.T, network string) { func testServerRespondMessages(t *testing.T, network string) { srv := newTestServer(t, &testHandler{}) - defer srv.Close() + t.Cleanup(func() { + assert.NoError(t, srv.Close()) + }) + testThisServerRespondMessages(t, network, srv) +} +func testThisServerRespondMessages(t *testing.T, network string, srv *testServer) { client := &Client{ Timeout: 1 * time.Second, Net: network, @@ -84,16 +118,16 @@ func testServerRespondMessages(t *testing.T, network string) { Proto: dnsstamps.StampProtoTypeDNSCrypt, } ri, err := client.DialStamp(stamp) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, ri) conn, err := net.Dial(network, stamp.ServerAddrStr) - assert.Nil(t, err) + assert.NoError(t, err) for i := 0; i < 10; i++ { m := createTestMessage() res, err := client.ExchangeConn(conn, m, ri) - assert.Nil(t, err) + assert.NoError(t, err) assertTestMessageResponse(t, res) } } @@ -114,16 +148,22 @@ func (s *testServer) UDPAddr() *net.UDPAddr { return s.udpConn.LocalAddr().(*net.UDPAddr) } -func (s *testServer) Close() { +func (s *testServer) Close() error { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + + err := s.server.Shutdown(ctx) _ = s.udpConn.Close() _ = s.tcpListen.Close() + + return err } -func newTestServer(t *testing.T, handler Handler) *testServer { +func newTestServer(t assert.TestingT, handler Handler) *testServer { rc, err := GenerateResolverConfig("example.org", nil) - assert.Nil(t, err) + assert.NoError(t, err) cert, err := rc.CreateCert() - assert.Nil(t, err) + assert.NoError(t, err) s := &Server{ ProviderName: rc.ProviderName, @@ -132,7 +172,7 @@ func newTestServer(t *testing.T, handler Handler) *testServer { } privateKey, err := HexDecodeKey(rc.PrivateKey) - assert.Nil(t, err) + assert.NoError(t, err) publicKey := ed25519.PrivateKey(privateKey).Public().(ed25519.PublicKey) srv := &testServer{ server: s, @@ -140,9 +180,9 @@ func newTestServer(t *testing.T, handler Handler) *testServer { } srv.tcpListen, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0}) - assert.Nil(t, err) + assert.NoError(t, err) srv.udpConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - assert.Nil(t, err) + assert.NoError(t, err) go s.ServeUDP(srv.udpConn) go s.ServeTCP(srv.tcpListen) diff --git a/server_udp.go b/server_udp.go index bb467a5..51ba232 100644 --- a/server_udp.go +++ b/server_udp.go @@ -2,38 +2,43 @@ package dnscrypt import ( "bytes" + "errors" "net" + "runtime" + "sync" + "time" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error) -// UDPResponseWriter - ResponseWriter implementation for UDP +// UDPResponseWriter is the ResponseWriter implementation for UDP type UDPResponseWriter struct { - udpConn *net.UDPConn // UDP connection - remoteAddr *net.UDPAddr // Remote peer address - localIP net.IP // Local IP (that was used to accept the remote connection) - encrypt encryptionFunc // DNSCRypt encryption function - req *dns.Msg // DNS query that was processed - query EncryptedQuery // DNSCrypt query properties + udpConn *net.UDPConn // UDP connection + sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP) + encrypt encryptionFunc // DNSCRypt encryption function + req *dns.Msg // DNS query that was processed + query EncryptedQuery // DNSCrypt query properties } // type check var _ ResponseWriter = &UDPResponseWriter{} -// LocalAddr - server socket local address +// LocalAddr is the server socket local address func (w *UDPResponseWriter) LocalAddr() net.Addr { return w.udpConn.LocalAddr() } -// RemoteAddr - client's address +// RemoteAddr is the client's address func (w *UDPResponseWriter) RemoteAddr() net.Addr { - return w.remoteAddr + return w.udpConn.RemoteAddr() } -// WriteMsg - writes DNS message to the client +// WriteMsg writes DNS message to the client func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error { m.Truncate(dnsSize("udp", w.req)) @@ -42,83 +47,176 @@ func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error { log.Tracef("Failed to encrypt the DNS query: %v", err) return err } + _, err = dns.WriteToSessionUDP(w.udpConn, res, w.sess) + return err +} + +// ServeUDP listens to UDP connections, queries are then processed by Server.Handler. +// It blocks the calling goroutine and to stop it you need to close the listener +// or call Server.Shutdown. +func (s *Server) ServeUDP(l *net.UDPConn) error { + err := s.prepareServeUDP(l) + if err != nil { + return err + } + + // Tracks UDP handling goroutines + udpWg := &sync.WaitGroup{} + defer s.cleanUpUDP(udpWg, l) + + // Track active goroutine + s.wg.Add(1) + + log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr()) + + // Serialize the cert right away and prepare it to be sent to the client + certTxt, err := s.getCertTXT() + if err != nil { + return err + } + + for s.isStarted() { + b, sess, err := s.readUDPMsg(l) + + // Check the error code and exit loop if necessary + if err != nil { + if !s.isStarted() { + // Stopped gracefully + return nil + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Temporary() { + // Note that timeout errors will be here (i.e. hitting ReadDeadline) + continue + } + if isConnClosed(err) { + log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop") + } else { + log.Info("got error when reading from UDP listen: %s", err) + } + return err + } + + if len(b) < minDNSPacketSize { + // Ignore the packets that are too short + continue + } + + udpWg.Add(1) + go func() { + s.serveUDPMsg(b, certTxt, sess, l) + udpWg.Done() + }() + } - _, _ = udpWrite(res, w.udpConn, w.remoteAddr, w.localIP) return nil } -// ServeUDP - listens to UDP connections, queries are then processed by Server.Handler. -// It blocks the calling goroutine and to stop it you need to close the listener. -func (s *Server) ServeUDP(l *net.UDPConn) error { +// prepareServeUDP prepares the server and listener to serving DNSCrypt +func (s *Server) prepareServeUDP(l *net.UDPConn) error { // Check that server is properly configured if !s.validate() { return ErrServerConfig } // set UDP options to allow receiving OOB data - err := udpSetOptions(l) + err := setUDPSocketOptions(l) if err != nil { return err } - // Buffer to read incoming messages - b := make([]byte, dns.MaxMsgSize) + // Protect shutdown-related fields + s.lock.Lock() + defer s.lock.Unlock() + s.initOnce.Do(s.init) - // Serialize the cert right away and prepare it to be sent to the client - certBuf, err := s.ResolverCert.Serialize() + // Mark the server as started. + // Note that we don't check if it was started before as + // Serve* methods can be called multiple times. + s.started = true + + // Track an active UDP listener + s.udpListeners[l] = struct{}{} + return err +} + +// cleanUpUDP waits until all UDP messages before cleaning up +func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) { + // Wait until UDP messages are processed + udpWg.Wait() + + // Not using it anymore so can be removed from the active listeners + s.lock.Lock() + delete(s.udpListeners, l) + s.lock.Unlock() + + // The work is finished + s.wg.Done() +} + +// readUDPMsg reads incoming UDP message +func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) { + _ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout)) + b := make([]byte, dns.MinMsgSize) + n, sess, err := dns.ReadFromSessionUDP(l, b) if err != nil { - return err + return nil, nil, err } - certTxt := packTxtString(certBuf) - // Init oobSize - it will be used later when reading and writing UDP messages - oobSize := udpGetOOBSize() - - log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr().String()) - - for { - n, localIP, addr, err := udpRead(l, b, oobSize) - if n < minDNSPacketSize { - // Ignore the packets that are too short - continue - } - - if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { - // This is an encrypted message, we should decrypt it - m, q, err := s.decrypt(b[:n]) - if err == nil { - rw := &UDPResponseWriter{ - udpConn: l, - remoteAddr: addr, - localIP: localIP, - encrypt: s.encrypt, - req: m, - query: q, - } - go s.serveDNS(rw, m) - } else { - log.Tracef("Failed to decrypt incoming message len=%d: %v", n, err) - } - } else { - // Most likely this a DNS message requesting the certificate - reply, err := s.handleHandshake(b, certTxt) - if err != nil { - log.Tracef("Failed to process a plain DNS query: %v", err) - } - if err == nil { - _, _ = l.WriteTo(reply, addr) - } - } + return b[:n], sess, err +} +// serveUDPMsg handles incoming DNS message +func (s *Server) serveUDPMsg(b []byte, certTxt string, sess *dns.SessionUDP, l *net.UDPConn) { + // First of all, check for "ClientMagic" in the incoming query + if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { + // If there's no ClientMagic in the packet, we assume this + // is a plain DNS query requesting the certificate data + reply, err := s.handleHandshake(b, certTxt) if err != nil { - if isConnClosed(err) { - log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop") - } else { - log.Info("got error when reading from UDP listen: %s", err) - } - break + log.Tracef("failed to process a plain DNS query: %v", err) } + if err == nil { + // Ignore errors, we don't care and can't handle them anyway + _, _ = dns.WriteToSessionUDP(l, reply, sess) + } + + return } + // If we got here, this is an encrypted DNSCrypt message + // We should decrypt it first to get the plain DNS query + m, q, err := s.decrypt(b) + if err == nil { + rw := &UDPResponseWriter{ + udpConn: l, + sess: sess, + encrypt: s.encrypt, + req: m, + query: q, + } + err = s.serveDNS(rw, m) + if err != nil { + log.Tracef("failed to process a DNS query: %v", err) + } + } else { + log.Tracef("failed to decrypt incoming message len=%d: %v", len(b), err) + } +} + +// setUDPSocketOptions method is necessary to be able to use dns.ReadFromSessionUDP / dns.WriteToSessionUDP +func setUDPSocketOptions(conn *net.UDPConn) error { + if runtime.GOOS == "windows" { + return nil + } + + // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. + // Try enabling receiving of ECN and packet info for both IP versions. + // We expect at least one of those syscalls to succeed. + err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) + err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) + if err6 != nil && err4 != nil { + return err4 + } return nil } diff --git a/testdata/dnscrypt-cert.opendns.txt b/testdata/dnscrypt-cert.opendns.txt new file mode 100644 index 0000000..18950f6 --- /dev/null +++ b/testdata/dnscrypt-cert.opendns.txt @@ -0,0 +1 @@ +DNSC\000\001\000\000\200\226E:H\156\203%\134\218\127]\168\239\027u\011$\191\008\239\176F\133\017\171\161\219\154\142i\164\010\239\017f\168dS\210f\197\194\169\171w\2499\1891\155<\130\218@/\155\023v\153#d\024\004\136\180\228K5\233d\180\144\189\218\186\232%\162K\004\021\160\139\225\157}\219\135\163<\215~\223\142/qc78aWoo]\221\184`]\221\184`_\190\235\224 \ No newline at end of file diff --git a/udp_unix.go b/udp_unix.go deleted file mode 100644 index 8d1f50c..0000000 --- a/udp_unix.go +++ /dev/null @@ -1,83 +0,0 @@ -// +build aix darwin dragonfly linux netbsd openbsd solaris freebsd - -package dnscrypt - -import ( - "fmt" - "net" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -// udpGetOOBSize - get max. size of received OOB data -// It will then be used in the ReadMsgUDP function -func udpGetOOBSize() int { - oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface) - oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface) - - if len(oob4) > len(oob6) { - return len(oob4) - } - return len(oob6) -} - -// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data -func udpSetOptions(c *net.UDPConn) error { - err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) - err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) - if err6 != nil && err4 != nil { - return fmt.Errorf("failed to call SetControlMessage: ipv4: %v ipv6: %v", err4, err6) - } - return nil -} - -// udpRead - receive payload and OOB data from the UDP socket -func udpRead(c *net.UDPConn, buf []byte, udpOOBSize int) (int, net.IP, *net.UDPAddr, error) { - var oobn int - oob := make([]byte, udpOOBSize) - var err error - var n int - var remoteAddr *net.UDPAddr - n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob) - if err != nil { - return -1, nil, nil, err - } - - localIP := udpGetDstFromOOB(oob[:oobn]) - return n, localIP, remoteAddr, nil -} - -// udpWrite - writes to the UDP socket and sets local IP to OOB data -func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, localIP net.IP) (int, error) { - n, _, err := conn.WriteMsgUDP(bytes, udpMakeOOBWithSrc(localIP), remoteAddr) - return n, err -} - -// udpGetDstFromOOB - get destination IP from OOB data -func udpGetDstFromOOB(oob []byte) net.IP { - cm6 := &ipv6.ControlMessage{} - if cm6.Parse(oob) == nil && cm6.Dst != nil { - return cm6.Dst - } - - cm4 := &ipv4.ControlMessage{} - if cm4.Parse(oob) == nil && cm4.Dst != nil { - return cm4.Dst - } - - return nil -} - -// udpMakeOOBWithSrc - make OOB data with a specified source IP -func udpMakeOOBWithSrc(ip net.IP) []byte { - if ip.To4() == nil { - cm := &ipv6.ControlMessage{} - cm.Src = ip - return cm.Marshal() - } - - cm := &ipv4.ControlMessage{} - cm.Src = ip - return cm.Marshal() -} diff --git a/udp_windows.go b/udp_windows.go deleted file mode 100644 index 74ac910..0000000 --- a/udp_windows.go +++ /dev/null @@ -1,30 +0,0 @@ -package dnscrypt - -import "net" - -// udpGetOOBSize - get max. size of received OOB data -// Does nothing on Windows -func udpGetOOBSize() int { - return 0 -} - -// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data -// Does nothing on Windows -func udpSetOptions(c *net.UDPConn) error { - return nil -} - -// udpRead - receive payload from the UDP socket -func udpRead(c *net.UDPConn, buf []byte, _ int) (int, net.IP, *net.UDPAddr, error) { - n, addr, err := c.ReadFrom(buf) - var udpAddr *net.UDPAddr - if addr != nil { - udpAddr = addr.(*net.UDPAddr) - } - return n, nil, udpAddr, err -} - -// udpWrite - writes to the UDP socket -func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ net.IP) (int, error) { - return conn.WriteTo(bytes, remoteAddr) -}