Skip to content

Commit afd5336

Browse files
authored
Confine symlink state handling to scanSymlink in Filesystem source (#4807)
* Confine symlink state handling to scanSymlink in Filesystem source * Fix s.canFollowSymlinks snafu * Update symlink tests to use starting depth 0 * Missed one * Remove symlink checking from scanFile; this is now always handled in scanSymlink * Confine errgroup.Groups to scanDir in the Filesystem source (#4808) * Move path parameter after rootPath parameter in the Filesystem source * Move the depth parameter too * Only create an errgroup.Group inside scanDir (where it's used) in the Filesystem source
1 parent d17df48 commit afd5336

2 files changed

Lines changed: 52 additions & 130 deletions

File tree

pkg/sources/filesystem/filesystem.go

Lines changed: 38 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -125,26 +125,15 @@ func (s *Source) Chunks(ctx trContext.Context, chunksChan chan *sources.Chunk, _
125125
}
126126

127127
if fileInfo.Mode()&os.ModeSymlink != 0 {
128-
if !s.canFollowSymlinks() {
129-
// If the file or directory is a symlink but the followSymlinks is disable ignore the path
130-
logger.Info("skipping, following symlinks is not allowed", "path", cleanPath)
131-
continue
132-
}
133128
// if the root path is a symlink we scan the symlink
134129
ctx.Logger().V(5).Info("Root path is a symlink", "path", cleanPath)
135-
workerPool := new(errgroup.Group)
136-
workerPool.SetLimit(s.concurrency)
137-
initialDepth := 1
138-
err = s.scanSymlink(ctx, chunksChan, workerPool, rootPath, initialDepth, cleanPath)
139-
_ = workerPool.Wait()
130+
initialDepth := 0
131+
err = s.scanSymlink(ctx, chunksChan, rootPath, initialDepth, cleanPath)
140132
s.ClearEncodedResumeInfoFor(rootPath)
141133
} else if fileInfo.IsDir() {
142134
ctx.Logger().V(5).Info("Root path is a dir", "path", cleanPath)
143-
workerPool := new(errgroup.Group)
144-
workerPool.SetLimit(s.concurrency)
145-
initialDepth := 1
146-
err = s.scanDir(ctx, chunksChan, workerPool, rootPath, initialDepth, cleanPath)
147-
_ = workerPool.Wait()
135+
initialDepth := 0
136+
err = s.scanDir(ctx, chunksChan, rootPath, initialDepth, cleanPath)
148137
s.ClearEncodedResumeInfoFor(rootPath)
149138
} else {
150139
if !fileInfo.Mode().IsRegular() {
@@ -156,9 +145,7 @@ func (s *Source) Chunks(ctx trContext.Context, chunksChan chan *sources.Chunk, _
156145
}
157146

158147
if err != nil && !errors.Is(err, io.EOF) {
159-
if !errors.Is(err, skipSymlinkErr) {
160-
logger.Error(err, "error scanning filesystem")
161-
}
148+
logger.Error(err, "error scanning filesystem")
162149
}
163150
}
164151

@@ -168,14 +155,22 @@ func (s *Source) Chunks(ctx trContext.Context, chunksChan chan *sources.Chunk, _
168155
func (s *Source) scanSymlink(
169156
ctx trContext.Context,
170157
chunksChan chan *sources.Chunk,
171-
workerPool *errgroup.Group,
172158
rootPath string,
173159
depth int,
174160
path string,
175161
) error {
162+
if !s.canFollowSymlinks() {
163+
// If the file or directory is a symlink but the followSymlinks is disable ignore the path
164+
ctx.Logger().V(2).Info("skipping, following symlinks is not allowed", "path", path)
165+
return nil
166+
}
167+
168+
depth++
169+
176170
if depth > s.maxSymlinkDepth {
177171
return errors.New("max symlink depth reached")
178172
}
173+
179174
cleanPath := filepath.Clean(path)
180175

181176
resolvedPath, err := os.Readlink(cleanPath)
@@ -196,7 +191,7 @@ func (s *Source) scanSymlink(
196191
"resolvedPath", resolvedPath,
197192
"depth", depth,
198193
)
199-
return s.scanSymlink(ctx, chunksChan, workerPool, rootPath, depth+1, resolvedPath)
194+
return s.scanSymlink(ctx, chunksChan, rootPath, depth, resolvedPath)
200195
}
201196

202197
if fileInfo.IsDir() {
@@ -207,7 +202,7 @@ func (s *Source) scanSymlink(
207202
"depth", depth,
208203
)
209204

210-
return s.scanDir(ctx, chunksChan, workerPool, rootPath, depth+1, resolvedPath)
205+
return s.scanDir(ctx, chunksChan, rootPath, depth, resolvedPath)
211206
}
212207
ctx.Logger().V(5).Info(
213208
"found symlink to file",
@@ -223,25 +218,20 @@ func (s *Source) scanSymlink(
223218
// Resume checks are handled by the calling scanDir function.
224219
resumptionKey := rootPath
225220

226-
workerPool.Go(func() error {
227-
if !fileInfo.Mode().Type().IsRegular() {
228-
ctx.Logger().V(5).Info("skipping non-regular file", "path", resolvedPath)
229-
return nil
230-
}
231-
if err := s.scanFile(ctx, chunksChan, resolvedPath); err != nil {
232-
ctx.Logger().Error(err, "error scanning file", "path", resolvedPath)
233-
}
234-
s.SetEncodedResumeInfoFor(resumptionKey, cleanPath)
221+
if !fileInfo.Mode().Type().IsRegular() {
222+
ctx.Logger().V(5).Info("skipping non-regular file", "path", resolvedPath)
235223
return nil
236-
})
237-
224+
}
225+
if err := s.scanFile(ctx, chunksChan, resolvedPath); err != nil {
226+
ctx.Logger().Error(err, "error scanning file", "path", resolvedPath)
227+
}
228+
s.SetEncodedResumeInfoFor(resumptionKey, cleanPath)
238229
return nil
239230
}
240231

241232
func (s *Source) scanDir(
242233
ctx trContext.Context,
243234
chunksChan chan *sources.Chunk,
244-
workerPool *errgroup.Group,
245235
rootPath string,
246236
depth int,
247237
path string,
@@ -285,6 +275,9 @@ func (s *Source) scanDir(
285275
return fmt.Errorf("readdir error: %w", err)
286276
}
287277

278+
workerPool := new(errgroup.Group)
279+
workerPool.SetLimit(s.concurrency)
280+
288281
for _, entry := range entries {
289282
entryPath := filepath.Join(path, entry.Name())
290283
if s.filter != nil && !s.filter.Pass(entryPath) {
@@ -308,7 +301,7 @@ func (s *Source) scanDir(
308301
// traverse into it to find where to resume.
309302
if entry.IsDir() && strings.HasPrefix(resumeAfter, entryPath+string(filepath.Separator)) {
310303
// Recurse into this directory to find the resume point.
311-
if err := s.scanDir(ctx, chunksChan, workerPool, rootPath, depth, entryPath); err != nil {
304+
if err := s.scanDir(ctx, chunksChan, rootPath, depth, entryPath); err != nil {
312305
ctx.Logger().Error(err, "error scanning directory", "path", entryPath)
313306
}
314307
// After recursing, clear local resumeAfter. The child scanDir will have
@@ -323,17 +316,12 @@ func (s *Source) scanDir(
323316

324317
if entry.Type()&os.ModeSymlink != 0 {
325318
ctx.Logger().V(5).Info("Entry found is a symlink", "path", entryPath)
326-
if !s.canFollowSymlinks() {
327-
// If the file or directory is a symlink but the followSymlinks is disable ignore the path
328-
ctx.Logger().Info("skipping, following symlinks is not allowed", "path", entryPath)
329-
continue
330-
}
331-
if err := s.scanSymlink(ctx, chunksChan, workerPool, rootPath, depth, entryPath); err != nil {
319+
if err := s.scanSymlink(ctx, chunksChan, rootPath, depth, entryPath); err != nil {
332320
ctx.Logger().Error(err, "error scanning symlink", "path", entryPath)
333321
}
334322
} else if entry.IsDir() {
335323
ctx.Logger().V(5).Info("Entry found is a directory", "path", entryPath)
336-
if err := s.scanDir(ctx, chunksChan, workerPool, rootPath, depth, entryPath); err != nil {
324+
if err := s.scanDir(ctx, chunksChan, rootPath, depth, entryPath); err != nil {
337325
ctx.Logger().Error(err, "error scanning directory", "path", entryPath)
338326
}
339327
} else {
@@ -351,20 +339,18 @@ func (s *Source) scanDir(
351339
}
352340
}
353341

342+
_ = workerPool.Wait() // [TODO] Handle errors
343+
354344
return nil
355345
}
356346

357-
var skipSymlinkErr = errors.New("skipping symlink")
358-
359347
func (s *Source) scanFile(ctx trContext.Context, chunksChan chan *sources.Chunk, path string) error {
360348
fileCtx := trContext.WithValues(ctx, "path", path)
361-
fileStat, err := os.Lstat(path)
349+
350+
_, err := os.Lstat(path)
362351
if err != nil {
363352
return fmt.Errorf("unable to stat file: %w", err)
364353
}
365-
if fileStat.Mode()&os.ModeSymlink != 0 {
366-
return skipSymlinkErr
367-
}
368354

369355
// Check if file is binary and should be skipped
370356
if (s.skipBinaries || feature.ForceSkipBinaries.Load()) && common.IsBinary(path) {
@@ -435,28 +421,17 @@ func (s *Source) ChunkUnit(ctx trContext.Context, unit sources.SourceUnit, repor
435421
go func() {
436422
defer close(ch)
437423
if fileInfo.Mode()&os.ModeSymlink != 0 {
438-
if !s.canFollowSymlinks() {
439-
// If the file or directory is a symlink but the followSymlinks is disable ignore the path
440-
logger.Info("skipping, following symlinks is not allowed", "path", cleanPath)
441-
return
442-
}
443424
// if the root path is a symlink we scan the symlink
444425
ctx.Logger().V(5).Info("Root path is a symlink", "path", cleanPath)
445-
workerPool := new(errgroup.Group)
446-
workerPool.SetLimit(s.concurrency)
447-
initialDepth := 1
448-
scanErr = s.scanSymlink(ctx, ch, workerPool, rootPath, initialDepth, cleanPath)
449-
_ = workerPool.Wait()
426+
initialDepth := 0
427+
scanErr = s.scanSymlink(ctx, ch, rootPath, initialDepth, cleanPath)
450428
s.ClearEncodedResumeInfoFor(rootPath)
451429

452430
} else if fileInfo.IsDir() {
453431
ctx.Logger().V(5).Info("Root path is a dir", "path", cleanPath)
454-
workerPool := new(errgroup.Group)
455-
workerPool.SetLimit(s.concurrency)
456-
initialDepth := 1
432+
initialDepth := 0
457433
// TODO: Finer grain error tracking of individual chunks.
458-
scanErr = s.scanDir(ctx, ch, workerPool, rootPath, initialDepth, cleanPath)
459-
_ = workerPool.Wait()
434+
scanErr = s.scanDir(ctx, ch, rootPath, initialDepth, cleanPath)
460435
s.ClearEncodedResumeInfoFor(rootPath)
461436
} else {
462437
ctx.Logger().V(5).Info("Root path is a file", "path", cleanPath)
@@ -480,9 +455,7 @@ func (s *Source) ChunkUnit(ctx trContext.Context, unit sources.SourceUnit, repor
480455
}
481456

482457
if scanErr != nil && !errors.Is(scanErr, io.EOF) {
483-
if !errors.Is(scanErr, skipSymlinkErr) {
484-
logger.Error(scanErr, "error scanning filesystem")
485-
}
458+
logger.Error(scanErr, "error scanning filesystem")
486459
return reporter.ChunkErr(ctx, scanErr)
487460
}
488461
return nil

pkg/sources/filesystem/filesystem_symlink_test.go

Lines changed: 14 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88

99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
11-
"golang.org/x/sync/errgroup"
1211
"google.golang.org/protobuf/types/known/anypb"
1312

1413
trContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
@@ -506,10 +505,8 @@ func TestScanSymlink_NoError(t *testing.T) {
506505
}
507506
chunks := make(chan *sources.Chunk, 10)
508507
go func() {
509-
workerPool := new(errgroup.Group)
510-
workerPool.SetLimit(src.concurrency)
511-
err := src.scanSymlink(ctx, chunks, workerPool, filepath.Join(baseDir, "A"), 1, filepath.Join(baseDir, "A"))
512-
_ = workerPool.Wait()
508+
path := filepath.Join(baseDir, "A")
509+
err := src.scanSymlink(ctx, chunks, path, 0, path)
513510
require.NoError(t, err)
514511
close(chunks)
515512
}()
@@ -555,19 +552,11 @@ func TestScanSymlink_MaxDepthExceeded(t *testing.T) {
555552
maxSymlinkDepth: 2,
556553
}
557554
chunks := make(chan *sources.Chunk, 10)
558-
workerPool := new(errgroup.Group)
559-
workerPool.SetLimit(src.concurrency)
560-
561-
err = src.scanSymlink(
562-
ctx,
563-
chunks,
564-
workerPool,
565-
filepath.Join(baseDir, "A"),
566-
1,
567-
filepath.Join(baseDir, "A"),
568-
)
569-
_ = workerPool.Wait()
555+
556+
path := filepath.Join(baseDir, "A")
557+
err = src.scanSymlink(ctx, chunks, path, 0, path)
570558
close(chunks)
559+
571560
require.Error(t, err)
572561
require.EqualError(t, err, "max symlink depth reached")
573562
}
@@ -597,18 +586,8 @@ func TestScanSymlink_FileTarget(t *testing.T) {
597586
}
598587

599588
chunks := make(chan *sources.Chunk, 10)
600-
workerPool := new(errgroup.Group)
601-
workerPool.SetLimit(src.concurrency)
602-
603-
err = src.scanSymlink(
604-
ctx,
605-
chunks,
606-
workerPool,
607-
symlinkPath,
608-
1,
609-
symlinkPath,
610-
)
611-
_ = workerPool.Wait()
589+
590+
err = src.scanSymlink(ctx, chunks, symlinkPath, 0, symlinkPath)
612591
require.NoError(t, err)
613592
close(chunks)
614593
var chunkCount int
@@ -639,18 +618,8 @@ func TestScanSymlink_SelfLoop(t *testing.T) {
639618
}
640619

641620
chunks := make(chan *sources.Chunk, 10)
642-
workerPool := new(errgroup.Group)
643-
workerPool.SetLimit(src.concurrency)
644-
645-
err = src.scanSymlink(
646-
ctx,
647-
chunks,
648-
workerPool,
649-
symlinkPath,
650-
1,
651-
symlinkPath,
652-
)
653-
_ = workerPool.Wait()
621+
622+
err = src.scanSymlink(ctx, chunks, symlinkPath, 0, symlinkPath)
654623
close(chunks)
655624
require.Error(t, err)
656625
require.EqualError(t, err, "max symlink depth reached")
@@ -675,18 +644,8 @@ func TestScanSymlink_BrokenSymlink(t *testing.T) {
675644
}
676645

677646
chunks := make(chan *sources.Chunk, 10)
678-
workerPool := new(errgroup.Group)
679-
workerPool.SetLimit(src.concurrency)
680-
681-
err = src.scanSymlink(
682-
ctx,
683-
chunks,
684-
workerPool,
685-
symlinkPath,
686-
0,
687-
symlinkPath,
688-
)
689-
_ = workerPool.Wait()
647+
648+
err = src.scanSymlink(ctx, chunks, symlinkPath, 0, symlinkPath)
690649
close(chunks)
691650
require.Error(t, err)
692651
require.Contains(t, err.Error(), "lstat error")
@@ -714,18 +673,8 @@ func TestScanSymlink_TwoFileLoop(t *testing.T) {
714673
}
715674

716675
chunks := make(chan *sources.Chunk, 10)
717-
workerPool := new(errgroup.Group)
718-
workerPool.SetLimit(src.concurrency)
719-
720-
err = src.scanSymlink(
721-
ctx,
722-
chunks,
723-
workerPool,
724-
fileA,
725-
0,
726-
fileA,
727-
)
728-
_ = workerPool.Wait()
676+
677+
err = src.scanSymlink(ctx, chunks, fileA, 0, fileA)
729678
close(chunks)
730679
require.Error(t, err)
731680
require.EqualError(t, err, "max symlink depth reached")

0 commit comments

Comments
 (0)