diff options
-rw-r--r-- | AUTHORS | 3 | ||||
-rw-r--r-- | CONTRIBUTORS | 3 | ||||
-rw-r--r-- | METADATA | 8 | ||||
-rw-r--r-- | errgroup/errgroup.go | 74 | ||||
-rw-r--r-- | errgroup/errgroup_test.go | 86 | ||||
-rw-r--r-- | singleflight/singleflight.go | 11 | ||||
-rw-r--r-- | syncmap/map_test.go | 2 |
7 files changed, 163 insertions, 24 deletions
diff --git a/AUTHORS b/AUTHORS deleted file mode 100644 index 15167cd..0000000 --- a/AUTHORS +++ /dev/null @@ -1,3 +0,0 @@ -# This source code refers to The Go Authors for copyright purposes. -# The master list of authors is in the main Go distribution, -# visible at http://tip.golang.org/AUTHORS. diff --git a/CONTRIBUTORS b/CONTRIBUTORS deleted file mode 100644 index 1c4577e..0000000 --- a/CONTRIBUTORS +++ /dev/null @@ -1,3 +0,0 @@ -# This source code was written by the Go contributors. -# The master list of contributors is in the main Go distribution, -# visible at http://tip.golang.org/CONTRIBUTORS. @@ -9,11 +9,11 @@ third_party { type: GIT value: "https://go.googlesource.com/sync/" } - version: "036812b2e83c0ddf193dd5a34e034151da389d09" + version: "v0.1.0" license_type: NOTICE last_upgrade_date { - year: 2021 - month: 8 - day: 27 + year: 2023 + month: 3 + day: 15 } } diff --git a/errgroup/errgroup.go b/errgroup/errgroup.go index 9857fe5..cbee7a4 100644 --- a/errgroup/errgroup.go +++ b/errgroup/errgroup.go @@ -8,22 +8,35 @@ package errgroup import ( "context" + "fmt" "sync" ) +type token struct{} + // A Group is a collection of goroutines working on subtasks that are part of // the same overall task. // -// A zero Group is valid and does not cancel on error. +// A zero Group is valid, has no limit on the number of active goroutines, +// and does not cancel on error. type Group struct { cancel func() wg sync.WaitGroup + sem chan token + errOnce sync.Once err error } +func (g *Group) done() { + if g.sem != nil { + <-g.sem + } + g.wg.Done() +} + // WithContext returns a new Group and an associated Context derived from ctx. // // The derived Context is canceled the first time a function passed to Go @@ -45,14 +58,48 @@ func (g *Group) Wait() error { } // Go calls the given function in a new goroutine. +// It blocks until the new goroutine can be added without the number of +// active goroutines in the group exceeding the configured limit. // -// The first call to return a non-nil error cancels the group; its error will be -// returned by Wait. +// The first call to return a non-nil error cancels the group's context, if the +// group was created by calling WithContext. The error will be returned by Wait. func (g *Group) Go(f func() error) { + if g.sem != nil { + g.sem <- token{} + } + g.wg.Add(1) + go func() { + defer g.done() + + if err := f(); err != nil { + g.errOnce.Do(func() { + g.err = err + if g.cancel != nil { + g.cancel() + } + }) + } + }() +} + +// TryGo calls the given function in a new goroutine only if the number of +// active goroutines in the group is currently below the configured limit. +// +// The return value reports whether the goroutine was started. +func (g *Group) TryGo(f func() error) bool { + if g.sem != nil { + select { + case g.sem <- token{}: + // Note: this allows barging iff channels in general allow barging. + default: + return false + } + } + g.wg.Add(1) go func() { - defer g.wg.Done() + defer g.done() if err := f(); err != nil { g.errOnce.Do(func() { @@ -63,4 +110,23 @@ func (g *Group) Go(f func() error) { }) } }() + return true +} + +// SetLimit limits the number of active goroutines in this group to at most n. +// A negative value indicates no limit. +// +// Any subsequent call to the Go method will block until it can add an active +// goroutine without exceeding the configured limit. +// +// The limit must not be modified while any goroutines in the group are active. +func (g *Group) SetLimit(n int) { + if n < 0 { + g.sem = nil + return + } + if len(g.sem) != 0 { + panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem))) + } + g.sem = make(chan token, n) } diff --git a/errgroup/errgroup_test.go b/errgroup/errgroup_test.go index 5a0b9cb..0358842 100644 --- a/errgroup/errgroup_test.go +++ b/errgroup/errgroup_test.go @@ -10,7 +10,9 @@ import ( "fmt" "net/http" "os" + "sync/atomic" "testing" + "time" "golang.org/x/sync/errgroup" ) @@ -174,3 +176,87 @@ func TestWithContext(t *testing.T) { } } } + +func TestTryGo(t *testing.T) { + g := &errgroup.Group{} + n := 42 + g.SetLimit(42) + ch := make(chan struct{}) + fn := func() error { + ch <- struct{}{} + return nil + } + for i := 0; i < n; i++ { + if !g.TryGo(fn) { + t.Fatalf("TryGo should succeed but got fail at %d-th call.", i) + } + } + if g.TryGo(fn) { + t.Fatalf("TryGo is expected to fail but succeeded.") + } + go func() { + for i := 0; i < n; i++ { + <-ch + } + }() + g.Wait() + + if !g.TryGo(fn) { + t.Fatalf("TryGo should success but got fail after all goroutines.") + } + go func() { <-ch }() + g.Wait() + + // Switch limit. + g.SetLimit(1) + if !g.TryGo(fn) { + t.Fatalf("TryGo should success but got failed.") + } + if g.TryGo(fn) { + t.Fatalf("TryGo should fail but succeeded.") + } + go func() { <-ch }() + g.Wait() + + // Block all calls. + g.SetLimit(0) + for i := 0; i < 1<<10; i++ { + if g.TryGo(fn) { + t.Fatalf("TryGo should fail but got succeded.") + } + } + g.Wait() +} + +func TestGoLimit(t *testing.T) { + const limit = 10 + + g := &errgroup.Group{} + g.SetLimit(limit) + var active int32 + for i := 0; i <= 1<<10; i++ { + g.Go(func() error { + n := atomic.AddInt32(&active, 1) + if n > limit { + return fmt.Errorf("saw %d active goroutines; want ≤ %d", n, limit) + } + time.Sleep(1 * time.Microsecond) // Give other goroutines a chance to increment active. + atomic.AddInt32(&active, -1) + return nil + }) + } + if err := g.Wait(); err != nil { + t.Fatal(err) + } +} + +func BenchmarkGo(b *testing.B) { + fn := func() {} + g := &errgroup.Group{} + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + g.Go(func() error { fn(); return nil }) + } + g.Wait() +} diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index 690eb85..8473fb7 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -52,10 +52,6 @@ type call struct { val interface{} err error - // forgotten indicates whether Forget was called with this call's key - // while the call was still in flight. - forgotten bool - // These fields are read and written with the singleflight // mutex held before the WaitGroup is done, and are read but // not written after the WaitGroup is done. @@ -148,10 +144,10 @@ func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { c.err = errGoexit } - c.wg.Done() g.mu.Lock() defer g.mu.Unlock() - if !c.forgotten { + c.wg.Done() + if g.m[key] == c { delete(g.m, key) } @@ -204,9 +200,6 @@ func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { // an earlier call to complete. func (g *Group) Forget(key string) { g.mu.Lock() - if c, ok := g.m[key]; ok { - c.forgotten = true - } delete(g.m, key) g.mu.Unlock() } diff --git a/syncmap/map_test.go b/syncmap/map_test.go index c883f17..bf69f50 100644 --- a/syncmap/map_test.go +++ b/syncmap/map_test.go @@ -115,7 +115,7 @@ func TestConcurrentRange(t *testing.T) { m := new(syncmap.Map) for n := int64(1); n <= mapSize; n++ { - m.Store(n, int64(n)) + m.Store(n, n) } done := make(chan struct{}) |