diff --git a/outlet/kafka/scaler.go b/outlet/kafka/scaler.go index 3811dc64..86c1e387 100644 --- a/outlet/kafka/scaler.go +++ b/outlet/kafka/scaler.go @@ -5,6 +5,7 @@ package kafka import ( "context" + "sync" "time" ) @@ -65,6 +66,34 @@ func (s *scalerState) nextWorkerCount(request ScaleRequest, currentWorkers, minW return currentWorkers } +// scaleWhileDraining runs a scaling function while draining incoming signals +// from the channel. It spawns two goroutines: one to discard signals and one to +// run the scaling function. +func scaleWhileDraining(ctx context.Context, ch <-chan ScaleRequest, scaleFn func()) { + var wg sync.WaitGroup + done := make(chan struct{}) + wg.Add(2) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case <-done: + return + case <-ch: + // Discard signal + } + } + }() + go func() { + defer wg.Done() + scaleFn() + close(done) + }() + wg.Wait() +} + // runScaler starts the automatic scaling loop func runScaler(ctx context.Context, config scalerConfiguration) chan<- ScaleRequest { ch := make(chan ScaleRequest, config.maxWorkers) @@ -87,9 +116,11 @@ func runScaler(ctx context.Context, config scalerConfiguration) chan<- ScaleRequ current := config.getWorkerCount() target := state.nextWorkerCount(ScaleIncrease, current, config.minWorkers, config.maxWorkers) if target > current { - config.increaseWorkers(current, target) + scaleWhileDraining(ctx, ch, func() { + config.increaseWorkers(current, target) + }) } - last = now + last = time.Now() decreaseCount = 0 continue } @@ -110,9 +141,11 @@ func runScaler(ctx context.Context, config scalerConfiguration) chan<- ScaleRequ current := config.getWorkerCount() target := state.nextWorkerCount(ScaleDecrease, current, config.minWorkers, config.maxWorkers) if target < current { - config.decreaseWorkers(current, target) + scaleWhileDraining(ctx, ch, func() { + config.decreaseWorkers(current, target) + }) } - last = now + last = time.Now() decreaseCount = 0 } } diff --git a/outlet/kafka/scaler_test.go b/outlet/kafka/scaler_test.go index c6dfcd25..8419ab0e 100644 --- a/outlet/kafka/scaler_test.go +++ b/outlet/kafka/scaler_test.go @@ -274,6 +274,103 @@ func TestScalerRateLimiter(t *testing.T) { }) } +func TestScalerDoesNotBlock(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + var mu sync.Mutex + currentWorkers := 1 + scalingInProgress := false + + config := scalerConfiguration{ + minWorkers: 1, + maxWorkers: 16, + increaseRateLimit: time.Second, + decreaseRateLimit: time.Second, + getWorkerCount: func() int { + mu.Lock() + defer mu.Unlock() + return currentWorkers + }, + increaseWorkers: func(from, to int) { + t.Logf("increaseWorkers(from: %d, to: %d) - start", from, to) + mu.Lock() + scalingInProgress = true + mu.Unlock() + + // Simulate a slow scaling operation + time.Sleep(30 * time.Second) + + mu.Lock() + currentWorkers = to + scalingInProgress = false + mu.Unlock() + t.Logf("increaseWorkers(from: %d, to: %d) - done", from, to) + }, + decreaseWorkers: func(from, to int) { + t.Logf("decreaseWorkers(from: %d, to: %d) - start", from, to) + mu.Lock() + scalingInProgress = true + mu.Unlock() + + // Simulate a slow scaling operation + time.Sleep(30 * time.Second) + + mu.Lock() + currentWorkers = to + scalingInProgress = false + mu.Unlock() + t.Logf("decreaseWorkers(from: %d, to: %d) - done", from, to) + }, + } + + ch := runScaler(ctx, config) + + // Send the first scale request that will trigger a slow scaling operation + ch <- ScaleIncrease + time.Sleep(time.Second) + + // Verify scaling is in progress + mu.Lock() + if !scalingInProgress { + t.Fatal("runScaler(): scaling should be in progress") + } + mu.Unlock() + + // Now send many more signals while scaling is in progress. + // These should not block - they should be discarded. + sendDone := make(chan struct{}) + go func() { + for range 100 { + ch <- ScaleIncrease + } + close(sendDone) + }() + + // Wait for all sends to complete with a timeout + select { + case <-sendDone: + t.Log("runScaler(): all signals sent successfully without blocking") + case <-time.After(5 * time.Second): + t.Fatal("runScaler(): blocked") + } + + // Wait for the scaling operation to complete + time.Sleep(30 * time.Second) + + // Verify scaling completed + mu.Lock() + defer mu.Unlock() + if scalingInProgress { + t.Fatal("runScaler(): still scaling") + } + if currentWorkers != 9 { + t.Fatalf("runScaler(): expected 9 workers, got %d", currentWorkers) + } + }) +} + func TestScalerState(t *testing.T) { tests := []struct { name string