Golang 標準庫 tips之waitgroup詳解

WaitGroup 用於線程同步,很多場景下為瞭提高並發需要開多個協程執行,但是又需要等待多個協程的結果都返回的情況下才進行後續邏輯處理,這種情況下可以通過 WaitGroup 提供的方法阻塞主線程的執行,直到所有的 goroutine 執行完成。
本文目錄結構:

WaitGroup 不能被值拷貝
Add 需要在 Wait 之前調用
使用 channel 實現 WaitGroup 的功能
Add 和 Done 數量問題
WaitGroup 和 channel 控制並發數
WaitGroup 和 channel 實現提前退出
WaitGroup 和 channel 返回錯誤
使用 ErrGroup 返回錯誤
使用 ErrGroup 實現提前退出
改善版的 Errgroup

WaitGroup 不能被值拷貝

wg 作為一個參數傳遞的時候,我們在函數中操作的時候還是操作的一個拷貝的變量,對於原來的 wg 是不會改變。
這一點可以從 WaitGroup 實現的源碼定義的 struct 能能看出來,WaitGroup 的 struct 就兩個字段,第一個字段就是 noCopy,表明這個結構體是不希望直接被復制的。noCopy 是的實現是一個空的 struct{},主要的作用是嵌入到結構體中作為輔助 vet 工具檢查是否通過 copy 賦值這個 WaitGroup 實例,如果有值拷貝的情況,會被檢測出來,我們一般的 lint 工具也都能檢測出來。
在某些情況下,如果 WaitGroup 需要作為參數傳遞到其他的方法中,一定需要使用指針類型進行傳遞。

type WaitGroup struct {
    noCopy noCopy

    // 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
    // 64-bit atomic operations require 64-bit alignment, but 32-bit
    // compilers do not ensure it. So we allocate 12 bytes and then use
    // the aligned 8 bytes in them as state, and the other 4 as storage
    // for the sema.
    state1 [3]uint32
}

可以用以下一個例子來說明:

// 錯誤的用法,函數傳遞 wg 是值拷貝
func main() {
    wg := sync.WaitGroup{}

    wg.Add(10)

    for i := 0; i < 10; i++ {
        go func(i int) {
            do(i, wg)
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

func do(i int, wg sync.WaitGroup) { // wg 值拷貝,會導致程序
    fmt.Println(i)
    wg.Done()
}

// 正確的用法,waitgroup 參數傳遞使用指針的形式
func main() {
    wg := sync.WaitGroup{}

    wg.Add(10)

    for i := 0; i < 10; i++ {
        go func(i int) {
            do(i, &wg)
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

func do(i int, wg *sync.WaitGroup) {
    fmt.Println(i)
    wg.Done()
}

Add 需要在 Wait 之前調用

WaitGroup 結構體提供瞭三個方法,Add、Done、Wait,Add 的作用是用來設置WaitGroup的計數值(子goroutine的數量);Done的作用用來將 WaitGroup 的計數值減 1,其實就是調用Add(-1);Wait 的作用是檢測 WaitGroup 計數器的值是否為 0,如果為 0 表示所有的 goroutine 都運行完成,否則會阻塞等待計數器的值為0(所有的 groutine都執行完成)之後才運行後面的代碼。
所以在 WaitGroup 調用的時候一定要保障 Add 函數在 Wait 函數之前執行,否則可能會導致 Wait 方法沒有等到所有的結果運行完成而被執行完。也就是我們不能在 Grountine 中來執行 Add 和 Done,這樣可能當前 Grountine 來不及運行,外層的 Wait 函數檢測到滿足條件然後退出瞭。

func main() {
    wg := sync.WaitGroup{}
    wg.Wait() // 直接調用 Wait() 方法是不會阻塞的,因為 wg 中 goroutine 計數器的值為 0
    fmt.Println("success")
}
// 錯誤的寫法,在 goroutine 中進行 Add(1) 操作。
// 可能在這些 goroutine 還沒來得及 Add(1) 就已經執行 Wait 操作瞭
func main() {
    wg := sync.WaitGroup{}

    for i := 0; i < 10; i++ {
        go func(i int) {
            wg.Add(1)
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 打印的結果,不是我們預期的打印 10 個元素之後再打印 success,而是會隨機打印其中的一部分
success
1
0
5
2

// 正確的寫法一
func main() {
    wg := sync.WaitGroup{}
    wg.Add(10) // 在 groutine 外層先把需要運行的 goroutine 的數量設置好,保障比 Wait 函數先執行

    for i := 0; i < 10; i++ {
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 正確的寫法二
func main() {
    wg := sync.WaitGroup{}

    for i := 0; i < 10; i++ {
        wg.Add(1) // 保障比 Wait 函數先執行
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

使用 channel 實現 WaitGroup 的功能

如果想要實現主線程中等待多個協程的結果都返回的情況下才進行後續調用,也可以通過帶緩存區的 channel 來實現,實現的思路是需要先知道等待 groutine 的運行的數量,然後初始化一個相同緩存區數量的 channel,在 groutine 運行結束之後往 channel 中放入一個值,並在主線程中阻塞監聽獲取 channel 中的值全部返回。

func main() {
    numGroutine := 10
    ch := make(chan struct{}, numGroutine)

    for i := 0; i < numGroutine; i++ {
        go func(i int) {
            fmt.Println(i)
            ch <- struct{}{}
        }(i)
    }

    for i := 0; i < numGroutine; i++ {
        <-ch
    }

    fmt.Println("success")
}

// 打印結果:
7
5
3
1
9
0
4
2
6
8
success

Add 和 Done 數量問題

需要保障 Add 的數量和 Done 的數量一致,如果 Add 數量小於 Done 數量的情況下,調用 Wait 方法會檢測到計數器的值為負數,程序會報 panic;如果 Add 數量大於 Done 的數量,會導致 Wait 循環阻塞後面的代碼得不到執行。
Add 數量小於 Done 數量:

func main() {
    wg := sync.WaitGroup{}
    wg.Add(1) // Add 數量小於 Done 數量

    for i := 0; i < 10; i++ {
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 運行結果,有兩種結果
結果一:打印部分輸出然後退出,這種情況是因為 Done 執行瞭一個隻會,Wait 檢測到剛好滿足條件然後退出瞭
1
success
9
5

結果二:執行 Wait 函數的時候,計數器的值已經是負數瞭
0
9
3
panic: sync: negative WaitGroup counter

Add 數量大於 Done 數量:

func main() {
    wg := sync.WaitGroup{}
    wg.Add(20)

    for i := 0; i < 10; i++ {
        go func(i int) {
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

// 執行結果:deadlock
0
9
3
7
8
1
4
2
6
5
fatal error: all goroutines are asleep - deadlock!

WaitGroup 和 channel 控制並發數

使用 waitgroup 可以控制一組 groutine 同時運行並等待結果返回之後再進行後續操作,雖然 groutine 對資源消耗比較小,但是大量的 groutine 並發對系統的壓力還是比較大,所以這種情況如果需要控制 waitgroup 中 groutine 並發數量控制,就可以使用緩存的 channel 控制同時並發的 groutine 數量。

func main() {
    wg := sync.WaitGroup{}
    wg.Add(200)

    ch := make(chan struct{}, 10) // 控制最大並發數是 10
 
    for i := 0; i < 200; i++ {
        ch <- struct{}{}
        go func(i int) {
            fmt.Println(i)
            wg.Done()
            <-ch
        }(i)
    }

    wg.Wait()
    fmt.Println("success")
}

根據使用 channel 實現 WaitGroup 的功能的思路,我們上面的代碼也可以通過兩個 channel 進行改造來實現。

func main() {
    numGroutine := 200 // 運行的 groutine 總數量
    numParallel := 10  // 並發的 groutine 數量

    chTotal := make(chan struct{}, numGroutine)
    chParallel := make(chan struct{}, numParallel)

    for i := 0; i < 200; i++ {
        chTotal <- struct{}{}
        go func(i int) {
            fmt.Println(i)
            <-chTotal
            chParallel <- struct{}{}
        }(i)
    }

    for i := 0; i < numGroutine; i++ {
        <-chParallel
    }
    fmt.Println("success")
}

WaitGroup 和 channel 實現提前退出

用 WaitGroup 協調一組並發 goroutine 的做法很常見,但 WaitGroup 本身也有其不足:
WaitGroup 必須要等待控制的一組 goroutine 全部返回結果之後才往下運行,但是有的情況下我們希望能快速失敗,也就是這一組 goroutine 中隻要有一個失敗瞭,那麼就不應該等到所有 goroutine 結束再結束任務,而是提前結束以避免資源浪費,這個時候就可以使用 channel 配合 WaitGroup 實現提前退出的效果。

func main() {
    wg := sync.WaitGroup{}
    wg.Add(10)

    ch := make(chan struct{}) // 使用一個 channel 傳遞退出信號

    for i := 0; i < 10; i++ {
        go func(i int) {
            time.Sleep(time.Duration(i) * time.Second)
            fmt.Println(i)
            if i == 2 { // 檢測到 i==2 則提前退出
                ch <- struct{}{}
            }
            wg.Done()
        }(i)
    }

    go func() {
        wg.Wait()        // wg.Wait 執行之後表示所有的 groutine 都已經執行完成瞭,而且沒有 groutine 往 ch 傳遞退出信號
        ch <- struct{}{} // 需要傳遞一個信號,不然主線程會一直阻塞
    }()

    <-ch // 阻塞等待收到退出信號之後往下執行

    fmt.Println("success")
}

// 打印結果
0
1
2
success

WaitGroup 和 channel 返回錯誤

WaitGroup 除瞭不能快速失敗之外還有一個問題就是不能在主線程中獲取到 groutine 出錯時返回的錯誤,這種情況下就可以用到 channel 進行錯誤傳遞,在主線程中獲取到錯誤。

// 案例一:groutine 中隻要有一個失敗瞭則返回 err 並且回到主協程運行後續代碼
func main() {
    wg := sync.WaitGroup{}
    wg.Add(10)

    ch := make(chan error) // 使用一個 channel 傳遞退出信號

    for i := 0; i < 10; i++ {
        go func(i int) {
            time.Sleep(time.Duration(i) * time.Second)
            if i == 2 { // 檢測到 i==2 則提前退出
                ch <- fmt.Errorf("i can't be 2")
                close(ch)
                return
            }
            fmt.Println(i)
            wg.Done()
        }(i)
    }

    go func() {
        wg.Wait() // wg.Wait 執行之後表示所有的 groutine 都已經執行完成瞭,而且沒有 groutine 往 ch 傳遞退出信號
        ch <- nil // 需要傳遞一個 nil error,不然主線程會一直阻塞
        close(ch)
    }()

    err := <-ch
    fmt.Println(err.Error())
}

// 運行結果:
/*
0
1
i can't be 2
*/

// 案例二:等待所有的 groutine 都運行完成再回到主線程並捕獲所有的 error
func main() {
    wg := sync.WaitGroup{}
    wg.Add(10)

    ch := make(chan error, 10) // 設置和 groutine 數量一致,可以緩沖最多 10 個 error

    for i := 0; i < 10; i++ {
        go func(i int) {
            defer func() {
                wg.Done()
            }()
            time.Sleep(time.Duration(i) * time.Second)
            if i == 2 {
                ch <- fmt.Errorf("i can't be 2")
                return
            }
            if i == 3 {
                ch <- fmt.Errorf("i can't be 3")
                return
            }
            fmt.Println(i)
        }(i)
    }

    wg.Wait() // wg.Wait 執行之後表示所有的 groutine 都已經執行完成瞭
    close(ch) // 需要 close channel,不然主線程會阻塞

    for err := range ch {
        fmt.Println(err.Error())
    }
}

// 打印結果:
0
1
4
5
6
7
8
9
i can't be 2
i can't be 3

使用 ErrGroup 返回錯誤

正是由於 WaitGroup 有以上說的一些缺點,Go 團隊在實驗倉庫(golang.org/x)增加瞭 errgroup.Group 的功能,相比 WaitGroup 增加瞭錯誤傳遞、快速失敗、超時取消等功能,相對於通過 channel 和 WaitGroup 組合實現這些功能更方便,也更加推薦。
errgroup.Group 結構體也比較簡單,在 sync.WaitGroup 的基礎之上包裝瞭一個 error 以及一個 cancel 方法,err 的作用是在 goroutine 出錯的時候能夠返回,cancel 方法的作用是在出錯的時候快速失敗。
errgroup.Group 對外暴露瞭3個方法,WithContext、Go、Wait,沒有瞭 Add、Done 方法,其實 Add 和 Done 是在包裝在瞭 errgroup.Group 的 Go 方法裡面瞭,我們執行的時候不需要關心。

// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
    cancel func()

    wg sync.WaitGroup

    errOnce sync.Once
    err     error
}

func WithContext(ctx context.Context) (*Group, context.Context) {
    ctx, cancel := context.WithCancel(ctx)
    return &Group{cancel: cancel}, ctx
}

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
    g.wg.Wait()
    if g.cancel != nil {
        g.cancel()
    }
    return g.err
}

// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
    g.wg.Add(1)

    go func() {
        defer g.wg.Done()

        if err := f(); err != nil {
            g.errOnce.Do(func() {
                g.err = err
                if g.cancel != nil {
                    g.cancel()
                }
            })
        }
    }()
}

以下是使用 errgroup.Group 來實現返回 goroutine 錯誤的例子:

func main() {
    eg := errgroup.Group{}

    for i := 0; i < 10; i++ {
        i := i // 這裡需要進行賦值操作,不然會有閉包問題,eg.Go 執行的 groutine 會引用 for 循環的 i
        eg.Go(func() error {
            if i == 2 {
                return fmt.Errorf("i can't be 2")
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }
}

// 打印結果
9
6
7
8
3
4
1
5
0
i can't be 2

需要註意的一點是通過 errgroup.Group 來返回 err 隻會返回其中一個 groutine 的錯誤,而且是最先返回 err 的 groutine 的錯誤,這一點是通過 errgroup.Group 的 errOnce 來實現的。

使用 ErrGroup 實現提前退出

使用 errgroup.Group 實現提前退出也比較簡單,調用 errgroup.WithContext 方法獲取 errgroup.Group 對象以及一個可以取消的 WithCancel 的 context,並且將這個 context 方法傳入到所有的 groutine 中,並在 groutine 中使用 select 監聽這個 context 的 Done() 事件,如果監聽到瞭表明接收到瞭 cancel 信號,然後退出 groutine 即可。需要註意的是 eg.Go 一定要返回一個 err 才會觸發 errgroup.Group 執行 cancel 方法。

// 案例一:通過 groutine 顯示返回 err 觸發 errgroup.Group 底層的 cancel 方法
func main() {
    ctx := context.Background()
    eg, ctx := errgroup.WithContext(ctx)

    for i := 0; i < 10; i++ {
        i := i // 這裡需要進行賦值操作,不然會有閉包問題,eg.Go 執行的 groutine 會引用 for 循環的 i
        eg.Go(func() error {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(time.Duration(i) * time.Second):
            }
            if i == 2 {
                return fmt.Errorf("i can't be 2") // 需要返回 err 才會導致 eg 的 cancel 方法
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }
}

// 打印結果:
0
1
i can't be 2

// 案例二:通過顯示調用 cancel 方法通知到各個 groutine 退出
func main() {
    ctx, cancel := context.WithCancel(context.Background())
    eg, ctx := errgroup.WithContext(ctx)

    for i := 0; i < 10; i++ {
        i := i // 這裡需要進行賦值操作,不然會有閉包問題,eg.Go 執行的 groutine 會引用 for 循環的 i
        eg.Go(func() error {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(time.Duration(i) * time.Second):
            }
            if i == 2 {
                cancel()
                return nil // 可以不用返回 err,因為手動觸發瞭 cancel 方法
                //return fmt.Errorf("i can't be 2")
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }
}

// 打印結果:
0
1
context canceled


// 案例三:
// 基於 errgroup 實現一個 http server 的啟動和關閉 ,以及 linux signal 信號的註冊和處理,要保證能夠 一個退出,全部註銷退出
// https://lailin.xyz/post/go-training-week3-errgroup.html
func main() {
    g, ctx := errgroup.WithContext(context.Background())

    mux := http.NewServeMux()
    mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
        w.Write([]byte("pong"))
    })

    // 模擬單個服務錯誤退出
    serverOut := make(chan struct{})
    mux.HandleFunc("/shutdown", func(w http.ResponseWriter, r *http.Request) {
        serverOut <- struct{}{}
    })

    server := http.Server{
        Handler: mux,
        Addr:    ":8080",
    }

    // g1
    // g1 退出瞭所有的協程都能退出麼?
    // g1 退出後, context 將不再阻塞,g2, g3 都會隨之退出
    // 然後 main 函數中的 g.Wait() 退出,所有協程都會退出
    g.Go(func() error {
        return server.ListenAndServe()
    })

    // g2
    // g2 退出瞭所有的協程都能退出麼?
    // g2 退出時,調用瞭 shutdown,g1 會退出
    // g2 退出後, context 將不再阻塞,g3 會隨之退出
    // 然後 main 函數中的 g.Wait() 退出,所有協程都會退出
    g.Go(func() error {
        select {
        case <-ctx.Done():
            log.Println("errgroup exit...")
        case <-serverOut:
            log.Println("server will out...")
        }

        timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
        // 這裡不是必須的,但是如果使用 _ 的話靜態掃描工具會報錯,加上也無傷大雅
        defer cancel()

        log.Println("shutting down server...")
        return server.Shutdown(timeoutCtx)
    })

    // g3
    // g3 捕獲到 os 退出信號將會退出
    // g3 退出瞭所有的協程都能退出麼?
    // g3 退出後, context 將不再阻塞,g2 會隨之退出
    // g2 退出時,調用瞭 shutdown,g1 會退出
    // 然後 main 函數中的 g.Wait() 退出,所有協程都會退出
    g.Go(func() error {
        quit := make(chan os.Signal, 0)
        signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)

        select {
        case <-ctx.Done():
            return ctx.Err()
        case sig := <-quit:
            return errors.Errorf("get os signal: %v", sig)
        }
    })

    fmt.Printf("errgroup exiting: %+v\n", g.Wait())
}

改善版的 Errgroup

使用 errgroup.Group 的 WithContext 我們註意到在返回 eg 對象的同時還會返回另外一個可以取消的 context 對象,這個 context 對象的功能就是用來傳遞到 eg 需要同步的 groutine 中有一個發生錯誤時取消整個同步的 groutine,但是有不少同學可能會不經意將這個 context 傳到其他的非 eg 同步的業務代碼groutine 中,這樣會導致非關聯的業務代碼莫名其妙的收到 cancel 信息,類似如下的寫法:

func main() {
    ctx := context.Background()
    eg, ctx := errgroup.WithContext(ctx)

    for i := 0; i < 10; i++ {
        i := i // 這裡需要進行賦值操作,不然會有閉包問題,eg.Go 執行的 groutine 會引用 for 循環的 i
        eg.Go(func() error {
            select {
            case <-ctx.Done():
                return ctx.Err()
            case <-time.After(time.Duration(i) * time.Second):
            }
            if i == 2 {
                return fmt.Errorf("i can't be 2") // 需要返回 err 才會導致 eg 的 cancel 方法
            }
            fmt.Println(i)
            return nil
        })
    }

    if err := eg.Wait(); err != nil {
        fmt.Println(err.Error())
    }

    OtherLogic(ctx)
}

func OtherLogic(ctx context.Context) {
    // 這裡的 context 用瞭創建 eg 返回的 context,這個 context 可能會往後面更多的 func 中傳遞
    // 如果在該方法或者後面的 func 中有對 context 監聽取消型號,會導致這些 context 被取消瞭
}

另外不管是 WaitGroup 還是 errgroup.Group 都不支持控制最大並發限制以及 panic 恢復的功能,因為我們不能保障我們通過創建的 groutine 不會出現異常,如果沒有在創建的協程中捕獲異常,會直接導致整個程序退出,這是非常危險的。
這裡推薦一下 bilbil 開源的微服務框架 go-kratos/kratos 自己實現瞭一個改善版本的 errgroup.Group,其實現的的思路是利用 channel 來控制並發,並且創建 errgroup 的時候不會返回 context 避免 context 往非關聯的業務方法中傳遞。

到此這篇關於Golang 標準庫 tips之waitgroup詳解的文章就介紹到這瞭,更多相關Golang waitgroup內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: