深入理解 sync.WaitGroup
# sync 包——WaitGroup
WaitGroup等待goroutines的集合完成。主goroutine调用Add来设置要等待的goroutine的数量。然后运行每个goroutine,并在完成时调用Done。同时,Wait可以用来阻塞直到所有goroutine完成。
WaitGroup使用的常见场景:
- 多任务处理,多个
goroutine处理小任务,主goroutine等待所有任务完成后合并这些任务处理的结果 - 主任务需要等待所有小任务(
goroutine)完成后才能进入下一步

# WaitGroup 的使用
下面程序执行 10 个任务(goroutine),主goroutine调用wg.Wait()阻塞等待所有的goroutine完成,完成之后就会执行wg.Wait()后面的代码:
func TestWaiGroup(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
// 每增加一个 goroutine 都要调用 Add 加 1
wg.Add(1)
go func(i int) {
// goroutine 执行完一定要调用 Done,即 Add(-1)
defer wg.Done()
fmt.Println("task", i, "done")
}(i)
}
// 主 goroutine 等待所有子 goroutine 完成
wg.Wait()
fmt.Println("all task done")
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
从例子可以看出,WaitGroup是用于同步多个goroutine之间的工作:
- 要在开启子
goroutine之前先加 1,即执行wg.Add(1) - 每一个小任务完成后,子
goroutine要减 1,即执行wg.Done() - 主
goroutine调用Wait方法来等待所有子任务完成
容易犯错的地方是 +1 和 -1 不匹配(非常不好测试):
- 加多了导致
Wait一直阻塞,引起goroutine泄露 - 减多了直接就
panic
# WaiGroup 的实现
WaitGroup从使用方式来看,就知道要实现类似功能,至少需要:
- 记住当前有多少个任务还没完成
- 记住当前有多少
goroutine调用了wait方法 - 然后需要一个东西来协调
goroutine的行为
所以,按照道理来说,我们需要设计三个字段来承载这个功能,然后搞个锁来维护这三个字段就可以了。

# WaitGroup的定义
noCopy:主要用于告诉编译器说中国东西不能复制。state1:在 64 位下,高 32 位记录了还有多少任务在运行;低 32 位记录了有多少goroutine在等Wait()方法返回state2:信号量,用于挂起或者唤醒goroutine,约等于Mutex里面的sema字段
type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers only guarantee that 64-bit fields are 32-bit aligned.
// For this reason on 32 bit architectures we need to check in state()
// if state1 is aligned or not, and dynamically "swap" the field order if
// needed.
state1 uint64
state2 uint32
}
1
2
3
4
5
6
7
8
9
10
11
12
2
3
4
5
6
7
8
9
10
11
12
WaitGroup支持的方法:
Add(delta int):将state1的高 32 位自增 1,原子操作Done():将state1的高 32 位自减 1,原子操作,然后看看要不要调用runtime_Semrelease唤醒等待中的goroutine。相当于Add(-1)。Wait():state1的低 32 位自增 1,同时利用state2和runtime_Semacquire调用把当前goroutine挂起
# Add 方法
func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state() // 解决 32 位对齐
// ...
state := atomic.AddUint64(statep, uint64(delta)<<32) // 操作高 32 位
v := int32(state >> 32)
w := uint32(state)
// ...
// Reset waiters count to 0.
*statep = 0
for ; w != 0; w-- { // 计数降为 0 了就要唤醒等待的 goroutine
runtime_Semrelease(semap, false, 0) // 唤醒 goroutine
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
2
3
4
5
6
7
8
9
10
11
12
13
# Wait 方法
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
// ...
for {
state := atomic.LoadUint64(statep)
v := int32(state >> 32) // 操作低 32 位
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
if atomic.CompareAndSwapUint64(statep, state, state+1) { // CAS 自增 1,可以防止在这个过程中任务计数变了
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(semap))
}
runtime_Semacquire(semap) // 挂起 goroutine
// ...
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
唯一要注意的就是这里使用的是 CAS,因为高 32 位可能也在操作。而前面Add方法可以用原子操作,是因为Add方法不关心等待者的数量。只有在唤醒goroutine的时候才会考虑等待者数量,但是这个数量是从原子操作的返回值里面解析出来。
# runtime_Semrelease和runtime_Semacquire
// Semacquire waits until *s > 0 and then atomically decrements it.
// It is intended as a simple sleep primitive for use by the synchronization
// library and should not be used directly.
func runtime_Semacquire(s *uint32)
// Semrelease atomically increments *s and notifies a waiting goroutine
// if one is blocked in Semacquire.
// It is intended as a simple wakeup primitive for use by the synchronization
// library and should not be used directly.
// If handoff is true, pass count directly to the first waiter.
// skipframes is the number of frames to omit during tracing, counting from
// runtime_Semrelease's caller.
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)
1
2
3
4
5
6
7
8
9
10
11
12
13
2
3
4
5
6
7
8
9
10
11
12
13
# errgroup 包
WaitGroup和errgroup.Group是很相似的,errgroup.Group是对WaitGroup的封装。
- 首先需要引入
golang.org/x/sync/errgroup依赖 errgroup.Group会帮我们保持进行中任务计数- 任何一个任务返回
error,Wait方法就会返回error
Group的定义:
type Group struct {
cancel func()
wg sync.WaitGroup
sem chan token
errOnce sync.Once
err error
}
1
2
3
4
5
6
7
8
9
10
2
3
4
5
6
7
8
9
10
errgroup的使用例子:
import (
"fmt"
"net/http"
"golang.org/x/sync/errgroup"
)
func main() {
g := new(errgroup.Group)
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.somestupidname.com/",
}
for _, url := range urls {
// Launch a goroutine to fetch the URL.
url := url // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
// Fetch the URL.
resp, err := http.Get(url)
if err == nil {
resp.Body.Close()
}
return err
})
}
// Wait for all HTTP fetches to complete.
if err := g.Wait(); err == nil {
fmt.Println("Successfully fetched all URLs.")
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 参考
- https://pkg.go.dev/golang.org/x/sync/errgroup
- https://pkg.go.dev/sync@go1.20.1
上次更新: 2024/05/29, 06:25:22