Skip to content

Commit 2630db5

Browse files
committed
feat: update GetAllConcurrentlyWithContext to fix test failure
1 parent 6944614 commit 2630db5

File tree

3 files changed

+169
-73
lines changed

3 files changed

+169
-73
lines changed

aws/s3/s3_concurrent.go

Lines changed: 34 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -157,41 +157,13 @@ func newConcurrencyManager(maxWorkers, maxWorkersPerRequest, maxBytes int) *Conc
157157
// containing a single HydratedFile with an error is returned.
158158
// Version can be empty, but must be the same for all objects.
159159
func (s *S3Concurrent) GetAllConcurrently(bucket, version string, objects []types.Object) chan HydratedFile {
160-
161-
if s.manager == nil {
162-
output := make(chan HydratedFile, 1)
163-
output <- HydratedFile{Error: errors.New("error getting files from S3, Concurrency Manager not initialised")}
164-
close(output)
165-
return output
166-
}
167-
168-
if s.manager.memoryTotalSize < s.manager.calculateRequiredMemoryFor(objects) {
169-
output := make(chan HydratedFile, 1)
170-
output <- HydratedFile{Error: fmt.Errorf("error: bytes requested greater than max allowed by server (%v)", s.manager.memoryTotalSize)}
171-
close(output)
172-
return output
173-
}
174-
// Secure memory for all objects upfront.
175-
s.manager.secureMemory(objects) // 0.
176-
177-
processFunc := func(input types.Object) HydratedFile {
178-
buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size)))
179-
key := aws.ToString(input.Key)
180-
err := s.Get(bucket, key, version, buf)
181-
182-
return HydratedFile{
183-
Key: key,
184-
Data: buf.Bytes(),
185-
Error: err,
186-
}
187-
}
188-
return s.manager.Process(processFunc, objects)
160+
return s.GetAllConcurrentlyWithContext(context.Background(), bucket, version, objects)
189161
}
190162

191-
// GetAllConcurrently gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles
163+
// GetAllConcurrentlyWithContext gets the objects with provided context, from specified bucket and writes the resulting HydratedFiles
192164
// to the returned output channel. The closure of this channel is handled, however it's the caller's
193165
// responsibility to purge the channel, and handle any errors present in the HydratedFiles.
194-
// If the ConcurrencyManager is not initialised before calling GetAllConcurrently, an output channel
166+
// If the ConcurrencyManager is not initialised before calling GetAllConcurrentlyWithContext, an output channel
195167
// containing a single HydratedFile with an error is returned.
196168
// Version can be empty, but must be the same for all objects.
197169
func (s *S3Concurrent) GetAllConcurrentlyWithContext(
@@ -229,13 +201,14 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext(
229201
close(output)
230202
return output
231203
}
204+
232205
// Secure memory for all objects upfront.
233206
s.manager.secureMemory(objects) // 0.
234207

235-
// IMPORTANT: ensure memory is released if context cancels before processing finishes
208+
// ensure memory is released if context cancels before processing finishes
236209
go func() {
237210
<-ctx.Done()
238-
// Best-effort cleanup: release all secured memory
211+
// release all secured memory
239212
for _, o := range objects {
240213
s.manager.releaseMemory(aws.ToInt64(o.Size))
241214
}
@@ -252,7 +225,6 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext(
252225
buf := bytes.NewBuffer(make([]byte, 0, int(*input.Size)))
253226
key := aws.ToString(input.Key)
254227

255-
// Prefer context-aware S3 call if available
256228
_, err := s.GetWithContext(ctx, bucket, key, version, buf)
257229

258230
// If context was cancelled during S3 read, surface that
@@ -267,8 +239,8 @@ func (s *S3Concurrent) GetAllConcurrentlyWithContext(
267239
}
268240
}
269241

270-
// Process already accepts a context internally, so pass it through
271-
return s.manager.ProcessWithContext(ctx, processFunc, objects)
242+
// Process with a context
243+
return s.manager.Process(ctx, processFunc, objects)
272244
}
273245

274246
// getWorker retrieves a number of workers from the manager's worker pool.
@@ -327,25 +299,10 @@ func (cm *ConcurrencyManager) releaseMemory(size int64) {
327299
}
328300
}
329301

330-
// Functions for providing a fan-out/fan-in operation. Workers are taken from the
331-
// worker pool and added to a WorkerGroup. All workers are returned to the pool once
332-
// the jobs have finished.
333-
func (cm *ConcurrencyManager) Process(asyncProcessor FileProcessor, objects []types.Object) chan HydratedFile {
334-
workerGroup := cm.newWorkerGroup(context.Background(), asyncProcessor, cm.maxWorkersPerRequest) // 1.
335-
336-
go func() {
337-
for _, obj := range objects {
338-
workerGroup.addWork(obj)
339-
}
340-
workerGroup.stopWork() // 9.
341-
}()
342-
return workerGroup.returnOutput() // 2.
343-
}
344-
345302
// Functions for providing a fan-out/fan-in operation with provided context. Workers are taken from the
346303
// worker pool and added to a WorkerGroup. All workers are returned to the pool once
347304
// the jobs have finished.
348-
func (cm *ConcurrencyManager) ProcessWithContext(
305+
func (cm *ConcurrencyManager) Process(
349306
ctx context.Context,
350307
asyncProcessor FileProcessor,
351308
objects []types.Object,
@@ -354,24 +311,34 @@ func (cm *ConcurrencyManager) ProcessWithContext(
354311
workerGroup := cm.newWorkerGroup(ctx, asyncProcessor, cm.maxWorkersPerRequest)
355312

356313
go func() {
314+
defer func() {
315+
close(workerGroup.reception)
316+
workerGroup.stopWork()
317+
}()
318+
357319
for _, obj := range objects {
358320
select {
359321
case <-ctx.Done():
360-
workerGroup.stopWork()
361322
return
362323
default:
363-
workerGroup.addWork(obj)
324+
if !workerGroup.addWork(ctx, obj) {
325+
return
326+
}
364327
}
365328
}
366-
workerGroup.stopWork()
367329
}()
368330

369331
return workerGroup.returnOutput()
370332
}
371333

372334
// start begins a worker's process of making itself available for work, doing the work,
373335
// and repeat, until all work is done.
374-
func (w *worker) start(ctx context.Context, processor FileProcessor, roster chan *worker, wg *sync.WaitGroup) {
336+
func (w *worker) start(
337+
ctx context.Context,
338+
processor FileProcessor,
339+
roster chan *worker,
340+
wg *sync.WaitGroup,
341+
) {
375342
go func() {
376343
defer func() {
377344
wg.Done()
@@ -451,20 +418,26 @@ func (wg *workerGroup) startOutput() {
451418
func (wg *workerGroup) cleanUp(ctx context.Context) {
452419
<-ctx.Done()
453420
wg.group.Wait() // 9.
454-
close(wg.reception)
421+
//close(wg.reception)
455422
close(wg.roster)
456423
}
457424

458425
// addWork gets the first available worker from the workerGroup's
459426
// roster, and gives it an S3 Object to download. The worker's output
460427
// channel is registered to the workerGroup's reception so that
461428
// order is retained.
462-
func (wg *workerGroup) addWork(newWork types.Object) { // 4.
429+
func (wg *workerGroup) addWork(ctx context.Context, newWork types.Object) bool {
463430
for w := range wg.roster {
464-
w.input <- newWork
465-
wg.reception <- w.output
466-
break
431+
select {
432+
case <-ctx.Done():
433+
return false
434+
default:
435+
w.input <- newWork
436+
wg.reception <- w.output
437+
return true
438+
}
467439
}
440+
return false
468441
}
469442

470443
// returnOutput returns the workerGroup's output channel.

aws/s3/s3_concurrent_test.go

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func TestS3GetAllConcurrently(t *testing.T) {
9090
}
9191

9292
// ASSERT input and output order is the same.
93-
require.Equal(t, len(outputKeys), total)
93+
require.Equal(t, total, len(outputKeys))
9494
for i := 0; i < total; i++ {
9595
assert.Equal(t, aws.ToString(objects[i].Key), outputKeys[i])
9696
}
@@ -123,6 +123,117 @@ func TestS3GetAllConcurrently(t *testing.T) {
123123
}
124124
}
125125

126+
// go test --run TestS3GetAllConcurrentlyWithContext -v
127+
func TestS3GetAllConcurrentlyWithContext(t *testing.T) {
128+
// ARRANGE
129+
setup()
130+
defer teardown()
131+
132+
// ASSERT parameter errors.
133+
_, err := NewConcurrent(0, 100, 1000)
134+
assert.NotNil(t, err)
135+
_, err = NewConcurrent(100, 0, 1000)
136+
assert.NotNil(t, err)
137+
_, err = NewConcurrent(100, 100, 0)
138+
assert.NotNil(t, err)
139+
_, err = NewConcurrent(100, 10, 99)
140+
assert.NotNil(t, err)
141+
_, err = NewConcurrent(100, 101, 1000)
142+
assert.NotNil(t, err)
143+
144+
client, err := NewConcurrent(100, 10, 1000)
145+
require.Nil(t, err, fmt.Sprintf("error creating s3 client concurrency manager: %v", err))
146+
147+
// ASSERT computed fields.
148+
assert.Equal(t, 100, len(client.manager.workerPool.channel))
149+
assert.Equal(t, 100, len(client.manager.memoryPool.channel))
150+
assert.Equal(t, int64(10), client.manager.memoryChunkSize)
151+
assert.Equal(t, int64(10*100), client.manager.memoryTotalSize)
152+
assert.Equal(t, 10, client.manager.maxWorkersPerRequest)
153+
154+
// ASSERT memory chunk size is correct in memory pool.
155+
chunk := <-client.manager.memoryPool.channel
156+
assert.Equal(t, int64(10), chunk)
157+
client.manager.memoryPool.channel <- chunk
158+
159+
// ASSERT worker get/release methods work expectedly.
160+
w := client.manager.getWorkers(1)
161+
assert.Equal(t, 99, len(client.manager.workerPool.channel))
162+
client.manager.returnWorker(w[0])
163+
assert.Equal(t, 100, len(client.manager.workerPool.channel))
164+
165+
// ASSERT memory get/release methods work expectedly.
166+
elevenByteFile := types.Object{Size: aws.Int64(11)} // requires 2 memory chunks.
167+
client.manager.secureMemory([]types.Object{elevenByteFile})
168+
assert.Equal(t, 98, len(client.manager.memoryPool.channel))
169+
client.manager.releaseMemory(20)
170+
assert.Equal(t, 100, len(client.manager.memoryPool.channel))
171+
172+
// ARRANGE bucket with test objects.
173+
total := 20
174+
keys := make([]string, total)
175+
for i := 0; i < total; i++ {
176+
keys[i] = fmt.Sprintf("%s-%v", testObjectKey, i)
177+
}
178+
awsCmdPutKeys(keys)
179+
180+
// ACTION
181+
objects, _ := client.ListAllObjects(testBucket, "")
182+
tooManyBytes := make([]types.Object, 10*len(objects))
183+
for _, o := range objects {
184+
for i := 0; i < 10; i++ {
185+
tooManyBytes = append(tooManyBytes, o)
186+
}
187+
}
188+
output := client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", tooManyBytes)
189+
190+
// ASSERT error returned
191+
for hf := range output {
192+
assert.NotNil(t, hf.Error)
193+
}
194+
195+
// ACTION
196+
objects, _ = client.ListAllObjects(testBucket, "")
197+
output = client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", objects)
198+
outputKeys := make([]string, 0)
199+
for hf := range output {
200+
outputKeys = append(outputKeys, hf.Key)
201+
}
202+
203+
// ASSERT input and output order is the same.
204+
require.Equal(t, total, len(outputKeys))
205+
for i := 0; i < total; i++ {
206+
assert.Equal(t, aws.ToString(objects[i].Key), outputKeys[i])
207+
}
208+
209+
// ASSERT all workers and memory returned to pools.
210+
time.Sleep(2 * time.Second)
211+
assert.Equal(t, 100, len(client.manager.workerPool.channel))
212+
assert.Equal(t, 100, len(client.manager.memoryPool.channel))
213+
214+
// ASSERT that process blocked when all memory secured.
215+
tenByteFile := types.Object{Size: aws.Int64(10)}
216+
oneThousandBytesOfFiles := make([]types.Object, 100)
217+
for i := 0; i < 100; i++ {
218+
oneThousandBytesOfFiles[i] = tenByteFile
219+
}
220+
client.manager.secureMemory(oneThousandBytesOfFiles)
221+
ch := make(chan chan HydratedFile)
222+
go func() {
223+
ch <- client.GetAllConcurrentlyWithContext(context.Background(), testBucket, "", objects)
224+
}()
225+
226+
for {
227+
select {
228+
case <-ch:
229+
t.Error("process was not blocked")
230+
case <-time.After(time.Second):
231+
// Timed out as expected
232+
return
233+
}
234+
}
235+
}
236+
126237
// go test --run TestS3GetAllConcurrentlyWithContext_Cancel -v
127238
func TestS3GetAllConcurrentlyWithContext_Cancel(t *testing.T) {
128239
// ARRANGE
@@ -175,5 +286,14 @@ func TestS3GetAllConcurrentlyWithContext_Cancel(t *testing.T) {
175286
require.GreaterOrEqual(t, len(collected), cancelAfter)
176287
// But not all objects should be processed
177288
require.Less(t, len(collected), len(objects))
289+
// Pool recovery
290+
require.Eventually(t, func() bool {
291+
return len(client.manager.workerPool.channel) == 100
292+
}, 5*time.Second, 10*time.Millisecond, fmt.Sprintf("workers pool not recovered, expected %d actual %d", 100, len(client.manager.workerPool.channel)))
293+
require.Eventually(t, func() bool {
294+
return len(client.manager.memoryPool.channel) == 100
295+
}, 5*time.Second, 10*time.Millisecond, fmt.Sprintf("memory pool not recovered, expected %d actual %d", 100, len(client.manager.memoryPool.channel)))
296+
178297
})
298+
179299
}

aws/s3/s3_integration_test.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,20 @@ func setup() {
5050
// setup environment variable to run AWS CLI/SDK
5151
setAwsEnv()
5252

53-
// create bucket
54-
cmd := exec.Command( //nolint:gosec
55-
"aws", "s3api",
56-
"create-bucket",
57-
"--bucket", testBucket,
58-
"--create-bucket-configuration", fmt.Sprintf(
59-
"{\"LocationConstraint\": \"%v\"}", testRegion),
60-
)
61-
if output, err := cmd.CombinedOutput(); err != nil {
62-
fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output))
63-
panic(err)
53+
// check if bucket exists before creating
54+
if !awsCmdBucketExists(testBucket) {
55+
// create bucket
56+
cmd := exec.Command( //nolint:gosec
57+
"aws", "s3api",
58+
"create-bucket",
59+
"--bucket", testBucket,
60+
"--create-bucket-configuration", fmt.Sprintf(
61+
"{\"LocationConstraint\": \"%v\"}", testRegion),
62+
)
63+
if output, err := cmd.CombinedOutput(); err != nil {
64+
fmt.Printf("Command failed: %v\nOutput: %s\n", err, string(output))
65+
panic(err)
66+
}
6467
}
6568
}
6669

0 commit comments

Comments
 (0)