Golang errgroup 設計及實現原理解析

開篇

繼上次學習瞭信號量 semaphore 擴展庫的設計思路和實現之後,今天我們繼續來看 golang.org/x/sync 包下的另一個經常被 Golang 開發者使用的大殺器:errgroup。

業務研發中我們經常會遇到需要調用多個下遊的場景,比如加載一個商品的詳情頁,你可能需要訪問商品服務,庫存服務,券服務,用戶服務等,才能從各個數據源獲取到所需要的信息,經過一些鑒權邏輯後,組裝成前端需要的數據格式下發。

串行調用當然可以,但這樣就潛在的給各個調用預設瞭【順序】,必須執行完 A,B,C 之後才能執行 D 操作。但如果我們對於順序並沒有強需求,從語義上看兩個調用是完全獨立可並發的,那麼我們就可以讓他們並發執行。

這個時候就可以使用 errgroup 來解決問題。一定意義上講,errgroup 是基於 WaitGroup 在錯誤傳遞上進行一些優化而提供出來的能力。它不僅可以支持 context.Context 的相關控制能力,還可以將子任務的 error 向上傳遞。

errgroup 源碼拆解

errgroup 定義在 golang.org/x/sync/errgroup,承載核心能力的結構體是 Group。

Group

type token struct{}
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid, has no limit on the number of active goroutines,
// and does not cancel on error.
type Group struct {
	cancel func()
	wg sync.WaitGroup
	sem chan token
	errOnce sync.Once
	err     error
}

Group 就是對我們上面提到的一堆子任務執行計劃的抽象。每一個子任務都會有自己對應的 goroutine 來執行。

通過這個結構我們也可以看出來,errgroup 底層實現多個 goroutine 調度,等待的能力還是基於 sync.WaitGroup。

WithContext

我們可以調用 errgroup.WithContext 函數來創建一個 Group。

// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
	ctx, cancel := context.WithCancel(ctx)
	return &Group{cancel: cancel}, ctx
}

這裡可以看到,Group 的 cancel 函數本質就是為瞭支持 context 的 cancel 能力,初始化的 Group 隻有一個 cancel 屬性,其他都是默認的。一旦有一個子任務返回錯誤,或者是 Wait 調用返回,這個新 Context 就會被 cancel。

Wait

本質上和 WaitGroup 的 Wait 方法語義一樣,既然我們是個 group task,就需要等待所有任務都執行完,這個語義就由 Wait 方法提供。如果有多個子任務返回錯誤,它隻會返回第一個出現的錯誤,如果所有的子任務都執行成功,就返回 nil。

// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
	g.wg.Wait()
	if g.cancel != nil {
		g.cancel()
	}
	return g.err
}

Wait 的實現非常簡單。一個前置的 WaitGroup Wait,結束後隻做瞭兩件事:

  • 如果對於公共的 Context 有 cancel 函數,就將其 cancel,因為事情辦完瞭;
  • 返回公共的 Group 結構中的 err 給調用方。

Go

Group 的核心能力就在於能夠並發執行多個子任務,從調用者的角度,我們隻需要傳入要執行的函數,簽名為:func() error即可,非常通用。如果任務執行成功,就返回 nil,否則就返回 error,並且會 cancel 那個新的 Context。底層的調度邏輯由 Group 的 Go 方法實現:

// Go calls the given function in a new goroutine.
// It blocks until the new goroutine can be added without the number of
// active goroutines in the group exceeding the configured limit.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
	if g.sem != nil {
		g.sem <- token{}
	}
	g.wg.Add(1)
	go func() {
		defer g.done()
		if err := f(); err != nil {
			g.errOnce.Do(func() {
				g.err = err
				if g.cancel != nil {
					g.cancel()
				}
			})
		}
	}()
}
func (g *Group) done() {
	if g.sem != nil {
		<-g.sem
	}
	g.wg.Done()
}

我們重點來分析下 Go 這裡發生瞭什麼。

WaitGroup 加 1 用作計數;

啟動一個新的 goroutine 執行調用方傳入的 f() 函數;

  • 若 err 為 nil 說明執行正常;
  • 若 err 不為 nil,說明執行出錯,此時將這個返回的 err 賦值給全局 Group 的變量 err,並將 context cancel 掉。註意,這裡的處理在 once 分支中,也就是隻有第一個來的錯誤會被處理。

在 defer 語句中調用 Group 的 done 方法,底層依賴 WaitGroup 的 Done,說明這一個子任務結束。

這裡也可以看到,其實所謂 errgroup,我們並不是將所有子任務的 error 拼成一個字符串返回,而是直接在 Go 方法那裡將第一個錯誤返回,底層依賴瞭 once 的能力。

SetLimit

其實看到這裡,你有沒有覺得 errgroup 的功能有點雞肋?底層核心技術都是靠 WaitGroup 完成的,自己隻不過是起瞭個 goroutine 執行方法,err 還隻能保留一個。而且 Group 裡面的 sem 那個 chan 是用來幹什麼呢?

這一節我們就來看看,Golang 對 errgroup 能力的一次擴充。

到目前為止,errgroup 是可以做到一開始人們對它的期望的,即並發執行子任務。但問題在於,這裡是每一個子任務都開瞭個goroutine,如果是在高並發的環境裡,頻繁創建大量goroutine 這樣的用法很容易對資源負載產生影響。開發者們提出,希望有辦法限制 errgroup 創建的 goroutine 數量,參照這個 proposal: #27837

// SetLimit limits the number of active goroutines in this group to at most n.
// A negative value indicates no limit.
//
// Any subsequent call to the Go method will block until it can add an active
// goroutine without exceeding the configured limit.
//
// The limit must not be modified while any goroutines in the group are active.
func (g *Group) SetLimit(n int) {
	if n < 0 {
		g.sem = nil
		return
	}
	if len(g.sem) != 0 {
		panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
	}
	g.sem = make(chan token, n)
}

SetLimit 的參數 n 就是我們對這個 Group 期望的最大 goroutine 數量,這裡其實除去校驗邏輯,隻幹瞭一件事:g.sem = make(chan token, n),即創建一個長度為 n 的 channel,賦值給 sem。

回憶一下我們在 Go 方法 defer 調用 done 中的表現,是不是清晰瞭很多?我們來理一下:

首先,Group 結構體就包含瞭 sem 變量,隻作為通信,元素是空結構體,不包含實際意義:

type Group struct {
	cancel func()
	wg sync.WaitGroup
	sem chan token
	errOnce sync.Once
	err     error
}

如果你對整個 Group 的 Limit 沒有要求,which is fine,你就直接忽略這個 SetLimit,errgroup 的原有能力就可以支持你的訴求。

但是,如果你希望保持 errgroup 的 goroutine 在一個上限之內,請在調用 Go 方法前,先 SetLimit,這樣 Group 的 sem 就賦值為一個長度為 n 的 channel。

那麼,當你調用 Go 方法時,註意下面的框架代碼,此時 g.sem 不為 nil,所以我們會塞一個 token 進去,作為占坑,語義上表示我申請一個 goroutine 用。

func (g *Group) Go(f func() error) {
	if g.sem != nil {
		g.sem <- token{}
	}
	g.wg.Add(1)
	go func() {
		defer g.done()
                ...
	}()
}

當然,如果此時 goroutine 數量已經達到上限,這裡就會 block 住,直到別的 goroutine 幹完活,sem 輸出瞭一個 token之後,才能繼續往裡面塞。

在每個 goroutine 執行完畢後 defer 的 g.done 方法,則是完成瞭這一點:

func (g *Group) done() {
	if g.sem != nil {
		<-g.sem
	}
	g.wg.Done()
}

這樣就把 sem 的用法串起來瞭。我們通過創建一個定長的channel來實現對於 goroutine 數量的管控,對於channel實際包含的元素並不關心,所以用一個空結構體省一省空間,這是非常優秀的設計,大傢平常也可以參考。

TryGo

TryGo 和 SetLimit 這套體系其實都是不久前歐長坤大佬提交到 errgroup 的能力。

一如既往,所有帶 TryXXX的函數,都不會阻塞。 其實辦的事情非常簡單,和 Go 方法一樣,傳進來一個 func() error來執行。

Go 方法的區別在於,如果判斷 limit 已經不夠瞭,此時不再阻塞,而是直接 return false,代表執行失敗。其他的部分完全一樣。

// TryGo calls the given function in a new goroutine only if the number of
// active goroutines in the group is currently below the configured limit.
//
// The return value reports whether the goroutine was started.
func (g *Group) TryGo(f func() error) bool {
	if g.sem != nil {
		select {
		case g.sem <- token{}:
			// Note: this allows barging iff channels in general allow barging.
		default:
			return false
		}
	}
	g.wg.Add(1)
	go func() {
		defer g.done()
		if err := f(); err != nil {
			g.errOnce.Do(func() {
				g.err = err
				if g.cancel != nil {
					g.cancel()
				}
			})
		}
	}()
	return true
}

使用方法

這裡我們先看一個最常見的用法,針對一組任務,每一個都單獨起 goroutine 執行,最後獲取 Wait 返回的 err 作為整個 Group 的 err。

package main
import (
    "errors"
    "fmt"
    "time"
    "golang.org/x/sync/errgroup"
)
func main() {
    var g errgroup.Group
    // 啟動第一個子任務,它執行成功
    g.Go(func() error {
        time.Sleep(5 * time.Second)
        fmt.Println("exec #1")
        return nil
    })
    // 啟動第二個子任務,它執行失敗
    g.Go(func() error {
        time.Sleep(10 * time.Second)
        fmt.Println("exec #2")
        return errors.New("failed to exec #2")
    })
    // 啟動第三個子任務,它執行成功
    g.Go(func() error {
        time.Sleep(15 * time.Second)
        fmt.Println("exec #3")
        return nil
    })
    // 等待三個任務都完成
    if err := g.Wait(); err == nil {
        fmt.Println("Successfully exec all")
    } else {
        fmt.Println("failed:", err)
    }
}

你會發現,最後 err 打印出來就是第二個子任務的 err。

當然,上面這個 case 是我們正好隻有一個報錯,但是,如果實際有多個任務都掛瞭呢?

從完備性來考慮,有沒有什麼辦法能夠將多個任務的錯誤都返回呢?

這一點其實 errgroup 庫並沒有提供非常好的支持,需要開發者自行做一些改造。因為 Group 中隻有一個 err 變量,我們不可能基於 Group 來實現這一點。

通常情況下,我們會創建一個 slice 來存儲 f() 執行的 err。

package main
import (
    "errors"
    "fmt"
    "time"
    "golang.org/x/sync/errgroup"
)
func main() {
    var g errgroup.Group
    var result = make([]error, 3)
    // 啟動第一個子任務,它執行成功
    g.Go(func() error {
        time.Sleep(5 * time.Second)
        fmt.Println("exec #1")
        result[0] = nil // 保存成功或者失敗的結果
        return nil
    })
    // 啟動第二個子任務,它執行失敗
    g.Go(func() error {
        time.Sleep(10 * time.Second)
        fmt.Println("exec #2")
        result[1] = errors.New("failed to exec #2") // 保存成功或者失敗的結果
        return result[1]
    })
    // 啟動第三個子任務,它執行成功
    g.Go(func() error {
        time.Sleep(15 * time.Second)
        fmt.Println("exec #3")
        result[2] = nil // 保存成功或者失敗的結果
        return nil
    })
    if err := g.Wait(); err == nil {
        fmt.Printf("Successfully exec all. result: %v\n", result)
    } else {
        fmt.Printf("failed: %v\n", result)
    }
}

