merge master into branch and use named returns better
This commit is contained in:
parent
f64f476a1b
commit
91969133ff
7 changed files with 62 additions and 45 deletions
18
cmd/cli.go
18
cmd/cli.go
|
@ -70,11 +70,11 @@ func ParseCLI(args []string, version string) (opts *util.Options, err error) {
|
|||
opts.Logger.Info("Options fully populated")
|
||||
opts.Logger.Debug(fmt.Sprintf("%+v", opts))
|
||||
|
||||
return opts, nil
|
||||
return
|
||||
}
|
||||
|
||||
// Everything that has to do with CLI flags goes here (the posix style, eg. -a and --bbbb).
|
||||
func parseFlags(args []string, version string) (*util.Options, []string, error) {
|
||||
func parseFlags(args []string, version string) (opts *util.Options, flags []string, err error) {
|
||||
flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError)
|
||||
|
||||
flagSet.Usage = func() {
|
||||
|
@ -162,7 +162,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
flagSet.SortFlags = true
|
||||
|
||||
// Parse the flags
|
||||
if err := flagSet.Parse(args[1:]); err != nil {
|
||||
if err = flagSet.Parse(args[1:]); err != nil {
|
||||
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("flag: %w", err)
|
||||
}
|
||||
|
||||
|
@ -172,7 +172,7 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("EDNS MBZ: %w", err)
|
||||
}
|
||||
|
||||
opts := util.Options{
|
||||
opts = &util.Options{
|
||||
Logger: util.InitLogger(*verbosity),
|
||||
IPv4: *ipv4,
|
||||
IPv6: *ipv6,
|
||||
|
@ -243,8 +243,8 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
|
||||
// TODO: DRY
|
||||
if *subnet != "" {
|
||||
if err = util.ParseSubnet(*subnet, &opts); err != nil {
|
||||
return &opts, nil, fmt.Errorf("%w", err)
|
||||
if err = util.ParseSubnet(*subnet, opts); err != nil {
|
||||
return opts, nil, fmt.Errorf("%w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,10 +254,12 @@ func parseFlags(args []string, version string) (*util.Options, []string, error)
|
|||
if *versionFlag {
|
||||
fmt.Printf("awl version %s, built with %s\n", version, runtime.Version())
|
||||
|
||||
return &opts, nil, util.ErrNotError
|
||||
return opts, nil, util.ErrNotError
|
||||
}
|
||||
|
||||
return &opts, flagSet.Args(), nil
|
||||
flags = flagSet.Args()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
var errNoArg = errors.New("no argument given")
|
||||
|
|
|
@ -140,7 +140,7 @@ func ToString(res util.Response, opts *util.Options) (s string, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
return s, nil
|
||||
return
|
||||
}
|
||||
|
||||
func serverExtra(opts *util.Options) string {
|
||||
|
|
|
@ -42,7 +42,7 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err
|
|||
|
||||
resolverInf, err := client.Dial(resolver.opts.Request.Server)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("dnscrypt: dial: %w", err)
|
||||
return resp, fmt.Errorf("dnscrypt: dial: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
@ -50,13 +50,15 @@ func (resolver *DNSCryptResolver) LookUp(msg *dns.Msg) (resp util.Response, err
|
|||
rtt := time.Since(now)
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("dnscrypt: exchange: %w", err)
|
||||
return resp, fmt.Errorf("dnscrypt: exchange: %w", err)
|
||||
}
|
||||
|
||||
resp = util.Response{
|
||||
DNS: res,
|
||||
RTT: rtt,
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Info("Request successful")
|
||||
|
||||
return util.Response{
|
||||
DNS: res,
|
||||
RTT: rtt,
|
||||
}, nil
|
||||
return
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
|
|||
|
||||
buf, err := msg.Pack()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: packing: %w", err)
|
||||
return resp, fmt.Errorf("doh: packing: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("https: sending HTTPS request")
|
||||
|
@ -55,7 +55,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
|
|||
|
||||
req, err := http.NewRequest(method, resolver.opts.Request.Server, bytes.NewBuffer(buf))
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: request creation: %w", err)
|
||||
return resp, fmt.Errorf("doh: request creation: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
|
@ -66,23 +66,29 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
|
|||
resp.RTT = time.Since(now)
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: HTTP request: %w", err)
|
||||
// overwrite RTT or else tests will fail
|
||||
resp.RTT = 0
|
||||
|
||||
return resp, fmt.Errorf("doh: HTTP request: %w", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return util.Response{}, &util.ErrHTTPStatus{Code: res.StatusCode}
|
||||
// overwrite RTT or else tests will fail
|
||||
resp.RTT = 0
|
||||
|
||||
return resp, &util.ErrHTTPStatus{Code: res.StatusCode}
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("https: reading response")
|
||||
|
||||
fullRes, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: body read: %w", err)
|
||||
return resp, fmt.Errorf("doh: body read: %w", err)
|
||||
}
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: body close: %w", err)
|
||||
return resp, fmt.Errorf("doh: body close: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("https: unpacking response")
|
||||
|
@ -91,7 +97,7 @@ func (resolver *HTTPSResolver) LookUp(msg *dns.Msg) (resp util.Response, err err
|
|||
|
||||
err = resp.DNS.Unpack(fullRes)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doh: dns message unpack: %w", err)
|
||||
return resp, fmt.Errorf("doh: dns message unpack: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
|
|
@ -38,7 +38,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
|
|||
|
||||
connection, err := quic.DialAddr(resolver.opts.Request.Server, tls, conf)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: dial: %w", err)
|
||||
return resp, fmt.Errorf("doq: dial: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: packing query")
|
||||
|
@ -46,7 +46,7 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
|
|||
// Compress request to over-the-wire
|
||||
buf, err := msg.Pack()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: pack: %w", err)
|
||||
return resp, fmt.Errorf("doq: pack: %w", err)
|
||||
}
|
||||
|
||||
t := time.Now()
|
||||
|
@ -55,21 +55,21 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
|
|||
|
||||
stream, err := connection.OpenStream()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream creation: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream creation: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: writing to stream")
|
||||
|
||||
_, err = stream.Write(buf)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream write: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream write: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: reading stream")
|
||||
|
||||
fullRes, err := io.ReadAll(stream)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream read: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream read: %w", err)
|
||||
}
|
||||
|
||||
resp.RTT = time.Since(t)
|
||||
|
@ -78,14 +78,14 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
|
|||
// Close with error: no error
|
||||
err = connection.CloseWithError(0, "")
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic connection close: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic connection close: %w", err)
|
||||
}
|
||||
|
||||
resolver.opts.Logger.Debug("quic: closing stream")
|
||||
|
||||
err = stream.Close()
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: quic stream close: %w", err)
|
||||
return resp, fmt.Errorf("doq: quic stream close: %w", err)
|
||||
}
|
||||
|
||||
resp.DNS = &dns.Msg{}
|
||||
|
@ -94,8 +94,8 @@ func (resolver *QUICResolver) LookUp(msg *dns.Msg) (resp util.Response, err erro
|
|||
|
||||
err = resp.DNS.Unpack(fullRes)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("doq: unpack: %w", err)
|
||||
return resp, fmt.Errorf("doq: unpack: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err
|
|||
|
||||
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
}
|
||||
|
||||
switch dns.RcodeToString[resp.DNS.MsgHdr.Rcode] {
|
||||
|
@ -64,7 +64,7 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err
|
|||
resp.DNS, resp.RTT, err = dnsClient.Exchange(msg, resolver.opts.Request.Server)
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("badcookie: DNS exchange: %w", err)
|
||||
return resp, fmt.Errorf("badcookie: DNS exchange: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -90,8 +90,8 @@ func (resolver *StandardResolver) LookUp(msg *dns.Msg) (resp util.Response, err
|
|||
}
|
||||
|
||||
if err != nil {
|
||||
return util.Response{}, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
return resp, fmt.Errorf("standard: DNS exchange: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ type Resolver interface {
|
|||
}
|
||||
|
||||
// LoadResolver loads the respective resolver for performing a DNS query.
|
||||
func LoadResolver(opts *util.Options) (Resolver, error) {
|
||||
func LoadResolver(opts *util.Options) (resolver Resolver, err error) {
|
||||
switch {
|
||||
case opts.HTTPS:
|
||||
opts.Logger.Info("loading DNS-over-HTTPS resolver")
|
||||
|
@ -32,10 +32,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
}
|
||||
|
||||
opts.Request.Server += opts.HTTPSOptions.Endpoint
|
||||
|
||||
return &HTTPSResolver{
|
||||
resolver = &HTTPSResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
case opts.QUIC:
|
||||
opts.Logger.Info("loading DNS-over-QUIC resolver")
|
||||
|
||||
|
@ -43,9 +44,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
|
||||
}
|
||||
|
||||
return &QUICResolver{
|
||||
resolver = &QUICResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
case opts.DNSCrypt:
|
||||
opts.Logger.Info("loading DNSCrypt resolver")
|
||||
|
||||
|
@ -53,9 +56,11 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
opts.Request.Server = "sdns://" + opts.Request.Server
|
||||
}
|
||||
|
||||
return &DNSCryptResolver{
|
||||
resolver = &DNSCryptResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
default:
|
||||
opts.Logger.Info("loading standard/DNS-over-TLS resolver")
|
||||
|
||||
|
@ -63,8 +68,10 @@ func LoadResolver(opts *util.Options) (Resolver, error) {
|
|||
opts.Request.Server = net.JoinHostPort(opts.Request.Server, strconv.Itoa(opts.Request.Port))
|
||||
}
|
||||
|
||||
return &StandardResolver{
|
||||
resolver = &StandardResolver{
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue