@@ -157,41 +157,13 @@ func newConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *Conc
157157// containing a single HydratedFile with an error is returned.
158158// Version can be empty, but must be the same for all objects.
159159func (s * S3Concurrent ) GetAllConcurrently (bucket , version string , objects []types.Object ) chan HydratedFile {
160-
161- if s .manager == nil {
162- output := make (chan HydratedFile , 1 )
163- output <- HydratedFile {Error : errors .New ("error getting files from S3, Concurrency Manager not initialised" )}
164- close (output )
165- return output
166- }
167-
168- if s .manager .memoryTotalSize < s .manager .calculateRequiredMemoryFor (objects ) {
169- output := make (chan HydratedFile , 1 )
170- output <- HydratedFile {Error : fmt .Errorf ("error: bytes requested greater than max allowed by server (%v)" , s .manager .memoryTotalSize )}
171- close (output )
172- return output
173- }
174- // Secure memory for all objects upfront.
175- s .manager .secureMemory (objects ) // 0.
176-
177- processFunc := func (input types.Object ) HydratedFile {
178- buf := bytes .NewBuffer (make ([]byte , 0 , int (* input .Size )))
179- key := aws .ToString (input .Key )
180- err := s .Get (bucket , key , version , buf )
181-
182- return HydratedFile {
183- Key : key ,
184- Data : buf .Bytes (),
185- Error : err ,
186- }
187- }
188- return s .manager .Process (processFunc , objects )
160+ return s .GetAllConcurrentlyWithContext (context .Background (), bucket , version , objects )
189161}
190162
191- // GetAllConcurrently gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles
163+ // GetAllConcurrentlyWithContext gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles
192164// to the returned output channel. The closure of this channel is handled, however it's the caller's
193165// responsibility to purge the channel, and handle any errors present in the HydratedFiles.
194- // If the ConcurrencyManager is not initialised before calling GetAllConcurrently , an output channel
166+ // If the ConcurrencyManager is not initialised before calling GetAllConcurrentlyWithContext , an output channel
195167// containing a single HydratedFile with an error is returned.
196168// Version can be empty, but must be the same for all objects.
197169func (s * S3Concurrent ) GetAllConcurrentlyWithContext (
@@ -229,19 +201,35 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext(
229201 close (output )
230202 return output
231203 }
204+
232205 // Secure memory for all objects upfront.
233206 s .manager .secureMemory (objects ) // 0.
234207
235- // IMPORTANT: ensure memory is released if context cancels before processing finishes
208+ // Track which objects have been dispatched to workers
209+ // so we know which memory to release on cancellation
210+ var dispatchedMutex sync.Mutex
211+ dispatchedObjects := make (map [string ]bool )
212+
213+ // ensure memory is released if context cancels before processing finishes
236214 go func () {
237215 <- ctx .Done ()
238- // Best-effort cleanup: release all secured memory
216+ // Only release memory for objects that were never dispatched
217+ dispatchedMutex .Lock ()
218+ defer dispatchedMutex .Unlock ()
239219 for _ , o := range objects {
240- s .manager .releaseMemory (aws .ToInt64 (o .Size ))
220+ key := aws .ToString (o .Key )
221+ if ! dispatchedObjects [key ] {
222+ s .manager .releaseMemory (aws .ToInt64 (o .Size ))
223+ }
241224 }
242225 }()
243226
244227 processFunc := func (input types.Object ) HydratedFile {
228+ // Mark as dispatched
229+ dispatchedMutex .Lock ()
230+ dispatchedObjects [aws .ToString (input .Key )] = true
231+ dispatchedMutex .Unlock ()
232+
245233 // Respect cancellation before starting work
246234 select {
247235 case <- ctx .Done ():
@@ -252,7 +240,6 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext(
252240 buf := bytes .NewBuffer (make ([]byte , 0 , int (* input .Size )))
253241 key := aws .ToString (input .Key )
254242
255- // Prefer context-aware S3 call if available
256243 _ , err := s .GetWithContext (ctx , bucket , key , version , buf )
257244
258245 // If context was cancelled during S3 read, surface that
@@ -267,8 +254,8 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext(
267254 }
268255 }
269256
270- // Process already accepts a context internally, so pass it through
271- return s .manager .ProcessWithContext (ctx , processFunc , objects )
257+ // Process with a context
258+ return s .manager .Process (ctx , processFunc , objects )
272259}
273260
274261// getWorker retrieves a number of workers from the manager's worker pool.
@@ -327,25 +314,10 @@ func (cm *ConcurrencyManager) releaseMemory(size int64) {
327314 }
328315}
329316
330- // Functions for providing a fan-out/fan-in operation. Workers are taken from the
331- // worker pool and added to a WorkerGroup. All workers are returned to the pool once
332- // the jobs have finished.
333- func (cm * ConcurrencyManager ) Process (asyncProcessor FileProcessor , objects []types.Object ) chan HydratedFile {
334- workerGroup := cm .newWorkerGroup (context .Background (), asyncProcessor , cm .maxWorkersPerRequest ) // 1.
335-
336- go func () {
337- for _ , obj := range objects {
338- workerGroup .addWork (obj )
339- }
340- workerGroup .stopWork () // 9.
341- }()
342- return workerGroup .returnOutput () // 2.
343- }
344-
345317// Functions for providing a fan-out/fan-in operation with provided context. Workers are taken from the
346318// worker pool and added to a WorkerGroup. All workers are returned to the pool once
347319// the jobs have finished.
348- func (cm * ConcurrencyManager ) ProcessWithContext (
320+ func (cm * ConcurrencyManager ) Process (
349321 ctx context.Context ,
350322 asyncProcessor FileProcessor ,
351323 objects []types.Object ,
@@ -354,48 +326,68 @@ func (cm *ConcurrencyManager) ProcessWithContext(
354326 workerGroup := cm .newWorkerGroup (ctx , asyncProcessor , cm .maxWorkersPerRequest )
355327
356328 go func () {
329+ defer func () {
330+ close (workerGroup .reception )
331+ workerGroup .stopWork () // 9.
332+ }()
333+
357334 for _ , obj := range objects {
358335 select {
359336 case <- ctx .Done ():
360- workerGroup .stopWork ()
361337 return
362338 default :
363- workerGroup .addWork (obj )
339+ if ! workerGroup .addWork (ctx , obj ) {
340+ return
341+ }
364342 }
365343 }
366- workerGroup .stopWork ()
367344 }()
368345
369- return workerGroup .returnOutput ()
346+ return workerGroup .returnOutput () // 2.
370347}
371348
372349// start begins a worker's process of making itself available for work, doing the work,
373350// and repeat, until all work is done.
374- func (w * worker ) start (ctx context.Context , processor FileProcessor , roster chan * worker , wg * sync.WaitGroup ) {
351+ func (w * worker ) start ( // 4.
352+ ctx context.Context ,
353+ processor FileProcessor ,
354+ roster chan * worker ,
355+ wg * sync.WaitGroup ,
356+ ) {
375357 go func () {
376358 defer func () {
377359 wg .Done ()
378360
379- // Make sure workers contents have been consumed
380- // before returning to pool.
381- if len (w .input ) > 0 {
382- input := <- w .input
361+ // Process any remaining input before returning to pool
362+ select {
363+ case input := <- w .input :
383364 w .output <- processor (input )
384- w .manager .releaseMemory (int64 (* input .Size ))
365+ w .manager .releaseMemory (aws .ToInt64 (input .Size ))
366+ default :
367+ // No pending input
385368 }
369+
370+ // Wait for output to be consumed
386371 for len (w .output ) > 0 {
387372 time .Sleep (1 * time .Millisecond )
388373 }
389374
390375 w .manager .returnWorker (w ) // 10.
391376 }()
392377 for {
378+ select {
379+ case <- ctx .Done ():
380+ return
381+ default :
382+ // Non-blocking check allows us to add to roster
383+ }
384+
393385 roster <- w // 3., 7.
394386
395387 select {
396388 case input := <- w .input : // 5.
397389 w .output <- processor (input ) // 6.
398- w .manager .releaseMemory (int64 ( * input .Size ))
390+ w .manager .releaseMemory (aws . ToInt64 ( input .Size ))
399391 case <- ctx .Done (): // 9.
400392 return
401393 }
@@ -451,20 +443,26 @@ func (wg *workerGroup) startOutput() {
451443func (wg * workerGroup ) cleanUp (ctx context.Context ) {
452444 <- ctx .Done ()
453445 wg .group .Wait () // 9.
454- close (wg .reception )
446+ // close(wg.reception)
455447 close (wg .roster )
456448}
457449
458450// addWork gets the first available worker from the workerGroup's
459451// roster, and gives it an S3 Object to download. The worker's output
460452// channel is registered to the workerGroup's reception so that
461453// order is retained.
462- func (wg * workerGroup ) addWork (newWork types.Object ) { // 4.
454+ func (wg * workerGroup ) addWork (ctx context. Context , newWork types.Object ) bool {
463455 for w := range wg .roster {
464- w .input <- newWork
465- wg .reception <- w .output
466- break
456+ select {
457+ case <- ctx .Done ():
458+ return false
459+ default :
460+ w .input <- newWork
461+ wg .reception <- w .output
462+ return true
463+ }
467464 }
465+ return false
468466}
469467
470468// returnOutput returns the workerGroup's output channel.
0 commit comments