diff --git a/.golangci.yml b/.golangci.yml index 0dc25f5..bb43eb9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,6 +20,12 @@ linters-settings: min-complexity: 20 lll: line-length: 200 + depguard: + list-type: blacklist + include-go-root: false + packages: + - golang.org/x/net/context # we use context + - log # we use github.com/AdguardTeam/golibs/log linters: enable: @@ -41,6 +47,7 @@ linters: - misspell - stylecheck - unconvert + - depguard disable-all: true fast: true diff --git a/cert.go b/cert.go index 649cc22..d19c391 100644 --- a/cert.go +++ b/cert.go @@ -8,10 +8,10 @@ import ( "time" ) -// Cert - DNSCrypt server certificate +// Cert is a DNSCrypt server certificate // See ResolverConfig for more info on how to create one type Cert struct { - // Serial - a 4 byte serial number in big-endian format. If more than + // Serial is a 4 byte serial number in big-endian format. If more than // one certificates are valid, the client must prefer the certificate // with a higher serial number. Serial uint32 @@ -22,36 +22,36 @@ type Cert struct { // For X25519-XChacha20Poly1305, must be 0x00 0x02. EsVersion CryptoConstruction - // Signature - a 64-byte signature of ( + // Signature is a 64-byte signature of ( // ) using the Ed25519 algorithm and the // provider secret key. Ed25519 must be used in this version of the // protocol. Signature [ed25519.SignatureSize]byte - // ResolverPk - the resolver's short-term public key, which is 32 bytes when using X25519. + // ResolverPk is the resolver's short-term public key, which is 32 bytes when using X25519. // This key is used to encrypt/decrypt DNS queries ResolverPk [keySize]byte - // ResolverSk - the resolver's short-term private key, which is 32 bytes when using X25519. + // ResolverSk is the resolver's short-term private key, which is 32 bytes when using X25519. // Note that it's only used in the server implementation and never serialized/deserialized. // This key is used to encrypt/decrypt DNS queries ResolverSk [keySize]byte - // ClientMagic - the first 8 bytes of a client query that is to be built + // ClientMagic is the first 8 bytes of a client query that is to be built // using the information from this certificate. It may be a truncated // public key. Two valid certificates cannot share the same . ClientMagic [clientMagicSize]byte - // NotAfter - the date the certificate is valid from, as a big-endian + // NotAfter is the date the certificate is valid from, as a big-endian // 4-byte unsigned Unix timestamp. NotBefore uint32 - // NotAfter - the date the certificate is valid until (inclusive), as a + // NotAfter is the date the certificate is valid until (inclusive), as a // big-endian 4-byte unsigned Unix timestamp. NotAfter uint32 } -// Serialize - serializes the cert to bytes +// Serialize serializes the cert to bytes // ::= // // @@ -86,7 +86,7 @@ func (c *Cert) Serialize() ([]byte, error) { return b, nil } -// Deserialize - deserializes certificate from a byte array +// Deserialize deserializes certificate from a byte array // ::= // // @@ -127,7 +127,7 @@ func (c *Cert) Deserialize(b []byte) error { return nil } -// VerifyDate - checks that cert is valid at this moment +// VerifyDate checks that the cert is valid at this moment func (c *Cert) VerifyDate() bool { if c.NotBefore >= c.NotAfter { return false @@ -139,14 +139,14 @@ func (c *Cert) VerifyDate() bool { return true } -// VerifySignature - checks if the cert is properly signed with the specified signature +// VerifySignature checks if the cert is properly signed with the specified signature func (c *Cert) VerifySignature(publicKey ed25519.PublicKey) bool { b := make([]byte, 52) c.writeSigned(b) return ed25519.Verify(publicKey, b, c.Signature[:]) } -// Sign - creates cert.Signature +// Sign creates cert.Signature func (c *Cert) Sign(privateKey ed25519.PrivateKey) { b := make([]byte, 52) c.writeSigned(b) @@ -154,14 +154,14 @@ func (c *Cert) Sign(privateKey ed25519.PrivateKey) { copy(c.Signature[:64], signature[:64]) } -// String - Cert's string representation +// String Cert's string representation func (c *Cert) String() string { return fmt.Sprintf("Certificate Serial=%d NotBefore=%s NotAfter=%s EsVersion=%s", c.Serial, time.Unix(int64(c.NotBefore), 0).String(), time.Unix(int64(c.NotAfter), 0).String(), c.EsVersion.String()) } -// writeSigned - writes ( ) +// writeSigned writes ( ) func (c *Cert) writeSigned(dst []byte) { // copy(dst[:32], c.ResolverPk[:keySize]) diff --git a/cert_test.go b/cert_test.go index edbc6f7..4b1d5e1 100644 --- a/cert_test.go +++ b/cert_test.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/ed25519" "crypto/rand" + "io/ioutil" "testing" "time" @@ -21,13 +22,13 @@ func TestCertSerialize(t *testing.T) { // serialize b, err := cert.Serialize() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, 124, len(b)) // check that we can deserialize it cert2 := Cert{} err = cert2.Deserialize(b) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, cert.Serial, cert2.Serial) assert.Equal(t, cert.NotBefore, cert2.NotBefore) assert.Equal(t, cert.NotAfter, cert2.NotAfter) @@ -39,12 +40,15 @@ func TestCertSerialize(t *testing.T) { func TestCertDeserialize(t *testing.T) { // dig -t txt 2.dnscrypt-cert.opendns.com. -p 443 @208.67.220.220 - b, err := unpackTxtString("DNSC\\000\\001\\000\\000\\200\\226E:H\\156\\203%\\134\\218\\127]\\168\\239\\027u\\011$\\191\\008\\239\\176F\\133\\017\\171\\161\\219\\154\\142i\\164\\010\\239\\017f\\168dS\\210f\\197\\194\\169\\171w\\2499\\1891\\155<\\130\\218@/\\155\\023v\\153#d\\024\\004\\136\\180\\228K5\\233d\\180\\144\\189\\218\\186\\232%\\162K\\004\\021\\160\\139\\225\\157}\\219\\135\\163<\\215~\\223\\142/qc78aWoo]\\221\\184`]\\221\\184`_\\190\\235\\224") - assert.Nil(t, err) + certBytes, err := ioutil.ReadFile("testdata/dnscrypt-cert.opendns.txt") + assert.NoError(t, err) + + b, err := unpackTxtString(string(certBytes)) + assert.NoError(t, err) cert := &Cert{} err = cert.Deserialize(b) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, uint32(1574811744), cert.Serial) assert.Equal(t, XSalsa20Poly1305, cert.EsVersion) assert.Equal(t, uint32(1574811744), cert.NotBefore) @@ -69,7 +73,7 @@ func generateValidCert(t *testing.T) (*Cert, ed25519.PublicKey, ed25519.PrivateK // generate private key publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader) - assert.Nil(t, err) + assert.NoError(t, err) // sign the data cert.Sign(privateKey) diff --git a/client.go b/client.go index 9ca3925..ebf1574 100644 --- a/client.go +++ b/client.go @@ -12,7 +12,7 @@ import ( "github.com/miekg/dns" ) -// Client - DNSCrypt resolver client +// Client is a DNSCrypt resolver client type Client struct { Net string // protocol (can be "udp" or "tcp", by default - "udp") Timeout time.Duration // read/write timeout @@ -124,7 +124,7 @@ func (c *Client) ExchangeConn(conn net.Conn, m *dns.Msg, resolverInfo *ResolverI return res, nil } -// writeQuery - writes query to the network connection +// writeQuery writes query to the network connection // depending on the protocol we may write a 2-byte prefix or not func (c *Client) writeQuery(conn net.Conn, query []byte) error { var err error @@ -145,7 +145,7 @@ func (c *Client) writeQuery(conn net.Conn, query []byte) error { return err } -// readResponse - reads response from the network connection +// readResponse reads response from the network connection // depending on the protocol, we may read a 2-byte prefix or not func (c *Client) readResponse(conn net.Conn) ([]byte, error) { if c.Timeout > 0 { @@ -171,7 +171,7 @@ func (c *Client) readResponse(conn net.Conn) ([]byte, error) { return readPrefixed(conn) } -// encrypt - encrypts a DNS message using shared key from the resolver info +// encrypt encrypts a DNS message using shared key from the resolver info func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) { q := EncryptedQuery{ EsVersion: resolverInfo.ResolverCert.EsVersion, @@ -185,7 +185,7 @@ func (c *Client) encrypt(m *dns.Msg, resolverInfo *ResolverInfo) ([]byte, error) return q.Encrypt(query, resolverInfo.SharedKey) } -// decrypts - decrypts a DNS message using shared key from the resolver info +// decrypts decrypts a DNS message using a shared key from the resolver info func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) { dr := EncryptedResponse{ EsVersion: resolverInfo.ResolverCert.EsVersion, @@ -203,7 +203,7 @@ func (c *Client) decrypt(b []byte, resolverInfo *ResolverInfo) (*dns.Msg, error) return res, nil } -// fetchCert - loads DNSCrypt cert from the specified server +// fetchCert loads DNSCrypt cert from the specified server func (c *Client) fetchCert(stamp dnsstamps.ServerStamp) (*Cert, error) { providerName := stamp.ProviderName if !strings.HasSuffix(providerName, ".") { diff --git a/client_test.go b/client_test.go index e5f294a..6bbd4bd 100644 --- a/client_test.go +++ b/client_test.go @@ -63,7 +63,7 @@ func TestTimeoutOnDialExchange(t *testing.T) { client := Client{Timeout: 300 * time.Millisecond} serverInfo, err := client.Dial(stampStr) - assert.Nil(t, err) + assert.NoError(t, err) // Point it to an IP where there's no DNSCrypt server serverInfo.ServerAddress = "8.8.8.8:5443" @@ -105,12 +105,15 @@ func TestFetchCertPublicResolvers(t *testing.T) { for _, test := range stamps { stamp, err := dnsstamps.NewServerStampFromString(test.stampStr) - assert.Nil(t, err) + assert.NoError(t, err) t.Run(stamp.ProviderName, func(t *testing.T) { - c := &Client{Net: "udp"} + c := &Client{ + Net: "udp", + Timeout: time.Second * 5, + } resolverInfo, err := c.DialStamp(stamp) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, resolverInfo) assert.True(t, resolverInfo.ResolverCert.VerifyDate()) assert.True(t, resolverInfo.ResolverCert.VerifySignature(stamp.ServerPk)) @@ -146,7 +149,7 @@ func TestExchangePublicResolvers(t *testing.T) { for _, test := range stamps { stamp, err := dnsstamps.NewServerStampFromString(test.stampStr) - assert.Nil(t, err) + assert.NoError(t, err) t.Run(stamp.ProviderName, func(t *testing.T) { checkDNSCryptServer(t, test.stampStr, "udp") @@ -158,12 +161,12 @@ func TestExchangePublicResolvers(t *testing.T) { func checkDNSCryptServer(t *testing.T, stampStr string, network string) { client := Client{Net: network, Timeout: 10 * time.Second} resolverInfo, err := client.Dial(stampStr) - assert.Nil(t, err) + assert.NoError(t, err) req := createTestMessage() reply, err := client.Exchange(req, resolverInfo) - assert.Nil(t, err) + assert.NoError(t, err) assertTestMessageResponse(t, reply) } @@ -177,7 +180,7 @@ func createTestMessage() *dns.Msg { return &req } -func assertTestMessageResponse(t *testing.T, reply *dns.Msg) { +func assertTestMessageResponse(t assert.TestingT, reply *dns.Msg) { assert.NotNil(t, reply) assert.Equal(t, 1, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) diff --git a/cmd/convert_dnscrypt_wrapper.go b/cmd/convert_dnscrypt_wrapper.go index fc889ac..e3d5e88 100644 --- a/cmd/convert_dnscrypt_wrapper.go +++ b/cmd/convert_dnscrypt_wrapper.go @@ -12,7 +12,7 @@ import ( "gopkg.in/yaml.v3" ) -// ConvertWrapperArgs - "convert-dnscrypt-wrapper" command arguments +// ConvertWrapperArgs is the "convert-dnscrypt-wrapper" command arguments structure type ConvertWrapperArgs struct { PrivateKeyFile string `short:"p" long:"private-key" description:"Path to the DNSCrypt resolver private key file that is used for signing certificates. Param is required." required:"true"` ResolverSkFile string `short:"r" long:"resolver-secret" description:"Path to the Short-term privacy key file for encrypting/decrypting DNS queries. If not specified, resolver_secret and resolver_public will be randomly generated."` @@ -21,7 +21,7 @@ type ConvertWrapperArgs struct { CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"` } -// convertWrapper - generates DNSCrypt configuration from both dnscrypt and server private keys +// convertWrapper generates DNSCrypt configuration from both dnscrypt and server private keys func convertWrapper(args ConvertWrapperArgs) { log.Info("Generating configuration for %s", args.ProviderName) @@ -71,7 +71,7 @@ func convertWrapper(args ConvertWrapperArgs) { } } -// validateRc - verifies that the certificate is correctly +// validateRc verifies that the certificate is correctly // created and validated for this resolver config. if rc valid returns nil. func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error { cert, err := rc.CreateCert() @@ -90,7 +90,7 @@ func validateRc(rc dnscrypt.ResolverConfig, publicKey ed25519.PublicKey) error { return nil } -// getResolverPk - calculates public key from private key +// getResolverPk calculates public key from private key func getResolverPk(private ed25519.PrivateKey) ed25519.PublicKey { resolverSk := [32]byte{} resolverPk := [32]byte{} diff --git a/cmd/generate.go b/cmd/generate.go index 6d93b4d..6f8d914 100644 --- a/cmd/generate.go +++ b/cmd/generate.go @@ -8,7 +8,7 @@ import ( "gopkg.in/yaml.v3" ) -// GenerateArgs - "generate" command arguments +// GenerateArgs is the "generate" command arguments structure type GenerateArgs struct { ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt provider name. Param is required." required:"true"` Out string `short:"o" long:"out" description:"Path to the resulting config file. Param is required." required:"true"` @@ -16,7 +16,7 @@ type GenerateArgs struct { CertificateTTL int `short:"t" long:"ttl" description:"Certificate time-to-live (seconds)"` } -// generate - generates DNSCrypt server configuration +// generate generates a DNSCrypt server configuration func generate(args GenerateArgs) { log.Info("Generating configuration for %s", args.ProviderName) diff --git a/cmd/lookup.go b/cmd/lookup.go index 0e873d4..e5d31c8 100644 --- a/cmd/lookup.go +++ b/cmd/lookup.go @@ -12,7 +12,7 @@ import ( "github.com/miekg/dns" ) -// LookupStampArgs - "lookup-stamp" command arguments +// LookupStampArgs is the "lookup-stamp" command arguments structure type LookupStampArgs struct { Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"` Stamp string `short:"s" long:"stamp" description:"DNSCrypt resolver stamp. Param is required." required:"true"` @@ -20,7 +20,7 @@ type LookupStampArgs struct { Type string `short:"t" long:"type" description:"DNS query type" default:"A"` } -// LookupArgs - "lookup" command arguments +// LookupArgs is the "lookup" command arguments structure type LookupArgs struct { Network string `short:"n" long:"network" description:"network type (tcp/udp)" default:"udp"` ProviderName string `short:"p" long:"provider-name" description:"DNSCrypt resolver provider name. Param is required." required:"true"` @@ -30,7 +30,7 @@ type LookupArgs struct { Type string `short:"t" long:"type" description:"DNS query type" default:"A"` } -// LookupResult - lookup result that contains the cert info and the query response +// LookupResult is the lookup result that contains the cert info and the query response type LookupResult struct { Certificate struct { Serial uint32 `json:"serial"` @@ -42,7 +42,7 @@ type LookupResult struct { Reply *dns.Msg `json:"reply"` } -// lookup - performs a DNS lookup, prints DNSCrypt info and lookup results +// lookup performs a DNS lookup, prints DNSCrypt info and lookup results func lookup(args LookupArgs) { serverPk, err := dnscrypt.HexDecodeKey(args.PublicKey) if err != nil { @@ -64,7 +64,7 @@ func lookup(args LookupArgs) { }) } -// lookupStamp - performs a DNS lookup, prints DNSCrypt cert info and lookup results +// lookupStamp performs a DNS lookup, prints DNSCrypt cert info and lookup results func lookupStamp(args LookupStampArgs) { c := &dnscrypt.Client{ Net: args.Network, diff --git a/cmd/main.go b/cmd/main.go index 077b484..28e2a2d 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -4,7 +4,6 @@ import ( "os" "github.com/AdguardTeam/golibs/log" - goFlags "github.com/jessevdk/go-flags" ) diff --git a/cmd/server.go b/cmd/server.go index a10609d..46295ed 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -13,7 +13,7 @@ import ( "gopkg.in/yaml.v3" ) -// ServerArgs - "server" command arguments +// ServerArgs is the "server" command arguments type ServerArgs struct { Config string `short:"c" long:"config" description:"Path to the DNSCrypt configuration file. Param is required." required:"true"` Forward string `short:"f" long:"forward" description:"Forwards DNS queries to the specified address" default:"94.140.14.140:53"` @@ -21,7 +21,7 @@ type ServerArgs struct { ListenPorts []int `short:"p" long:"port" description:"Listening ports" default:"443"` } -// server - runs a DNSCrypt server +// server runs a DNSCrypt server func server(args ServerArgs) { log.Info("Starting DNSCrypt server") @@ -72,7 +72,7 @@ func server(args ServerArgs) { } } -// createListeners - creates listeners for our server +// createListeners creates listeners for our server func createListeners(args ServerArgs) (tcp []net.Listener, udp []*net.UDPConn) { for _, addr := range args.ListenAddrs { ip := net.ParseIP(addr) @@ -101,7 +101,10 @@ type forwardHandler struct { addr string } -// ServeDNS - implements Handler interface +// type check +var _ dnscrypt.Handler = &forwardHandler{} + +// ServeDNS implements Handler interface func (f *forwardHandler) ServeDNS(rw dnscrypt.ResponseWriter, r *dns.Msg) error { res, err := dns.Exchange(r, f.addr) if err != nil { diff --git a/constants.go b/constants.go index 81a2cf2..2411d78 100644 --- a/constants.go +++ b/constants.go @@ -1,52 +1,58 @@ package dnscrypt -import "errors" +// Error represents a dnscrypt error. +type Error string -var ( - // ErrTooShort - DNS query is shorter than possible - ErrTooShort = errors.New("DNSCrypt message is too short") +func (e Error) Error() string { return "dnscrypt: " + string(e) } - // ErrQueryTooLarge - DNS query is larger than max allowed size - ErrQueryTooLarge = errors.New("DNSCrypt query is too large") +const ( + // ErrTooShort means that the DNS query is shorter than possible + ErrTooShort = Error("message is too short") - // ErrEsVersion - cert contains unsupported es-version - ErrEsVersion = errors.New("unsupported es-version") + // ErrQueryTooLarge means that the DNS query is larger than max allowed size + ErrQueryTooLarge = Error("DNSCrypt query is too large") - // ErrInvalidDate - cert is not valid for the current time - ErrInvalidDate = errors.New("cert has invalid ts-start or ts-end") + // ErrEsVersion means that the cert contains unsupported es-version + ErrEsVersion = Error("unsupported es-version") - // ErrInvalidCertSignature - cert has invalid signature - ErrInvalidCertSignature = errors.New("cert has invalid signature") + // ErrInvalidDate means that the cert is not valid for the current time + ErrInvalidDate = Error("cert has invalid ts-start or ts-end") - // ErrInvalidQuery - failed to decrypt a DNSCrypt query - ErrInvalidQuery = errors.New("DNSCrypt query is invalid and cannot be decrypted") + // ErrInvalidCertSignature means that the cert has invalid signature + ErrInvalidCertSignature = Error("cert has invalid signature") - // ErrInvalidClientMagic - client-magic does not match - ErrInvalidClientMagic = errors.New("DNSCrypt query contains invalid client magic") + // ErrInvalidQuery means that it failed to decrypt a DNSCrypt query + ErrInvalidQuery = Error("DNSCrypt query is invalid and cannot be decrypted") - // ErrInvalidResolverMagic - server-magic does not match - ErrInvalidResolverMagic = errors.New("DNSCrypt response contains invalid resolver magic") + // ErrInvalidClientMagic means that client-magic does not match + ErrInvalidClientMagic = Error("DNSCrypt query contains invalid client magic") - // ErrInvalidResponse - failed to decrypt a DNSCrypt response - ErrInvalidResponse = errors.New("DNSCrypt response is invalid and cannot be decrypted") + // ErrInvalidResolverMagic means that server-magic does not match + ErrInvalidResolverMagic = Error("DNSCrypt response contains invalid resolver magic") - // ErrInvalidPadding - failed to unpad a query - ErrInvalidPadding = errors.New("invalid padding") + // ErrInvalidResponse means that it failed to decrypt a DNSCrypt response + ErrInvalidResponse = Error("DNSCrypt response is invalid and cannot be decrypted") - // ErrInvalidDNSStamp - invalid DNS stamp - ErrInvalidDNSStamp = errors.New("invalid DNS stamp") + // ErrInvalidPadding means that it failed to unpad a query + ErrInvalidPadding = Error("invalid padding") - // ErrFailedToFetchCert - failed to fetch DNSCrypt certificate - ErrFailedToFetchCert = errors.New("failed to fetch DNSCrypt certificate") + // ErrInvalidDNSStamp means an invalid DNS stamp + ErrInvalidDNSStamp = Error("invalid DNS stamp") - // ErrCertTooShort - failed to deserialize cert, too short - ErrCertTooShort = errors.New("cert is too short") + // ErrFailedToFetchCert means that it failed to fetch DNSCrypt certificate + ErrFailedToFetchCert = Error("failed to fetch DNSCrypt certificate") - // ErrCertMagic - invalid cert magic - ErrCertMagic = errors.New("invalid cert magic") + // ErrCertTooShort means that it failed to deserialize cert, too short + ErrCertTooShort = Error("cert is too short") - // ErrServerConfig - failed to start the DNSCrypt server - invalid configuration - ErrServerConfig = errors.New("invalid server configuration") + // ErrCertMagic means an invalid cert magic + ErrCertMagic = Error("invalid cert magic") + + // ErrServerConfig means that it failed to start the DNSCrypt server - invalid configuration + ErrServerConfig = Error("invalid server configuration") + + // ErrServerNotStarted is returned if there's nothing to shutdown + ErrServerNotStarted = Error("server is not started") ) const ( @@ -55,7 +61,7 @@ const ( // Some servers do not work if padded length is less than 256. Example: Quad9 minUDPQuestionSize = 256 - // - maximum allowed query length + // is the maximum allowed query length maxQueryLen = 1252 // Minimum possible DNS packet size @@ -68,7 +74,7 @@ const ( // size of the shared key used to encrypt/decrypt messages sharedKeySize = 32 - // ClientMagic - the first 8 bytes of a client query that is to be built + // ClientMagic is the first 8 bytes of a client query that is to be built // using the information from this certificate. It may be a truncated // public key. Two valid certificates cannot share the same . clientMagicSize = 8 @@ -82,10 +88,10 @@ const ( ) var ( - // certMagic - bytes sequence that must be in the beginning of the serialized cert + // certMagic is a bytes sequence that must be in the beginning of the serialized cert certMagic = [4]byte{0x44, 0x4e, 0x53, 0x43} - // resolverMagic - byte sequence that must be in the beginning of every response + // resolverMagic is a byte sequence that must be in the beginning of every response resolverMagic = []byte{0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38} ) diff --git a/encrypted_query.go b/encrypted_query.go index 545991e..f64ba8f 100644 --- a/encrypted_query.go +++ b/encrypted_query.go @@ -10,19 +10,19 @@ import ( "golang.org/x/crypto/nacl/secretbox" ) -// EncryptedQuery - a structure for encrypting and decrypting client queries +// EncryptedQuery is a structure for encrypting and decrypting client queries // // ::= // ::= AE( , ) type EncryptedQuery struct { - // EsVersion - encryption to use + // EsVersion is the encryption to use EsVersion CryptoConstruction - // ClientMagic - a 8 byte identifier for the resolver certificate + // ClientMagic is a 8 byte identifier for the resolver certificate // chosen by the client. ClientMagic [clientMagicSize]byte - // ClientPk - the client's public key + // ClientPk is the client's public key ClientPk [keySize]byte // With a 24 bytes nonce, a question sent by a DNSCrypt client must be @@ -36,7 +36,7 @@ type EncryptedQuery struct { Nonce [nonceSize]byte } -// Encrypt - encrypts the specified DNS query, returns encrypted data ready to be sent. +// Encrypt encrypts the specified DNS query, returns encrypted data ready to be sent. // // Note that this method will generate a random nonce automatically. // @@ -79,7 +79,7 @@ func (q *EncryptedQuery) Encrypt(packet []byte, sharedKey [sharedKeySize]byte) ( return query, nil } -// Decrypt - decrypts the client query, returns decrypted DNS packet. +// Decrypt decrypts the client query, returns decrypted DNS packet. // // Please note, that before calling this method the following fields must be set: // * ClientMagic -- to verify the query diff --git a/encrypted_query_test.go b/encrypted_query_test.go index d70643f..cffb75c 100644 --- a/encrypted_query_test.go +++ b/encrypted_query_test.go @@ -23,7 +23,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction) // Generate client shared key clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey) - assert.Nil(t, err) + assert.NoError(t, err) clientMagic := [clientMagicSize]byte{} _, _ = rand.Read(clientMagic[:]) @@ -40,7 +40,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction) // Encrypt it encrypted, err := q1.Encrypt(packet, clientSharedKey) - assert.Nil(t, err) + assert.NoError(t, err) // Now let's try decrypting it q2 := EncryptedQuery{ @@ -50,7 +50,7 @@ func testDNSCryptQueryEncryptDecrypt(t *testing.T, esVersion CryptoConstruction) // Decrypt it decrypted, err := q2.Decrypt(encrypted, serverSecretKey) - assert.Nil(t, err) + assert.NoError(t, err) // Check that packet is the same assert.True(t, bytes.Equal(packet, decrypted)) diff --git a/encrypted_response.go b/encrypted_response.go index 30d8405..212cb87 100644 --- a/encrypted_response.go +++ b/encrypted_response.go @@ -10,12 +10,12 @@ import ( "golang.org/x/crypto/nacl/secretbox" ) -// EncryptedResponse - structure for encrypting/decrypting server responses +// EncryptedResponse is a structure for encrypting/decrypting server responses // // ::= // ::= AE(, , ) type EncryptedResponse struct { - // EsVersion - encryption to use + // EsVersion is the encryption to use EsVersion CryptoConstruction // Nonce - ::= @@ -23,7 +23,7 @@ type EncryptedResponse struct { Nonce [nonceSize]byte } -// Encrypt - encrypts the server response +// Encrypt encrypts the server response // // EsVersion must be set. // Nonce needs to be set to "client-nonce". @@ -57,7 +57,7 @@ func (r *EncryptedResponse) Encrypt(packet []byte, sharedKey [sharedKeySize]byte return response, nil } -// Decrypt - decrypts the server response +// Decrypt decrypts the server response // // EsVersion must be set. func (r *EncryptedResponse) Decrypt(response []byte, sharedKey [sharedKeySize]byte) ([]byte, error) { diff --git a/encrypted_response_test.go b/encrypted_response_test.go index 0de8dbe..6c755cc 100644 --- a/encrypted_response_test.go +++ b/encrypted_response_test.go @@ -24,11 +24,11 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi // Generate client shared key clientSharedKey, err := computeSharedKey(esVersion, &clientSecretKey, &serverPublicKey) - assert.Nil(t, err) + assert.NoError(t, err) // Generate server shared key serverSharedKey, err := computeSharedKey(esVersion, &serverSecretKey, &clientPublicKey) - assert.Nil(t, err) + assert.NoError(t, err) r1 := &EncryptedResponse{ EsVersion: esVersion, @@ -42,7 +42,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi // Encrypt it encrypted, err := r1.Encrypt(packet, serverSharedKey) - assert.Nil(t, err) + assert.NoError(t, err) // Now let's try decrypting it r2 := &EncryptedResponse{ @@ -51,7 +51,7 @@ func testDNSCryptResponseEncryptDecrypt(t *testing.T, esVersion CryptoConstructi // Decrypt it decrypted, err := r2.Decrypt(encrypted, clientSharedKey) - assert.Nil(t, err) + assert.NoError(t, err) // Check that packet is the same assert.True(t, bytes.Equal(packet, decrypted)) diff --git a/generate.go b/generate.go index 3891e31..430d955 100644 --- a/generate.go +++ b/generate.go @@ -14,36 +14,37 @@ import ( const dnsCryptV2Prefix = "2.dnscrypt-cert." -// ResolverConfig - DNSCrypt resolver configuration +// ResolverConfig is the DNSCrypt resolver configuration type ResolverConfig struct { // DNSCrypt provider name ProviderName string `yaml:"provider_name"` - // PublicKey - DNSCrypt resolver public key + // PublicKey is the DNSCrypt resolver public key PublicKey string `yaml:"public_key"` - // PrivateKey - DNSCrypt resolver private key + // PrivateKey is the DNSCrypt resolver private key // The main and only purpose of this key is to sign the certificate PrivateKey string `yaml:"private_key"` - // ResolverSk - hex-encoded short-term private key. + // ResolverSk is a hex-encoded short-term private key. // This key is used to encrypt/decrypt DNS queries. // If not set, we'll generate a new random ResolverSk and ResolverPk. ResolverSk string `yaml:"resolver_secret"` - // ResolverPk - hex-encoded short-term public key corresponding to ResolverSk. + // ResolverPk is a hex-encoded short-term public key corresponding to ResolverSk. // This key is used to encrypt/decrypt DNS queries. ResolverPk string `yaml:"resolver_public"` - // EsVersion - crypto to use in this resolver + // EsVersion is the crypto to use in this resolver EsVersion CryptoConstruction `yaml:"es_version"` - // CertificateTTL - time-to-live for the certificate that is generated using this ResolverConfig. + // CertificateTTL is the time-to-live value for the certificate that is + // generated using this ResolverConfig. // If not set, we'll use 1 year by default. CertificateTTL time.Duration `yaml:"certificate_ttl"` } -// CreateCert - generates a signed Cert to be used by Server +// CreateCert generates a signed Cert to be used by Server func (rc *ResolverConfig) CreateCert() (*Cert, error) { log.Printf("Creating signed DNSCrypt certificate") @@ -98,7 +99,7 @@ func (rc *ResolverConfig) CreateCert() (*Cert, error) { return cert, nil } -// CreateStamp - generates a DNS stamp for this resolver +// CreateStamp generates a DNS stamp for this resolver func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error) { stamp := dnsstamps.ServerStamp{ ProviderName: rc.ProviderName, @@ -115,7 +116,7 @@ func (rc *ResolverConfig) CreateStamp(addr string) (dnsstamps.ServerStamp, error return stamp, nil } -// GenerateResolverConfig - generates resolver configuration for a given provider name. +// GenerateResolverConfig generates resolver configuration for a given provider name. // providerName is mandatory. If needed, "2.dnscrypt-cert." prefix is added to it. // privateKey is optional. If not set, it will be generated automatically. func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey) (ResolverConfig, error) { @@ -145,18 +146,18 @@ func GenerateResolverConfig(providerName string, privateKey ed25519.PrivateKey) return rc, nil } -// HexEncodeKey - encodes a byte slice to a hex-encoded string. +// HexEncodeKey encodes a byte slice to a hex-encoded string. func HexEncodeKey(b []byte) string { return strings.ToUpper(hex.EncodeToString(b)) } -// HexDecodeKey - decodes a hex-encoded string with (optional) colons +// HexDecodeKey decodes a hex-encoded string with (optional) colons // to a byte array. func HexDecodeKey(str string) ([]byte, error) { return hex.DecodeString(strings.ReplaceAll(str, ":", "")) } -// generateRandomKeyPair - generates a random key-pair +// generateRandomKeyPair generates a random key-pair func generateRandomKeyPair() (privateKey [keySize]byte, publicKey [keySize]byte) { privateKey = [keySize]byte{} publicKey = [keySize]byte{} diff --git a/generate_test.go b/generate_test.go index 6b46079..a6099d2 100644 --- a/generate_test.go +++ b/generate_test.go @@ -15,24 +15,24 @@ func TestHexEncodeKey(t *testing.T) { func TestHexDecodeKey(t *testing.T) { b, err := HexDecodeKey("01:02:03:04") - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, bytes.Equal(b, []byte{1, 2, 3, 4})) } func TestGenerateResolverConfig(t *testing.T) { rc, err := GenerateResolverConfig("example.org", nil) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, "2.dnscrypt-cert.example.org", rc.ProviderName) assert.Equal(t, ed25519.PrivateKeySize*2, len(rc.PrivateKey)) assert.Equal(t, keySize*2, len(rc.ResolverSk)) assert.Equal(t, keySize*2, len(rc.ResolverPk)) cert, err := rc.CreateCert() - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, cert.VerifyDate()) publicKey, err := HexDecodeKey(rc.PublicKey) - assert.Nil(t, err) + assert.NoError(t, err) assert.True(t, cert.VerifySignature(publicKey)) } diff --git a/go.mod b/go.mod index 9a29d25..1c90381 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/aead/poly1305 v0.0.0-20180717145839-3fee0db0b635 github.com/ameshkov/dnsstamps v1.0.1 github.com/jessevdk/go-flags v1.4.0 - github.com/miekg/dns v1.1.29 + github.com/miekg/dns v1.1.40 github.com/stretchr/testify v1.6.1 golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e diff --git a/go.sum b/go.sum index c07da94..4dbbe1e 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGAR github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/miekg/dns v1.1.29 h1:xHBEhR+t5RzcFJjBLJlax2daXOrTYtr9z4WdKEfWFzg= github.com/miekg/dns v1.1.29/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= +github.com/miekg/dns v1.1.40 h1:pyyPFfGMnciYUk/mXpKkVmeMQjfXqt3FAJ2hy7tPiLA= +github.com/miekg/dns v1.1.40/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/handler.go b/handler.go index eb3430f..2943ae5 100644 --- a/handler.go +++ b/handler.go @@ -14,14 +14,14 @@ type Handler interface { ServeDNS(rw ResponseWriter, r *dns.Msg) error } -// ResponseWriter - interface that needs to be implemented for different protocols +// ResponseWriter is the interface that needs to be implemented for different protocols type ResponseWriter interface { LocalAddr() net.Addr // LocalAddr - local socket address RemoteAddr() net.Addr // RemoteAddr - remote client socket address WriteMsg(m *dns.Msg) error // WriteMsg - writes response message to the client } -// DefaultHandler - default Handler implementation +// DefaultHandler is the default Handler implementation // that is used by Server if custom handler is not configured var DefaultHandler Handler = &defaultHandler{ udpClient: &dns.Client{ @@ -41,7 +41,7 @@ type defaultHandler struct { addr string } -// ServeDNS - implements Handler interface +// ServeDNS implements Handler interface func (h *defaultHandler) ServeDNS(rw ResponseWriter, r *dns.Msg) error { // Google DNS res, _, err := h.udpClient.Exchange(r, h.addr) diff --git a/server.go b/server.go index 0e08557..d127242 100644 --- a/server.go +++ b/server.go @@ -1,27 +1,167 @@ package dnscrypt import ( + "context" + "net" + "sync" + "time" + "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -// Server - a simple DNSCrypt server implementation +// default read timeout for all reads +const defaultReadTimeout = 2 * time.Second + +// in case of TCP we only use defaultReadTimeout for the first read +// then we start using defaultTCPIdleTimeout +const defaultTCPIdleTimeout = 8 * time.Second + +// helper struct that is used in several SetReadDeadline calls +var longTimeAgo = time.Unix(1, 0) + +// ServerDNSCrypt is an interface for a DNSCrypt server +type ServerDNSCrypt interface { + // ServeTCP listens to TCP connections, queries are then processed by Server.Handler. + // It blocks the calling goroutine and to stop it you need to close the listener + // or call ServerDNSCrypt.Shutdown. + ServeTCP(l net.Listener) error + + // ServeUDP listens to UDP connections, queries are then processed by Server.Handler. + // It blocks the calling goroutine and to stop it you need to close the listener + // or call ServerDNSCrypt.Shutdown. + ServeUDP(l *net.UDPConn) error + + // Shutdown tries to gracefully shutdown the server. It waits until all + // connections are processed and only after that it leaves the method. + // If context deadline is specified, it will exit earlier + // or call ServerDNSCrypt.Shutdown. + Shutdown(ctx context.Context) error +} + +// Server is a simple DNSCrypt server implementation type Server struct { - // ProviderName - DNSCrypt provider name + // ProviderName is a DNSCrypt provider name ProviderName string - // ResolverCert - contains resolver certificate. + // ResolverCert contains resolver certificate. ResolverCert *Cert // Handler to invoke. If nil, uses DefaultHandler. Handler Handler + + // make sure init is called only once + initOnce sync.Once + + // Shutdown handling + // -- + lock sync.RWMutex // protects access to all the fields below + started bool + wg sync.WaitGroup // active workers (servers) + tcpListeners map[net.Listener]struct{} // track active TCP listeners + udpListeners map[*net.UDPConn]struct{} // track active UDP listeners + tcpConns map[net.Conn]struct{} // track active connections } -// serveDNS - serves DNS response -func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) { +// type check +var _ ServerDNSCrypt = &Server{} + +// prepareShutdown - prepares the server to shutdown: +// unblocks reads from all connections related to this server +// marks the server as stopped +// if the server is not started, returns ErrServerNotStarted +func (s *Server) prepareShutdown() error { + s.lock.Lock() + defer s.lock.Unlock() + + if !s.started { + log.Info("Server is not started") + return ErrServerNotStarted + } + + s.started = false + + // These listeners were passed to us from the outside so we cannot close + // them here - this is up to the calling code to do that. Instead of that, + // we call Set(Read)Deadline to unblock goroutines that are currently + // blocked on reading from those listeners. + // For tcpConns we would like to avoid closing them to be able to process + // queries before shutting everything down. + + // Unblock reads for all active tcpConns + for conn := range s.tcpConns { + _ = conn.SetReadDeadline(longTimeAgo) + } + + // Unblock reads for all active TCP listeners + for l := range s.tcpListeners { + switch v := l.(type) { + case *net.TCPListener: + _ = v.SetDeadline(longTimeAgo) + } + } + + // Unblock reads for all active UDP listeners + for l := range s.udpListeners { + _ = l.SetReadDeadline(longTimeAgo) + } + + return nil +} + +// Shutdown tries to gracefully shutdown the server. It waits until all +// connections are processed and only after that it leaves the method. +// If context deadline is specified, it will exit earlier. +func (s *Server) Shutdown(ctx context.Context) error { + log.Info("Shutting down the DNSCrypt server") + + err := s.prepareShutdown() + if err != nil { + return err + } + + // Using this channel to wait until all goroutines finish their work + closed := make(chan struct{}) + go func() { + s.wg.Wait() + log.Info("Serve goroutines finished their work") + close(closed) + }() + + // Wait for either all goroutines finish their work + // Or for the context deadline + select { + case <-closed: + log.Info("DNSCrypt server has been stopped") + case <-ctx.Done(): + log.Info("DNSCrypt server shutdown has timed out") + err = ctx.Err() + } + + return err +} + +// init initializes (lazily) Server properties on startup +// this method is called from Server.ServeTCP and Server.ServeUDP +func (s *Server) init() { + s.tcpConns = map[net.Conn]struct{}{} + s.udpListeners = map[*net.UDPConn]struct{}{} + s.tcpListeners = map[net.Listener]struct{}{} +} + +// isStarted returns true if the server is processing queries right now +// it means that Server.ServeTCP and/or Server.ServeUDP have been called +func (s *Server) isStarted() bool { + s.lock.RLock() + started := s.started + s.lock.RUnlock() + return started +} + +// serveDNS serves DNS response +func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) error { if r == nil || len(r.Question) != 1 || r.Response { - log.Tracef("Invalid query: %v", r) - return + return ErrInvalidQuery } log.Tracef("Handling a DNS query: %s", r.Question[0].Name) @@ -39,9 +179,11 @@ func (s *Server) serveDNS(rw ResponseWriter, r *dns.Msg) { reply.SetRcode(r, dns.RcodeServerFailure) _ = rw.WriteMsg(reply) } + + return nil } -// encrypt - encrypts DNSCrypt response +// encrypt encrypts DNSCrypt response func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) { r := EncryptedResponse{ EsVersion: q.EsVersion, @@ -60,7 +202,7 @@ func (s *Server) encrypt(m *dns.Msg, q EncryptedQuery) ([]byte, error) { return r.Encrypt(packet, sharedKey) } -// decrypt - decrypts the incoming message and returns a DNS message to process +// decrypt decrypts the incoming message and returns a DNS message to process func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) { q := EncryptedQuery{ EsVersion: s.ResolverCert.EsVersion, @@ -82,7 +224,7 @@ func (s *Server) decrypt(b []byte) (*dns.Msg, EncryptedQuery, error) { return r, q, nil } -// handleHandshake - handles a TXT request that requests certificate data +// handleHandshake handles a TXT request that requests certificate data func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) { m := new(dns.Msg) err := m.Unpack(b) @@ -120,7 +262,7 @@ func (s *Server) handleHandshake(b []byte, certTxt string) ([]byte, error) { return reply.Pack() } -// validate - checks if the Server config is properly set +// validate checks if the Server config is properly set func (s *Server) validate() bool { if s.ResolverCert == nil { log.Error("ResolverCert must be set") @@ -139,3 +281,13 @@ func (s *Server) validate() bool { return true } + +// getCertTXT serializes the cert TXT record that are to be sent to the client +func (s *Server) getCertTXT() (string, error) { + certBuf, err := s.ResolverCert.Serialize() + if err != nil { + return "", err + } + certTxt := packTxtString(certBuf) + return certTxt, nil +} diff --git a/server_bench_test.go b/server_bench_test.go new file mode 100644 index 0000000..e473276 --- /dev/null +++ b/server_bench_test.go @@ -0,0 +1,60 @@ +package dnscrypt + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/ameshkov/dnsstamps" + "github.com/stretchr/testify/require" +) + +func BenchmarkServeUDP(b *testing.B) { + benchmarkServe(b, "udp") +} + +func BenchmarkServeTCP(b *testing.B) { + benchmarkServe(b, "tcp") +} + +func benchmarkServe(b *testing.B, network string) { + srv := newTestServer(b, &testHandler{}) + b.Cleanup(func() { + err := srv.Close() + require.NoError(b, err) + }) + + client := &Client{ + Timeout: 1 * time.Second, + Net: network, + } + + serverAddr := fmt.Sprintf("127.0.0.1:%d", srv.UDPAddr().Port) + if network == "tcp" { + serverAddr = fmt.Sprintf("127.0.0.1:%d", srv.TCPAddr().Port) + } + + stamp := dnsstamps.ServerStamp{ + ServerAddrStr: serverAddr, + ServerPk: srv.resolverPk, + ProviderName: srv.server.ProviderName, + Proto: dnsstamps.StampProtoTypeDNSCrypt, + } + ri, err := client.DialStamp(stamp) + require.NoError(b, err) + require.NotNil(b, ri) + + conn, err := net.Dial(network, stamp.ServerAddrStr) + require.NoError(b, err) + + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m := createTestMessage() + res, err := client.ExchangeConn(conn, m, ri) + require.NoError(b, err) + assertTestMessageResponse(b, res) + } + b.StopTimer() +} diff --git a/server_tcp.go b/server_tcp.go index 4913def..84f9a8f 100644 --- a/server_tcp.go +++ b/server_tcp.go @@ -2,13 +2,17 @@ package dnscrypt import ( "bytes" + "errors" + "fmt" "net" + "sync" + "time" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) -// TCPResponseWriter - ResponseWriter implementation for TCP +// TCPResponseWriter is the ResponseWriter implementation for TCP type TCPResponseWriter struct { tcpConn net.Conn encrypt encryptionFunc @@ -19,17 +23,17 @@ type TCPResponseWriter struct { // type check var _ ResponseWriter = &TCPResponseWriter{} -// LocalAddr - server socket local address +// LocalAddr is the server socket local address func (w *TCPResponseWriter) LocalAddr() net.Addr { return w.tcpConn.LocalAddr() } -// RemoteAddr - client's address +// RemoteAddr is the client's address func (w *TCPResponseWriter) RemoteAddr() net.Addr { return w.tcpConn.RemoteAddr() } -// WriteMsg - writes DNS message to the client +// WriteMsg writes DNS message to the client func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error { m.Truncate(dnsSize("tcp", w.req)) @@ -42,33 +46,44 @@ func (w *TCPResponseWriter) WriteMsg(m *dns.Msg) error { return writePrefixed(res, w.tcpConn) } -// ServeTCP - listens to TCP connections, queries are then processed by Server.Handler. -// It blocks the calling goroutine and to stop it you need to close the listener. +// ServeTCP listens to TCP connections, queries are then processed by Server.Handler. +// It blocks the calling goroutine and to stop it you need to close the listener +// or call Server.Shutdown. func (s *Server) ServeTCP(l net.Listener) error { - // Check that server is properly configured - if !s.validate() { - return ErrServerConfig - } - - // Serialize the cert right away and prepare it to be sent to the client - certBuf, err := s.ResolverCert.Serialize() + err := s.prepareServeTCP(l) if err != nil { return err } - certTxt := packTxtString(certBuf) - log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr().String()) + log.Info("Entering DNSCrypt TCP listening loop tcp://%s", l.Addr()) - for { + // Tracks TCP connection handling goroutines + tcpWg := &sync.WaitGroup{} + defer s.cleanUpTCP(tcpWg, l) + + // Track active goroutine + s.wg.Add(1) + + // Serialize the cert right away and prepare it to be sent to the client + certTxt, err := s.getCertTXT() + if err != nil { + return err + } + + for s.isStarted() { conn, err := l.Accept() - if err == nil { - go func() { - _ = s.handleTCPConnection(conn, certTxt) - _ = conn.Close() - }() - } + // Check the error code and exit loop if necessary if err != nil { + if !s.isStarted() { + // Stopped gracefully + break + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Temporary() { + // Note that timeout errors will be here (i.e. hitting ReadDeadline) + continue + } if isConnClosed(err) { log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop") } else { @@ -76,47 +91,129 @@ func (s *Server) ServeTCP(l net.Listener) error { } break } + + // If we got here, the connection is alive + s.lock.Lock() + // Track the connection to allow unblocking reads on shutdown. + s.tcpConns[conn] = struct{}{} + s.lock.Unlock() + + tcpWg.Add(1) + go func() { + // Ignore error here, it is most probably a legit one + // if not, it's written to the debug log + _ = s.handleTCPConnection(conn, certTxt) + + // Clean up + _ = conn.Close() + s.lock.Lock() + delete(s.tcpConns, conn) + s.lock.Unlock() + tcpWg.Done() + }() } return nil } +// prepareServeTCP prepares the server and listener to serving DNSCrypt +func (s *Server) prepareServeTCP(l net.Listener) error { + // Check that server is properly configured + if !s.validate() { + return ErrServerConfig + } + + // Protect shutdown-related fields + s.lock.Lock() + defer s.lock.Unlock() + s.initOnce.Do(s.init) + + // Mark the server as started if needed + s.started = true + + // Track an active TCP listener + s.tcpListeners[l] = struct{}{} + return nil +} + +// cleanUpTCP waits until all TCP messages before cleaning up +func (s *Server) cleanUpTCP(tcpWg *sync.WaitGroup, l net.Listener) { + // Wait until all TCP connections are processed + tcpWg.Wait() + + // Not using it anymore so can be removed from the active listeners + s.lock.Lock() + delete(s.tcpListeners, l) + s.lock.Unlock() + + // The work is finished + s.wg.Done() +} + +// handleTCPMsg handles a single TCP message. If this method returns error +// the connection will be closed +func (s *Server) handleTCPMsg(b []byte, conn net.Conn, certTxt string) error { + if len(b) < minDNSPacketSize { + // Ignore the packets that are too short + return ErrTooShort + } + + // First of all, check for "ClientMagic" in the incoming query + if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { + // If there's no ClientMagic in the packet, we assume this + // is a plain DNS query requesting the certificate data + reply, err := s.handleHandshake(b, certTxt) + if err != nil { + return fmt.Errorf("failed to process a plain DNS query: %w", err) + } + err = writePrefixed(reply, conn) + if err != nil { + return fmt.Errorf("failed to write a response: %w", err) + } + return nil + } + + // If we got here, this is an encrypted DNSCrypt message + // We should decrypt it first to get the plain DNS query + m, q, err := s.decrypt(b) + if err != nil { + return fmt.Errorf("failed to decrypt incoming message: %w", err) + } + rw := &TCPResponseWriter{ + tcpConn: conn, + encrypt: s.encrypt, + req: m, + query: q, + } + err = s.serveDNS(rw, m) + if err != nil { + return fmt.Errorf("failed to process a DNS query: %w", err) + } + + return nil +} + +// handleTCPConnection handles all queries that are coming to the +// specified TCP connection. func (s *Server) handleTCPConnection(conn net.Conn, certTxt string) error { - for { + timeout := defaultReadTimeout + + for s.isStarted() { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + b, err := readPrefixed(conn) if err != nil { return err } - if len(b) < minDNSPacketSize { - // Ignore the packets that are too short - return ErrTooShort + + err = s.handleTCPMsg(b, conn, certTxt) + if err != nil { + log.Debug("failed to process DNS query: %v", err) + return err } - if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { - // This is an encrypted message, we should decrypt it - m, q, err := s.decrypt(b) - if err != nil { - log.Tracef("failed to decrypt incoming message: %v", err) - return err - } - rw := &TCPResponseWriter{ - tcpConn: conn, - encrypt: s.encrypt, - req: m, - query: q, - } - s.serveDNS(rw, m) - } else { - // Most likely this a DNS message requesting the certificate - reply, err := s.handleHandshake(b, certTxt) - if err != nil { - log.Tracef("Failed to process a plain DNS query: %v", err) - return err - } - err = writePrefixed(reply, conn) - if err != nil { - return err - } - } + timeout = defaultTCPIdleTimeout } + + return nil } diff --git a/server_test.go b/server_test.go index f981925..d1063e0 100644 --- a/server_test.go +++ b/server_test.go @@ -2,9 +2,11 @@ package dnscrypt import ( "bytes" + "context" "crypto/ed25519" "fmt" "net" + "runtime" "testing" "time" @@ -13,25 +15,52 @@ import ( "github.com/stretchr/testify/assert" ) -func TestServerUDPServeCert(t *testing.T) { +func TestServer_Shutdown(t *testing.T) { + n := runtime.GOMAXPROCS(1) + t.Cleanup(func() { + runtime.GOMAXPROCS(n) + }) + srv := newTestServer(t, &testHandler{}) + // Serve* methods are called in different goroutines + // give them at least a moment to actually start the server + time.Sleep(10 * time.Millisecond) + assert.NoError(t, srv.Close()) +} + +func TestServer_UDPServeCert(t *testing.T) { testServerServeCert(t, "udp") } -func TestServerTCPServeCert(t *testing.T) { +func TestServer_TCPServeCert(t *testing.T) { testServerServeCert(t, "tcp") } -func TestServerUDPRespondMessages(t *testing.T) { +func TestServer_UDPRespondMessages(t *testing.T) { testServerRespondMessages(t, "udp") } -func TestServerTCPRespondMessages(t *testing.T) { +func TestServer_TCPRespondMessages(t *testing.T) { testServerRespondMessages(t, "tcp") } +func TestServer_ReadTimeout(t *testing.T) { + srv := newTestServer(t, &testHandler{}) + t.Cleanup(func() { + assert.NoError(t, srv.Close()) + }) + // Sleep for "defaultReadTimeout" before trying to shutdown the server + // The point is to make sure readTimeout is properly handled by + // the "Serve*" goroutines and they don't finish their work unexpectedly + time.Sleep(defaultReadTimeout) + testThisServerRespondMessages(t, "udp", srv) + testThisServerRespondMessages(t, "tcp", srv) +} + func testServerServeCert(t *testing.T, network string) { srv := newTestServer(t, &testHandler{}) - defer srv.Close() + t.Cleanup(func() { + assert.NoError(t, srv.Close()) + }) client := &Client{ Net: network, @@ -50,7 +79,7 @@ func testServerServeCert(t *testing.T, network string) { Proto: dnsstamps.StampProtoTypeDNSCrypt, } ri, err := client.DialStamp(stamp) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, ri) assert.Equal(t, ri.ProviderName, srv.server.ProviderName) @@ -65,8 +94,13 @@ func testServerServeCert(t *testing.T, network string) { func testServerRespondMessages(t *testing.T, network string) { srv := newTestServer(t, &testHandler{}) - defer srv.Close() + t.Cleanup(func() { + assert.NoError(t, srv.Close()) + }) + testThisServerRespondMessages(t, network, srv) +} +func testThisServerRespondMessages(t *testing.T, network string, srv *testServer) { client := &Client{ Timeout: 1 * time.Second, Net: network, @@ -84,16 +118,16 @@ func testServerRespondMessages(t *testing.T, network string) { Proto: dnsstamps.StampProtoTypeDNSCrypt, } ri, err := client.DialStamp(stamp) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, ri) conn, err := net.Dial(network, stamp.ServerAddrStr) - assert.Nil(t, err) + assert.NoError(t, err) for i := 0; i < 10; i++ { m := createTestMessage() res, err := client.ExchangeConn(conn, m, ri) - assert.Nil(t, err) + assert.NoError(t, err) assertTestMessageResponse(t, res) } } @@ -114,16 +148,22 @@ func (s *testServer) UDPAddr() *net.UDPAddr { return s.udpConn.LocalAddr().(*net.UDPAddr) } -func (s *testServer) Close() { +func (s *testServer) Close() error { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(time.Second)) + defer cancel() + + err := s.server.Shutdown(ctx) _ = s.udpConn.Close() _ = s.tcpListen.Close() + + return err } -func newTestServer(t *testing.T, handler Handler) *testServer { +func newTestServer(t assert.TestingT, handler Handler) *testServer { rc, err := GenerateResolverConfig("example.org", nil) - assert.Nil(t, err) + assert.NoError(t, err) cert, err := rc.CreateCert() - assert.Nil(t, err) + assert.NoError(t, err) s := &Server{ ProviderName: rc.ProviderName, @@ -132,7 +172,7 @@ func newTestServer(t *testing.T, handler Handler) *testServer { } privateKey, err := HexDecodeKey(rc.PrivateKey) - assert.Nil(t, err) + assert.NoError(t, err) publicKey := ed25519.PrivateKey(privateKey).Public().(ed25519.PublicKey) srv := &testServer{ server: s, @@ -140,9 +180,9 @@ func newTestServer(t *testing.T, handler Handler) *testServer { } srv.tcpListen, err = net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4zero, Port: 0}) - assert.Nil(t, err) + assert.NoError(t, err) srv.udpConn, err = net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) - assert.Nil(t, err) + assert.NoError(t, err) go s.ServeUDP(srv.udpConn) go s.ServeTCP(srv.tcpListen) diff --git a/server_udp.go b/server_udp.go index bb467a5..51ba232 100644 --- a/server_udp.go +++ b/server_udp.go @@ -2,38 +2,43 @@ package dnscrypt import ( "bytes" + "errors" "net" + "runtime" + "sync" + "time" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) type encryptionFunc func(m *dns.Msg, q EncryptedQuery) ([]byte, error) -// UDPResponseWriter - ResponseWriter implementation for UDP +// UDPResponseWriter is the ResponseWriter implementation for UDP type UDPResponseWriter struct { - udpConn *net.UDPConn // UDP connection - remoteAddr *net.UDPAddr // Remote peer address - localIP net.IP // Local IP (that was used to accept the remote connection) - encrypt encryptionFunc // DNSCRypt encryption function - req *dns.Msg // DNS query that was processed - query EncryptedQuery // DNSCrypt query properties + udpConn *net.UDPConn // UDP connection + sess *dns.SessionUDP // SessionUDP (necessary to use dns.WriteToSessionUDP) + encrypt encryptionFunc // DNSCRypt encryption function + req *dns.Msg // DNS query that was processed + query EncryptedQuery // DNSCrypt query properties } // type check var _ ResponseWriter = &UDPResponseWriter{} -// LocalAddr - server socket local address +// LocalAddr is the server socket local address func (w *UDPResponseWriter) LocalAddr() net.Addr { return w.udpConn.LocalAddr() } -// RemoteAddr - client's address +// RemoteAddr is the client's address func (w *UDPResponseWriter) RemoteAddr() net.Addr { - return w.remoteAddr + return w.udpConn.RemoteAddr() } -// WriteMsg - writes DNS message to the client +// WriteMsg writes DNS message to the client func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error { m.Truncate(dnsSize("udp", w.req)) @@ -42,83 +47,176 @@ func (w *UDPResponseWriter) WriteMsg(m *dns.Msg) error { log.Tracef("Failed to encrypt the DNS query: %v", err) return err } + _, err = dns.WriteToSessionUDP(w.udpConn, res, w.sess) + return err +} + +// ServeUDP listens to UDP connections, queries are then processed by Server.Handler. +// It blocks the calling goroutine and to stop it you need to close the listener +// or call Server.Shutdown. +func (s *Server) ServeUDP(l *net.UDPConn) error { + err := s.prepareServeUDP(l) + if err != nil { + return err + } + + // Tracks UDP handling goroutines + udpWg := &sync.WaitGroup{} + defer s.cleanUpUDP(udpWg, l) + + // Track active goroutine + s.wg.Add(1) + + log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr()) + + // Serialize the cert right away and prepare it to be sent to the client + certTxt, err := s.getCertTXT() + if err != nil { + return err + } + + for s.isStarted() { + b, sess, err := s.readUDPMsg(l) + + // Check the error code and exit loop if necessary + if err != nil { + if !s.isStarted() { + // Stopped gracefully + return nil + } + var netErr net.Error + if errors.As(err, &netErr) && netErr.Temporary() { + // Note that timeout errors will be here (i.e. hitting ReadDeadline) + continue + } + if isConnClosed(err) { + log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop") + } else { + log.Info("got error when reading from UDP listen: %s", err) + } + return err + } + + if len(b) < minDNSPacketSize { + // Ignore the packets that are too short + continue + } + + udpWg.Add(1) + go func() { + s.serveUDPMsg(b, certTxt, sess, l) + udpWg.Done() + }() + } - _, _ = udpWrite(res, w.udpConn, w.remoteAddr, w.localIP) return nil } -// ServeUDP - listens to UDP connections, queries are then processed by Server.Handler. -// It blocks the calling goroutine and to stop it you need to close the listener. -func (s *Server) ServeUDP(l *net.UDPConn) error { +// prepareServeUDP prepares the server and listener to serving DNSCrypt +func (s *Server) prepareServeUDP(l *net.UDPConn) error { // Check that server is properly configured if !s.validate() { return ErrServerConfig } // set UDP options to allow receiving OOB data - err := udpSetOptions(l) + err := setUDPSocketOptions(l) if err != nil { return err } - // Buffer to read incoming messages - b := make([]byte, dns.MaxMsgSize) + // Protect shutdown-related fields + s.lock.Lock() + defer s.lock.Unlock() + s.initOnce.Do(s.init) - // Serialize the cert right away and prepare it to be sent to the client - certBuf, err := s.ResolverCert.Serialize() + // Mark the server as started. + // Note that we don't check if it was started before as + // Serve* methods can be called multiple times. + s.started = true + + // Track an active UDP listener + s.udpListeners[l] = struct{}{} + return err +} + +// cleanUpUDP waits until all UDP messages before cleaning up +func (s *Server) cleanUpUDP(udpWg *sync.WaitGroup, l *net.UDPConn) { + // Wait until UDP messages are processed + udpWg.Wait() + + // Not using it anymore so can be removed from the active listeners + s.lock.Lock() + delete(s.udpListeners, l) + s.lock.Unlock() + + // The work is finished + s.wg.Done() +} + +// readUDPMsg reads incoming UDP message +func (s *Server) readUDPMsg(l *net.UDPConn) ([]byte, *dns.SessionUDP, error) { + _ = l.SetReadDeadline(time.Now().Add(defaultReadTimeout)) + b := make([]byte, dns.MinMsgSize) + n, sess, err := dns.ReadFromSessionUDP(l, b) if err != nil { - return err + return nil, nil, err } - certTxt := packTxtString(certBuf) - // Init oobSize - it will be used later when reading and writing UDP messages - oobSize := udpGetOOBSize() - - log.Info("Entering DNSCrypt UDP listening loop on udp://%s", l.LocalAddr().String()) - - for { - n, localIP, addr, err := udpRead(l, b, oobSize) - if n < minDNSPacketSize { - // Ignore the packets that are too short - continue - } - - if bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { - // This is an encrypted message, we should decrypt it - m, q, err := s.decrypt(b[:n]) - if err == nil { - rw := &UDPResponseWriter{ - udpConn: l, - remoteAddr: addr, - localIP: localIP, - encrypt: s.encrypt, - req: m, - query: q, - } - go s.serveDNS(rw, m) - } else { - log.Tracef("Failed to decrypt incoming message len=%d: %v", n, err) - } - } else { - // Most likely this a DNS message requesting the certificate - reply, err := s.handleHandshake(b, certTxt) - if err != nil { - log.Tracef("Failed to process a plain DNS query: %v", err) - } - if err == nil { - _, _ = l.WriteTo(reply, addr) - } - } + return b[:n], sess, err +} +// serveUDPMsg handles incoming DNS message +func (s *Server) serveUDPMsg(b []byte, certTxt string, sess *dns.SessionUDP, l *net.UDPConn) { + // First of all, check for "ClientMagic" in the incoming query + if !bytes.Equal(b[:clientMagicSize], s.ResolverCert.ClientMagic[:]) { + // If there's no ClientMagic in the packet, we assume this + // is a plain DNS query requesting the certificate data + reply, err := s.handleHandshake(b, certTxt) if err != nil { - if isConnClosed(err) { - log.Info("udpListen.ReadFrom() returned because we're reading from a closed connection, exiting loop") - } else { - log.Info("got error when reading from UDP listen: %s", err) - } - break + log.Tracef("failed to process a plain DNS query: %v", err) } + if err == nil { + // Ignore errors, we don't care and can't handle them anyway + _, _ = dns.WriteToSessionUDP(l, reply, sess) + } + + return } + // If we got here, this is an encrypted DNSCrypt message + // We should decrypt it first to get the plain DNS query + m, q, err := s.decrypt(b) + if err == nil { + rw := &UDPResponseWriter{ + udpConn: l, + sess: sess, + encrypt: s.encrypt, + req: m, + query: q, + } + err = s.serveDNS(rw, m) + if err != nil { + log.Tracef("failed to process a DNS query: %v", err) + } + } else { + log.Tracef("failed to decrypt incoming message len=%d: %v", len(b), err) + } +} + +// setUDPSocketOptions method is necessary to be able to use dns.ReadFromSessionUDP / dns.WriteToSessionUDP +func setUDPSocketOptions(conn *net.UDPConn) error { + if runtime.GOOS == "windows" { + return nil + } + + // We don't know if this a IPv4-only, IPv6-only or a IPv4-and-IPv6 connection. + // Try enabling receiving of ECN and packet info for both IP versions. + // We expect at least one of those syscalls to succeed. + err6 := ipv6.NewPacketConn(conn).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) + err4 := ipv4.NewPacketConn(conn).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) + if err6 != nil && err4 != nil { + return err4 + } return nil } diff --git a/testdata/dnscrypt-cert.opendns.txt b/testdata/dnscrypt-cert.opendns.txt new file mode 100644 index 0000000..18950f6 --- /dev/null +++ b/testdata/dnscrypt-cert.opendns.txt @@ -0,0 +1 @@ +DNSC\000\001\000\000\200\226E:H\156\203%\134\218\127]\168\239\027u\011$\191\008\239\176F\133\017\171\161\219\154\142i\164\010\239\017f\168dS\210f\197\194\169\171w\2499\1891\155<\130\218@/\155\023v\153#d\024\004\136\180\228K5\233d\180\144\189\218\186\232%\162K\004\021\160\139\225\157}\219\135\163<\215~\223\142/qc78aWoo]\221\184`]\221\184`_\190\235\224 \ No newline at end of file diff --git a/udp_unix.go b/udp_unix.go deleted file mode 100644 index 8d1f50c..0000000 --- a/udp_unix.go +++ /dev/null @@ -1,83 +0,0 @@ -// +build aix darwin dragonfly linux netbsd openbsd solaris freebsd - -package dnscrypt - -import ( - "fmt" - "net" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -// udpGetOOBSize - get max. size of received OOB data -// It will then be used in the ReadMsgUDP function -func udpGetOOBSize() int { - oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface) - oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface) - - if len(oob4) > len(oob6) { - return len(oob4) - } - return len(oob6) -} - -// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data -func udpSetOptions(c *net.UDPConn) error { - err6 := ipv6.NewPacketConn(c).SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true) - err4 := ipv4.NewPacketConn(c).SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true) - if err6 != nil && err4 != nil { - return fmt.Errorf("failed to call SetControlMessage: ipv4: %v ipv6: %v", err4, err6) - } - return nil -} - -// udpRead - receive payload and OOB data from the UDP socket -func udpRead(c *net.UDPConn, buf []byte, udpOOBSize int) (int, net.IP, *net.UDPAddr, error) { - var oobn int - oob := make([]byte, udpOOBSize) - var err error - var n int - var remoteAddr *net.UDPAddr - n, oobn, _, remoteAddr, err = c.ReadMsgUDP(buf, oob) - if err != nil { - return -1, nil, nil, err - } - - localIP := udpGetDstFromOOB(oob[:oobn]) - return n, localIP, remoteAddr, nil -} - -// udpWrite - writes to the UDP socket and sets local IP to OOB data -func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, localIP net.IP) (int, error) { - n, _, err := conn.WriteMsgUDP(bytes, udpMakeOOBWithSrc(localIP), remoteAddr) - return n, err -} - -// udpGetDstFromOOB - get destination IP from OOB data -func udpGetDstFromOOB(oob []byte) net.IP { - cm6 := &ipv6.ControlMessage{} - if cm6.Parse(oob) == nil && cm6.Dst != nil { - return cm6.Dst - } - - cm4 := &ipv4.ControlMessage{} - if cm4.Parse(oob) == nil && cm4.Dst != nil { - return cm4.Dst - } - - return nil -} - -// udpMakeOOBWithSrc - make OOB data with a specified source IP -func udpMakeOOBWithSrc(ip net.IP) []byte { - if ip.To4() == nil { - cm := &ipv6.ControlMessage{} - cm.Src = ip - return cm.Marshal() - } - - cm := &ipv4.ControlMessage{} - cm.Src = ip - return cm.Marshal() -} diff --git a/udp_windows.go b/udp_windows.go deleted file mode 100644 index 74ac910..0000000 --- a/udp_windows.go +++ /dev/null @@ -1,30 +0,0 @@ -package dnscrypt - -import "net" - -// udpGetOOBSize - get max. size of received OOB data -// Does nothing on Windows -func udpGetOOBSize() int { - return 0 -} - -// udpSetOptions - set options on a UDP socket to be able to receive the necessary OOB data -// Does nothing on Windows -func udpSetOptions(c *net.UDPConn) error { - return nil -} - -// udpRead - receive payload from the UDP socket -func udpRead(c *net.UDPConn, buf []byte, _ int) (int, net.IP, *net.UDPAddr, error) { - n, addr, err := c.ReadFrom(buf) - var udpAddr *net.UDPAddr - if addr != nil { - udpAddr = addr.(*net.UDPAddr) - } - return n, nil, udpAddr, err -} - -// udpWrite - writes to the UDP socket -func udpWrite(bytes []byte, conn *net.UDPConn, remoteAddr *net.UDPAddr, _ net.IP) (int, error) { - return conn.WriteTo(bytes, remoteAddr) -}