可以看到,我們聲明瞭一個 result slice,長度為 3。這裡不用擔心並發問題,因為每個 goroutine 讀寫的位置是確定唯一的。

本質上可以理解為,我們把 f() 返回的 err 不僅僅給瞭 Group 一份,還自己存瞭一份,當出錯的時候,Wait 返回的錯誤我們不一定真的用,而是直接用自己錯的這一個 error slice。

Go 官方文檔中的利用 errgroup 實現 pipeline 的示例也很有參考意義:由一個子任務遍歷文件夾下的文件,然後把遍歷出的文件交給 20 個 goroutine,讓這些 goroutine 並行計算文件的 md5。

這裡貼出來簡化代碼學習一下.

package main
import (
	"context"
	"crypto/md5"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"path/filepath"
	"golang.org/x/sync/errgroup"
)
// Pipeline demonstrates the use of a Group to implement a multi-stage
// pipeline: a version of the MD5All function with bounded parallelism from
// https://blog.golang.org/pipelines.
func main() {
	m, err := MD5All(context.Background(), ".")
	if err != nil {
		log.Fatal(err)
	}
	for k, sum := range m {
		fmt.Printf("%s:\t%x\n", k, sum)
	}
}
type result struct {
	path string
	sum  [md5.Size]byte
}
// MD5All reads all the files in the file tree rooted at root and returns a map
// from file path to the MD5 sum of the file's contents. If the directory walk
// fails or any read operation fails, MD5All returns an error.
func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
	// ctx is canceled when g.Wait() returns. When this version of MD5All returns
	// - even in case of error! - we know that all of the goroutines have finished
	// and the memory they were using can be garbage-collected.
	g, ctx := errgroup.WithContext(ctx)
	paths := make(chan string)
	g.Go(func() error {
		defer close(paths)
		return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
			if err != nil {
				return err
			}
			if !info.Mode().IsRegular() {
				return nil
			}
			select {
			case paths <- path:
			case <-ctx.Done():
				return ctx.Err()
			}
			return nil
		})
	})
	// Start a fixed number of goroutines to read and digest files.
	c := make(chan result)
	const numDigesters = 20
	for i := 0; i < numDigesters; i++ {
		g.Go(func() error {
			for path := range paths {
				data, err := ioutil.ReadFile(path)
				if err != nil {
					return err
				}
				select {
				case c <- result{path, md5.Sum(data)}:
				case <-ctx.Done():
					return ctx.Err()
				}
			}
			return nil
		})
	}
	go func() {
		g.Wait()
		close(c)
	}()
	m := make(map[string][md5.Size]byte)
	for r := range c {
		m[r.path] = r.sum
	}
	// Check whether any of the goroutines failed. Since g is accumulating the
	// errors, we don't need to send them (or check for them) in the individual
	// results sent on the channel.
	if err := g.Wait(); err != nil {
		return nil, err
	}
	return m, nil
}

其實本質上還是 channel發揮瞭至關重要的作用,這裡建議大傢有時間盡量看看源文章:pkg.go.dev/golang.org/…

對於用 errgroup 實現 pipeline 模式有很大幫助。

結束語

今天我們學習瞭 errgroup 的源碼已經新增的 SetLimit 能力,其實看過瞭 sync 相關的這些庫,整體用到的能力其實大差不差,很多設計思想都是非常精巧的。比如 errgroup 中利用定長 channel 實現控制 goroutine 數量,空結構體節省內存空間。

並且 sync 包的實現一般都還是非常簡潔的,比如 once,singleflight,semaphore 等。建議大傢有空的話自己過一遍,對並發和設計模式的理解會更上一個臺階。

errgroup 本身並不復雜,業界也有很多封裝實現,大傢可以對照源碼再思考一下還有什麼可以改進的地方。

以上就是Golang errgroup 設計及實現原理解析的詳細內容,更多關於Golang errgroup 設計原理的資料請關註WalkonNet其它相關文章!

推薦閱讀: