diff --git a/cmd/cli.go b/cmd/cli.go index 62bfe76..3c06037 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -18,6 +18,63 @@ import ( // ParseCLI parses arguments given from the CLI and passes them into an `Options` // struct. func ParseCLI(args []string, version string) (*util.Options, error) { + // Parse the standard flags + opts, misc, err := parseFlags(args, version) + if err != nil { + return opts, err + } + + // Parse all the arguments that don't start with - or -- + // This includes the dig-style (+) options + err = ParseMiscArgs(misc, opts) + if err != nil { + return opts, err + } + + opts.Logger.Info("Dig/Drill flags parsed") + opts.Logger.Debug(fmt.Sprintf("%+v", opts)) + + // Special options and exceptions time + + if opts.Request.Port == 0 { + if opts.TLS || opts.QUIC { + opts.Request.Port = 853 + } else { + opts.Request.Port = 53 + } + } + + opts.Logger.Info("Port set to", opts.Request.Port) + + // Set timeout to 0.5 seconds if set below 0.5 + if opts.Request.Timeout < (time.Second / 2) { + opts.Request.Timeout = (time.Second / 2) + } + + if opts.Request.Retries < 0 { + opts.Request.Retries = 0 + } + + if opts.Trace { + if opts.TLS || opts.HTTPS || opts.QUIC { + opts.Logger.Warn("Every query after the root query will only use UDP/TCP") + } + + if opts.Reverse { + opts.Logger.Error("Reverse queries are not currently supported") + } + + opts.RD = true + } + + opts.Logger.Info("Options fully populated") + opts.Logger.Debug(fmt.Sprintf("%+v", opts)) + + return opts, nil +} + +// Everything that has to do with CLI flags goes here (the posix style, eg. -a and --bbbb). +func parseFlags(args []string, version string) (*util.Options, []string, error) { flagSet := flag.NewFlagSet(args[0], flag.ContinueOnError) flagSet.Usage = func() { @@ -46,6 +103,7 @@ func ParseCLI(args []string, version string) (*util.Options, error) { ipv4 = flagSet.Bool("4", false, "force IPv4", flag.OptShorthand('4')) ipv6 = flagSet.Bool("6", false, "force IPv6", flag.OptShorthand('6')) reverse = flagSet.Bool("reverse", false, "do a reverse lookup", flag.OptShorthand('x')) + trace = flagSet.Bool("trace", false, "trace from the root") timeout = flagSet.Float32("timeout", 5, "Timeout, in `seconds`") retry = flagSet.Int("retries", 2, "number of `times` to retry") @@ -101,23 +159,24 @@ func ParseCLI(args []string, version string) (*util.Options, error) { ) // Don't sort the flags when -h is given - flagSet.SortFlags = false + flagSet.SortFlags = true // Parse the flags if err := flagSet.Parse(args[1:]); err != nil { - return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("flag: %w", err) + return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("flag: %w", err) } // TODO: DRY, dumb dumb. mbz, err := strconv.ParseInt(*mbzflag, 0, 16) if err != nil { - return &util.Options{Logger: util.InitLogger(*verbosity)}, fmt.Errorf("EDNS MBZ: %w", err) + return &util.Options{Logger: util.InitLogger(*verbosity)}, nil, fmt.Errorf("EDNS MBZ: %w", err) } opts := util.Options{ Logger: util.InitLogger(*verbosity), IPv4: *ipv4, IPv6: *ipv6, + Trace: *trace, Short: *short, TCP: *tcp, DNSCrypt: *dnscrypt, @@ -185,7 +244,7 @@ func ParseCLI(args []string, version string) (*util.Options, error) { // TODO: DRY if *subnet != "" { if err = util.ParseSubnet(*subnet, &opts); err != nil { - return &opts, fmt.Errorf("%w", err) + return &opts, nil, fmt.Errorf("%w", err) } } @@ -195,42 +254,10 @@ func ParseCLI(args []string, version string) (*util.Options, error) { if *versionFlag { fmt.Printf("awl version %s, built with %s\n", version, runtime.Version()) - return &opts, util.ErrNotError + return &opts, nil, util.ErrNotError } - // Parse all the arguments that don't start with - or -- - // This includes the dig-style (+) options - err = ParseMiscArgs(flagSet.Args(), &opts) - if err != nil { - return &opts, err - } - - opts.Logger.Info("Dig/Drill flags parsed") - opts.Logger.Debug(fmt.Sprintf("%+v", opts)) - - if opts.Request.Port == 0 { - if opts.TLS || opts.QUIC { - opts.Request.Port = 853 - } else { - opts.Request.Port = 53 - } - } - - opts.Logger.Info("Port set to", opts.Request.Port) - - // Set timeout to 0.5 seconds if set below 0.5 - if opts.Request.Timeout < (time.Second / 2) { - opts.Request.Timeout = (time.Second / 2) - } - - if opts.Request.Retries < 0 { - opts.Request.Retries = 0 - } - - opts.Logger.Info("Options fully populated") - opts.Logger.Debug(fmt.Sprintf("%+v", opts)) - - return &opts, nil + return &opts, flagSet.Args(), nil } var errNoArg = errors.New("no argument given") diff --git a/cmd/dig.go b/cmd/dig.go index 024d889..289004e 100644 --- a/cmd/dig.go +++ b/cmd/dig.go @@ -25,6 +25,18 @@ func ParseDig(arg string, opts *util.Options) error { opts.Logger.Info("Setting", arg) switch arg { + case "trace", "notrace": + opts.Trace = isNo + if isNo { + opts.DNSSEC = true + opts.Display.Comments = false + opts.Display.Question = false + opts.Display.Opt = false + opts.Display.Answer = true + opts.Display.Authority = true + opts.Display.Additional = false + opts.Display.Statistics = false + } // Set DNS query flags case "aa", "aaflag", "aaonly": opts.AA = isNo diff --git a/cmd/dig_test.go b/cmd/dig_test.go index cb18f48..fcbb9ea 100644 --- a/cmd/dig_test.go +++ b/cmd/dig_test.go @@ -60,6 +60,7 @@ func FuzzDig(f *testing.F) { "all", "noall", "idnout", "noidnout", "class", "noclass", + "trace", "notrace", "invalid", } diff --git a/completions/fish.fish b/completions/fish.fish index 5d74c37..1f1b68c 100644 --- a/completions/fish.fish +++ b/completions/fish.fish @@ -38,7 +38,7 @@ complete -f -c awl -s V -l version -d 'Print version and exit' # complete -f -c awl -a '+search +nosearch' -d 'Set whether to use searchlist' # complete -f -c awl -a '+showsearch +noshowsearch' -d 'Search with intermediate results' -# complete -f -c awl -a '+recurse +norecurse' -d 'Recursive mode' +complete -f -c awl -a '+recurse +norecurse' -d 'Recursive mode' complete -f -c awl -l no-truncate -a '+ignore +noignore' -d 'Dont revert to TCP for TC responses.' # complete -f -c awl -a '+fail +nofail' -d 'Dont try next server on SERVFAIL' # complete -f -c awl -a '+besteffort +nobesteffort' -d 'Try to parse even illegal messages' @@ -59,8 +59,8 @@ complete -f -c awl -a '+all +noall' -d 'Set or clear all display flags' complete -f -c awl -a '+qr +noqr' -d 'Print question before sending' # complete -f -c awl -a '+nssearch +nonssearch' -d 'Search all authoritative nameservers' complete -f -c awl -a '+identify +noidentify' -d 'ID responders in short answers' -# complete -f -c awl -a '+trace +notrace' -d 'Trace delegation down from root' -complete -f -c awl -l dnssec -a '+dnssec +nodnssec' -d 'Request DNSSEC records' +complete -f -c awl -a '+trace +notrace' -d 'Trace delegation down from root' +complete -f -c awl -l dnssec -a '+dnssec +nodnssec +do +nodo' -d 'Request DNSSEC records' complete -f -c awl -a '+nsid +nonsid' -d 'Request Name Server ID' # complete -f -c awl -a '+multiline +nomultiline' -d 'Print records in an expanded format' # complete -f -c awl -a '+onesoa +noonesoa' -d 'AXFR prints only one soa record' diff --git a/completions/zsh.zsh b/completions/zsh.zsh index 269bd12..ac46304 100644 --- a/completions/zsh.zsh +++ b/completions/zsh.zsh @@ -28,7 +28,7 @@ local -a alts args '*+'{no,}'keepopen[keep TCP socket open between queries]' '*+'{no,}'recurse[set the RD (recursion desired) bit in the query]' # '*+'{no,}'nssearch[search all authoritative nameservers]' - # '*+'{no,}'trace[trace delegation down from root]' + '*+'{no,}'trace[trace delegation down from root]' # '*+'{no,}'cmd[print initial comment in output]' '*+'{no,}'short[print terse output]' '*+'{no,}'identify[print IP and port of responder]' @@ -98,6 +98,7 @@ _arguments -s -C $args \ '*-'{j,-json}'+[present the results as JSON]' \ '*-'{X,-xml}'+[present the results as XML]' \ '*-'{y,-yaml}'+[present the results as YAML]' \ + '*--trace+[trace from the root]' \ '*: :->args' && ret=0 if [[ -n $state ]]; then @@ -108,4 +109,4 @@ if [[ -n $state ]]; then fi fi -return ret \ No newline at end of file +return ret diff --git a/doc/awl.1.scd b/doc/awl.1.scd index e43d93d..3035adf 100644 --- a/doc/awl.1.scd +++ b/doc/awl.1.scd @@ -114,6 +114,11 @@ Many options are inherited from *dig*(1). Set the timeout period. Floating point numbers are accepted. 0.5 seconds is the minimum. +*--trace*, *+trace* + Trace the path of the query from the root, acting like its own resolver. + This option enables DNSSEC. + When *@server* is specified, this will only affect the initial query. + *--retries* _int_, *+tries*=_int_, *+retry*=_int_ Set the number of retries. Retry is one more than tries, dig style. diff --git a/main.go b/main.go index 987bb39..f5c01b4 100644 --- a/main.go +++ b/main.go @@ -5,12 +5,15 @@ package main import ( "errors" "fmt" + "math/rand" "os" "strings" + "time" cli "git.froth.zone/sam/awl/cmd" "git.froth.zone/sam/awl/pkg/query" "git.froth.zone/sam/awl/pkg/util" + "github.com/miekg/dns" ) var version = "DEV" @@ -28,42 +31,103 @@ func main() { } func run(args []string) (opts *util.Options, code int, err error) { + //nolint:gosec //Secure source not needed + r := rand.New(rand.NewSource(time.Now().Unix())) + opts, err = cli.ParseCLI(args, version) if err != nil { return opts, 1, fmt.Errorf("parse: %w", err) } - var resp util.Response + var ( + resp util.Response + keepTracing bool + tempDomain string + tempQueryType uint16 + ) - // Retry queries if a query fails - for i := 0; i <= opts.Request.Retries; i++ { - resp, err = query.CreateQuery(opts) - if err == nil { - break - } else if i != opts.Request.Retries { - opts.Logger.Warn("Retrying request, error:", err) + for ok := true; ok; ok = keepTracing { + if keepTracing { + opts.Request.Name = tempDomain + opts.Request.Type = tempQueryType + } else { + tempDomain = opts.Request.Name + tempQueryType = opts.Request.Type + + // Override the query because it needs to be done + opts.Request.Name = "." + opts.Request.Type = dns.TypeNS } - } + // Retry queries if a query fails + for i := 0; i <= opts.Request.Retries; i++ { + resp, err = query.CreateQuery(opts) + if err == nil { + keepTracing = opts.Trace && (!resp.DNS.Authoritative || (opts.Request.Name == "." && tempDomain != ".")) - // Query failed, make it fail - if err != nil { - return opts, 9, fmt.Errorf("query: %w", err) - } + break + } else if i != opts.Request.Retries { + opts.Logger.Warn("Retrying request, error:", err) + } + } - var str string - if opts.JSON || opts.XML || opts.YAML { - str, err = query.PrintSpecial(resp, opts) + // Query failed, make it fail if err != nil { - return opts, 10, fmt.Errorf("format print: %w", err) + return opts, 9, fmt.Errorf("query: %w", err) } - } else { - str, err = query.ToString(resp, opts) - if err != nil { - return opts, 15, fmt.Errorf("standard print: %w", err) + + var str string + if opts.JSON || opts.XML || opts.YAML { + str, err = query.PrintSpecial(resp, opts) + if err != nil { + return opts, 10, fmt.Errorf("format print: %w", err) + } + } else { + str, err = query.ToString(resp, opts) + if err != nil { + return opts, 15, fmt.Errorf("standard print: %w", err) + } + } + + fmt.Println(str) + + if keepTracing { + var records []dns.RR + + if opts.Request.Name == "." { + records = resp.DNS.Answer + } else { + records = resp.DNS.Ns + } + + want := func(rr dns.RR) bool { + temp := strings.Split(rr.String(), "\t") + + return temp[len(temp)-2] == "NS" + } + + i := 0 + + for _, x := range records { + if want(x) { + records[i] = x + i++ + } + } + + records = records[:i] + randomRR := records[r.Intn(len(records))] + + v := strings.Split(randomRR.String(), "\t") + opts.Request.Server = strings.TrimSuffix(v[len(v)-1], ".") + + opts.TLS = false + opts.HTTPS = false + opts.QUIC = false + + opts.RD = false + opts.Request.Port = 53 } } - fmt.Println(str) - return opts, 0, nil } diff --git a/main_test.go b/main_test.go index 949230c..66e5840 100644 --- a/main_test.go +++ b/main_test.go @@ -27,6 +27,17 @@ func TestRun(t *testing.T) { } } +func TestTrace(t *testing.T) { + domains := []string{"git.froth.zone", "google.com", "amazon.com", "freecumextremist.com", "dns.froth.zone", "sleepy.cafe", "pkg.go.dev"} + + for i := range domains { + args := []string{"awl", "+trace", domains[i], "@1.1.1.1"} + _, code, err := run(args) + assert.NilError(t, err) + assert.Equal(t, code, 0) + } +} + func TestHelp(t *testing.T) { // t.Parallel() args := []string{"awl", "-h"} diff --git a/pkg/util/options.go b/pkg/util/options.go index b462a18..8848944 100644 --- a/pkg/util/options.go +++ b/pkg/util/options.go @@ -67,6 +67,9 @@ type Options struct { IPv4 bool `json:"forceIPv4" example:"false"` // Force IPv6 only IPv6 bool `json:"forceIPv6" example:"false"` + + // Trace from the root + Trace bool `json:"trace" example:"false"` } // HTTPSOptions are options exclusively for DNS-over-HTTPS queries.