深入理解 sync.WaitGroup
# sync 包——WaitGroup
WaitGroup
等待goroutines
的集合完成。主goroutine
调用Add
来设置要等待的goroutine
的数量。然后运行每个goroutine
,并在完成时调用Done
。同时,Wait
可以用来阻塞直到所有goroutine
完成。
WaitGroup
使用的常见场景:
- 多任务处理,多个
goroutine
处理小任务,主goroutine
等待所有任务完成后合并这些任务处理的结果 - 主任务需要等待所有小任务(
goroutine
)完成后才能进入下一步
# WaitGroup 的使用
下面程序执行 10 个任务(goroutine
),主goroutine
调用wg.Wait()
阻塞等待所有的goroutine
完成,完成之后就会执行wg.Wait()
后面的代码:
func TestWaiGroup(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
// 每增加一个 goroutine 都要调用 Add 加 1
wg.Add(1)
go func(i int) {
// goroutine 执行完一定要调用 Done,即 Add(-1)
defer wg.Done()
fmt.Println("task", i, "done")
}(i)
}
// 主 goroutine 等待所有子 goroutine 完成
wg.Wait()
fmt.Println("all task done")
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
从例子可以看出,WaitGroup
是用于同步多个goroutine
之间的工作:
- 要在开启子
goroutine
之前先加 1,即执行wg.Add(1)
- 每一个小任务完成后,子
goroutine
要减 1,即执行wg.Done()
- 主
goroutine
调用Wait
方法来等待所有子任务完成
容易犯错的地方是 +1 和 -1 不匹配(非常不好测试):
- 加多了导致
Wait
一直阻塞,引起goroutine
泄露 - 减多了直接就
panic
# WaiGroup 的实现
WaitGroup
从使用方式来看,就知道要实现类似功能,至少需要:
- 记住当前有多少个任务还没完成
- 记住当前有多少
goroutine
调用了wait
方法 - 然后需要一个东西来协调
goroutine
的行为
所以,按照道理来说,我们需要设计三个字段来承载这个功能,然后搞个锁来维护这三个字段就可以了。
# WaitGroup
的定义
noCopy
:主要用于告诉编译器说中国东西不能复制。state1
:在 64 位下,高 32 位记录了还有多少任务在运行;低 32 位记录了有多少goroutine
在等Wait()
方法返回state2
:信号量,用于挂起或者唤醒goroutine
,约等于Mutex
里面的sema
字段
type WaitGroup struct {
noCopy noCopy
// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
// 64-bit atomic operations require 64-bit alignment, but 32-bit
// compilers only guarantee that 64-bit fields are 32-bit aligned.
// For this reason on 32 bit architectures we need to check in state()
// if state1 is aligned or not, and dynamically "swap" the field order if
// needed.
state1 uint64
state2 uint32
}
1
2
3
4
5
6
7
8
9
10
11
12
2
3
4
5
6
7
8
9
10
11
12
WaitGroup
支持的方法:
Add(delta int)
:将state1
的高 32 位自增 1,原子操作Done()
:将state1
的高 32 位自减 1,原子操作,然后看看要不要调用runtime_Semrelease
唤醒等待中的goroutine
。相当于Add(-1)
。Wait()
:state1
的低 32 位自增 1,同时利用state2
和runtime_Semacquire
调用把当前goroutine
挂起
# Add 方法
func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state() // 解决 32 位对齐
// ...
state := atomic.AddUint64(statep, uint64(delta)<<32) // 操作高 32 位
v := int32(state >> 32)
w := uint32(state)
// ...
// Reset waiters count to 0.
*statep = 0
for ; w != 0; w-- { // 计数降为 0 了就要唤醒等待的 goroutine
runtime_Semrelease(semap, false, 0) // 唤醒 goroutine
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
2
3
4
5
6
7
8
9
10
11
12
13
# Wait 方法
// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
// ...
for {
state := atomic.LoadUint64(statep)
v := int32(state >> 32) // 操作低 32 位
w := uint32(state)
if v == 0 {
// Counter is 0, no need to wait.
if race.Enabled {
race.Enable()
race.Acquire(unsafe.Pointer(wg))
}
return
}
// Increment waiters count.
if atomic.CompareAndSwapUint64(statep, state, state+1) { // CAS 自增 1,可以防止在这个过程中任务计数变了
if race.Enabled && w == 0 {
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(semap))
}
runtime_Semacquire(semap) // 挂起 goroutine
// ...
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
唯一要注意的就是这里使用的是 CAS,因为高 32 位可能也在操作。而前面Add
方法可以用原子操作,是因为Add
方法不关心等待者的数量。只有在唤醒goroutine
的时候才会考虑等待者数量,但是这个数量是从原子操作的返回值里面解析出来。
# runtime_Semrelease
和runtime_Semacquire
// Semacquire waits until *s > 0 and then atomically decrements it.
// It is intended as a simple sleep primitive for use by the synchronization
// library and should not be used directly.
func runtime_Semacquire(s *uint32)
// Semrelease atomically increments *s and notifies a waiting goroutine
// if one is blocked in Semacquire.
// It is intended as a simple wakeup primitive for use by the synchronization
// library and should not be used directly.
// If handoff is true, pass count directly to the first waiter.
// skipframes is the number of frames to omit during tracing, counting from
// runtime_Semrelease's caller.
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)
1
2
3
4
5
6
7
8
9
10
11
12
13
2
3
4
5
6
7
8
9
10
11
12
13
# errgroup 包
WaitGroup
和errgroup.Group
是很相似的,errgroup.Group
是对WaitGroup
的封装。
- 首先需要引入
golang.org/x/sync/errgroup
依赖 errgroup.Group
会帮我们保持进行中任务计数- 任何一个任务返回
error
,Wait
方法就会返回error
Group
的定义:
type Group struct {
cancel func()
wg sync.WaitGroup
sem chan token
errOnce sync.Once
err error
}
1
2
3
4
5
6
7
8
9
10
2
3
4
5
6
7
8
9
10
errgroup
的使用例子:
import (
"fmt"
"net/http"
"golang.org/x/sync/errgroup"
)
func main() {
g := new(errgroup.Group)
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.somestupidname.com/",
}
for _, url := range urls {
// Launch a goroutine to fetch the URL.
url := url // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
// Fetch the URL.
resp, err := http.Get(url)
if err == nil {
resp.Body.Close()
}
return err
})
}
// Wait for all HTTP fetches to complete.
if err := g.Wait(); err == nil {
fmt.Println("Successfully fetched all URLs.")
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 参考
- https://pkg.go.dev/golang.org/x/sync/errgroup
- https://pkg.go.dev/sync@go1.20.1