Golang的Fork/Join實現代碼

做過Java開發的同學肯定知道,JDK7加入的Fork/Join是一個非常優秀的設計,到瞭JDK8,又結合並行流中進行瞭優化和增強,是一個非常好的工具。

1、Fork/Join是什麼

Fork/Join本質上是一種任務分解,即:將一個很大的任務分解成若幹個小任務,然後再對小任務進一步分解,直到最小顆粒度,然後並發執行。

這麼做的優點很明顯,就是可以大幅提升計算性能,缺點嘛,也有一點,那就是資源開銷要大一些。

在網上找瞭一張圖,任務分解就是這個意思:

2、Golang中的Fork/Join實現

對於Golang中的Fork/Join的實現,我參考瞭JDK的源碼,利用瞭Goroutine特性,這樣就能充分利用MPG模型,不必自己再處理任務竊取等問題瞭,用起來還是蠻爽的。

廢話不多說,請看代碼:

package like_fork_join
 
import (
    "fmt"
    "github.com/oklog/ulid/v2"
)
 
const defaultPageSize = 10
 
type MyForkJoinTask struct {
    size int
}
 
// NewMyTask 初始化一個任務
func NewMyTask(pageSize int) *MyForkJoinTask {
    var size = defaultPageSize
    if pageSize > size {
        size = pageSize
    }
    return &MyForkJoinTask{
        size: size,
    }
}
 
// Do 執行任務時,傳入一個切片
func (t *MyForkJoinTask) Do(numbers []int) int {
    JoinCh := make(chan bool, 1)
    resultCh := make(chan int, 1)
    t.do(numbers, JoinCh, resultCh, ulid.Make().String())
    result := <-resultCh
    return result
}
 
func (t *MyForkJoinTask) do(numbers []int, joinCh chan bool, resultCh chan int, id string) {
    defer func() {
        joinCh <- true
        close(joinCh)
        close(resultCh)
    }()
    fmt.Printf("id %s numbers %+v\n", id, numbers)
    // 任務小於最小顆粒度時,直接執行邏輯(此處是求和),不再拆分,否則進行分治
    if len(numbers) <= t.size {
        var sum = 0
        for _, number := range numbers {
            sum += number
        }
        resultCh <- sum
        fmt.Printf("id %s numbers %+v, result %+v\n", id, numbers, sum)
        return
    } else {
        start := 0
        end := len(numbers)
        middle := (start + end) / 2
 
        // 左
        leftJoinCh := make(chan bool, 1)
        leftResultCh := make(chan int, 1)
        leftId := ulid.Make().String()
        go t.do(numbers[start:middle], leftJoinCh, leftResultCh, id+"->left->"+leftId)
 
        // 右
        rightJoinCh := make(chan bool, 1)
        rightResultCh := make(chan int, 1)
        rightId := ulid.Make().String()
        go t.do(numbers[middle:], rightJoinCh, rightResultCh, id+"->right->"+rightId)
 
        // 等待左邊和右邊分治子任務結束
        var leftDone, rightDone = false, false
        for {
            select {
            case _, ok := <-leftJoinCh:
                if ok {
                    fmt.Printf("left %s join done\n", leftId)
                    leftDone = true
                }
            case _, ok := <-rightJoinCh:
                if ok {
                    fmt.Printf("right %s join done\n", rightId)
                    rightDone = true
                }
            }
            if leftDone && rightDone {
                break
            }
        }
 
        // 取結果
        var (
            left            = 0
            right           = 0
            leftResultDone  = false
            rightResultDone = false
        )
        for {
            select {
            case l, ok := <-leftResultCh:
                if ok {
                    fmt.Printf("id %s numbers %+v, left %s return: %+v\n", id, numbers, leftId, left)
                    left = l
                    leftResultDone = true
                }
            case r, ok := <-rightResultCh:
                if ok {
                    fmt.Printf("id %s numbers %+v, right %s return: %+v\n", id, numbers, rightId, right)
                    right = r
                    rightResultDone = true
                }
            }
            if leftResultDone && rightResultDone {
                break
            }
        }
 
        resultCh <- left + right
        return
    }
}

代碼也不復雜,有註釋,大傢耐心讀一下就明白瞭。

3、測試驗證

我寫瞭一個比較有壓力的測試用例代碼,請看:

package like_fork_join
 
import (
    "fmt"
    "testing"
)
 
func TestMyTask_Do(t1 *testing.T) {
    type args struct {
        numbers []int
    }
    const max = 10000
    var nums = make([]int, 0, max)
    var want = 0
    for i := 1; i <= max; i++ {
        nums = append(nums, i)
        want += i
    }
    tests := []struct {
        name string
        args args
        want int
    }{
        {name: fmt.Sprintf("sum(1,%d)", max), args: args{numbers: nums}, want: want},
    }
    for _, tt := range tests {
        t1.Run(tt.name, func(t1 *testing.T) {
            for i := 0; i <= 100; i += 5 {
                t := NewMyTask(i)
                if got := t.Do(tt.args.numbers); got != tt.want {
                    t1.Errorf("Do() = %v, want %v", got, tt.want)
                }
            }
        })
    }
}

測試成功:

    --- PASS: TestMyTask_Do/sum(1,10000) (1257.79s)
PASS

4、小優化

刪除所有fmt包的控制臺輸出,再跑單元測試結果:

=== RUN   TestMyTask_Do
— PASS: TestMyTask_Do (60.53s)
=== RUN   TestMyTask_Do/sum(1,10000)
    — PASS: TestMyTask_Do/sum(1,10000) (60.53s)
PASS

20萬次加法計算,長度為1萬的數組的20次計算,60秒搞定,性能巨強,Golang就是棒!

5、後續計劃

計劃後續再研究研究,看能否把執行任務的邏輯做成泛型和函數閉包,給抽象出來,這樣就能單獨形成一個通用型的代碼包,供外部各種應用程序使用瞭,不過考慮到goroutine的上下文等問題,估計會讓代碼比較復雜,眼下這個版本足夠簡單,也能滿足絕大多數場景瞭。

到此這篇關於Golang的Fork/Join實現的文章就介紹到這瞭,更多相關Golang的Fork/Join實現內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: