diff --git a/internal/cli/format.go b/internal/cli/format.go index 5adce3e..b8b3195 100644 --- a/internal/cli/format.go +++ b/internal/cli/format.go @@ -1,6 +1,7 @@ package cli import ( + "bufio" "context" "errors" "fmt" @@ -236,20 +237,29 @@ func (f *Format) Run() error { }) eg.Go(func() (err error) { - var walker walk.Walker - - if Cli.Stdin { - walker, err = walk.NewPathReader(os.Stdin) - } else if len(Cli.Paths) > 0 { - walker, err = walk.NewPathList(Cli.Paths) - } else { - walker, err = walk.New(Cli.Walk, Cli.TreeRoot) - } + walker, err := walk.New(Cli.Walk, Cli.TreeRoot) if err != nil { return fmt.Errorf("%w: failed to create walker", err) } + paths := Cli.Paths + + if len(paths) == 0 && Cli.Stdin { + // read in all the paths + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + paths = append(paths, scanner.Text()) + } + } + + if len(paths) > 0 { + walker, err = walk.NewSelector(Cli.TreeRoot, paths, walker) + if err != nil { + return fmt.Errorf("%w: failed to create selector", err) + } + } + defer close(pathsCh) return cache.ChangeSet(ctx, walker, pathsCh) }) diff --git a/internal/walk/selector.go b/internal/walk/selector.go new file mode 100644 index 0000000..1b5952f --- /dev/null +++ b/internal/walk/selector.go @@ -0,0 +1,59 @@ +package walk + +import ( + "context" + "fmt" + "git.numtide.com/numtide/treefmt/internal/format" + "github.com/gobwas/glob" + "io/fs" + "os" + "path" + "path/filepath" +) + +type selector struct { + root string + globs []glob.Glob + delegate Walker +} + +func (s selector) Root() string { + return s.root +} + +func (s selector) Walk(ctx context.Context, fn filepath.WalkFunc) error { + return s.delegate.Walk(ctx, func(path string, info fs.FileInfo, err error) error { + for _, g := range s.globs { + if !g.Match(path) { + continue + } + return fn(path, info, err) + } + return nil + }) +} + +func NewSelector(root string, paths []string, delegate Walker) (Walker, error) { + var fullPaths []string + for _, p := range paths { + info, err := os.Lstat(p) + if err != nil { + return nil, fmt.Errorf("%w: failed to lstat %v", err, p) + } + + if info.Mode()&os.ModeSymlink == os.ModeSymlink { + return nil, fmt.Errorf("%v is a symlink which is not supported", p) + } else if info.IsDir() { + p = path.Join(p, "/*") + } + + fullPaths = append(fullPaths, p) + } + + globs, err := format.CompileGlobs(fullPaths) + if err != nil { + return nil, err + } + + return &selector{root, globs, delegate}, nil +}