我的学习日志 我的学习日志
首页
  • Go

    • Go基础知识
  • Python

    • Python进阶
  • 操作系统
  • 计算机网络
  • MySQL
  • 学习笔记
  • 常用到的算法
其他技术
  • 友情链接
  • 收藏
关于
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

xuqil

一介帆夫
首页
  • Go

    • Go基础知识
  • Python

    • Python进阶
  • 操作系统
  • 计算机网络
  • MySQL
  • 学习笔记
  • 常用到的算法
其他技术
  • 友情链接
  • 收藏
关于
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • 环境部署

  • 测试

  • 反射

  • 数据库操作

  • 并发编程

    • 解密 go Context 包
    • 深入了解Mutex和RWMutex
    • atomic 的使用
    • sync.Once 的使用
    • 深入理解 sync.Pool
    • 深入理解 sync.WaitGroup
      • WaitGroup 的使用
      • WaiGroup 的实现
        • WaitGroup的定义
        • Add 方法
        • Wait 方法
        • runtime_Semrelease和runtime_Semacquire
      • errgroup 包
      • 参考
    • 深入理解 channel
  • 内存管理

  • Go 技巧

  • 《go基础知识》
  • 并发编程
Xu Qil
2023-03-01
0
目录

深入理解 sync.WaitGroup

# sync 包——WaitGroup

WaitGroup等待goroutines的集合完成。主goroutine调用Add来设置要等待的goroutine的数量。然后运行每个goroutine,并在完成时调用Done。同时,Wait可以用来阻塞直到所有goroutine完成。

WaitGroup使用的常见场景:

  • 多任务处理,多个goroutine处理小任务,主goroutine等待所有任务完成后合并这些任务处理的结果
  • 主任务需要等待所有小任务(goroutine)完成后才能进入下一步

image-20230301214329132

# WaitGroup 的使用

下面程序执行 10 个任务(goroutine),主goroutine调用wg.Wait()阻塞等待所有的goroutine完成,完成之后就会执行wg.Wait()后面的代码:

func TestWaiGroup(t *testing.T) {
	var wg sync.WaitGroup

	for i := 0; i < 10; i++ {
		// 每增加一个 goroutine 都要调用 Add 加 1
		wg.Add(1)
		go func(i int) {
			// goroutine 执行完一定要调用 Done,即 Add(-1)
			defer wg.Done()
			fmt.Println("task", i, "done")
		}(i)
	}
	// 主 goroutine 等待所有子 goroutine 完成
	wg.Wait()
	fmt.Println("all task done")
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

从例子可以看出,WaitGroup是用于同步多个goroutine之间的工作:

  • 要在开启子goroutine之前先加 1,即执行wg.Add(1)
  • 每一个小任务完成后,子goroutine要减 1,即执行wg.Done()
  • 主goroutine调用Wait方法来等待所有子任务完成

容易犯错的地方是 +1 和 -1 不匹配(非常不好测试):

  • 加多了导致Wait一直阻塞,引起goroutine泄露
  • 减多了直接就panic

# WaiGroup 的实现

WaitGroup从使用方式来看,就知道要实现类似功能,至少需要:

  • 记住当前有多少个任务还没完成
  • 记住当前有多少goroutine调用了wait方法
  • 然后需要一个东西来协调goroutine的行为

所以,按照道理来说,我们需要设计三个字段来承载这个功能,然后搞个锁来维护这三个字段就可以了。

image-20230301214329132

# WaitGroup的定义

  • noCopy:主要用于告诉编译器说中国东西不能复制。
  • state1:在 64 位下,高 32 位记录了还有多少任务在运行;低 32 位记录了有多少goroutine在等Wait()方法返回
  • state2:信号量,用于挂起或者唤醒goroutine,约等于Mutex里面的sema字段
type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers only guarantee that 64-bit fields are 32-bit aligned.
	// For this reason on 32 bit architectures we need to check in state()
	// if state1 is aligned or not, and dynamically "swap" the field order if
	// needed.
	state1 uint64
	state2 uint32
}
1
2
3
4
5
6
7
8
9
10
11
12

WaitGroup支持的方法:

  • Add(delta int):将state1的高 32 位自增 1,原子操作
  • Done():将state1的高 32 位自减 1,原子操作,然后看看要不要调用runtime_Semrelease唤醒等待中的goroutine。相当于Add(-1)。
  • Wait():state1的低 32 位自增 1,同时利用state2和runtime_Semacquire调用把当前goroutine挂起

# Add 方法

func (wg *WaitGroup) Add(delta int) {
	statep, semap := wg.state() // 解决 32 位对齐
	// ...
	state := atomic.AddUint64(statep, uint64(delta)<<32) // 操作高 32 位
	v := int32(state >> 32)
	w := uint32(state)
	// ...
	// Reset waiters count to 0.
	*statep = 0
	for ; w != 0; w-- { // 计数降为 0 了就要唤醒等待的 goroutine
		runtime_Semrelease(semap, false, 0) // 唤醒 goroutine
	}
}
1
2
3
4
5
6
7
8
9
10
11
12
13

# Wait 方法

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
   // ...
   for {
      state := atomic.LoadUint64(statep)
      v := int32(state >> 32) // 操作低 32 位
      w := uint32(state)
      if v == 0 {
         // Counter is 0, no need to wait.
         if race.Enabled {
            race.Enable()
            race.Acquire(unsafe.Pointer(wg))
         }
         return
      }
      // Increment waiters count.
      if atomic.CompareAndSwapUint64(statep, state, state+1) { // CAS 自增 1,可以防止在这个过程中任务计数变了
         if race.Enabled && w == 0 {
            // Wait must be synchronized with the first Add.
            // Need to model this is as a write to race with the read in Add.
            // As a consequence, can do the write only for the first waiter,
            // otherwise concurrent Waits will race with each other.
            race.Write(unsafe.Pointer(semap))
         }
         runtime_Semacquire(semap) // 挂起 goroutine
         // ...
   }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

唯一要注意的就是这里使用的是 CAS,因为高 32 位可能也在操作。而前面Add方法可以用原子操作,是因为Add方法不关心等待者的数量。只有在唤醒goroutine的时候才会考虑等待者数量,但是这个数量是从原子操作的返回值里面解析出来。

# runtime_Semrelease和runtime_Semacquire

// Semacquire waits until *s > 0 and then atomically decrements it.
// It is intended as a simple sleep primitive for use by the synchronization
// library and should not be used directly.
func runtime_Semacquire(s *uint32)

// Semrelease atomically increments *s and notifies a waiting goroutine
// if one is blocked in Semacquire.
// It is intended as a simple wakeup primitive for use by the synchronization
// library and should not be used directly.
// If handoff is true, pass count directly to the first waiter.
// skipframes is the number of frames to omit during tracing, counting from
// runtime_Semrelease's caller.
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)
1
2
3
4
5
6
7
8
9
10
11
12
13

# errgroup 包

WaitGroup和errgroup.Group是很相似的,errgroup.Group是对WaitGroup的封装。

  • 首先需要引入golang.org/x/sync/errgroup依赖
  • errgroup.Group会帮我们保持进行中任务计数
  • 任何一个任务返回error,Wait方法就会返回error

Group的定义:

type Group struct {
   cancel func()

   wg sync.WaitGroup

   sem chan token

   errOnce sync.Once
   err     error
}
1
2
3
4
5
6
7
8
9
10

errgroup的使用例子:

import (
	"fmt"
	"net/http"

	"golang.org/x/sync/errgroup"
)

func main() {
	g := new(errgroup.Group)
	var urls = []string{
		"http://www.golang.org/",
		"http://www.google.com/",
		"http://www.somestupidname.com/",
	}
	for _, url := range urls {
		// Launch a goroutine to fetch the URL.
		url := url // https://golang.org/doc/faq#closures_and_goroutines
		g.Go(func() error {
			// Fetch the URL.
			resp, err := http.Get(url)
			if err == nil {
				resp.Body.Close()
			}
			return err
		})
	}
	// Wait for all HTTP fetches to complete.
	if err := g.Wait(); err == nil {
		fmt.Println("Successfully fetched all URLs.")
	}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

# 参考

  • https://pkg.go.dev/golang.org/x/sync/errgroup
  • https://pkg.go.dev/sync@go1.20.1

评论

深入理解 sync.Pool
深入理解 channel

← 深入理解 sync.Pool 深入理解 channel→

最近更新
01
Golang 逃逸分析
03-22
02
深入理解 channel
03-04
03
深入理解 sync.Pool
02-28
更多文章>
Theme by Vdoing | Copyright © 2018-2023 FeelingLife | 粤ICP备2022093535号-1
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式