Golang的Fork/Join实现代码
互联网速递520 人气:0做过Java开发的同学肯定知道,JDK7加入的Fork/Join是一个非常优秀的设计,到了JDK8,又结合并行流中进行了优化和增强,是一个非常好的工具。
1、Fork/Join是什么
Fork/Join本质上是一种任务分解,即:将一个很大的任务分解成若干个小任务,然后再对小任务进一步分解,直到最小颗粒度,然后并发执行。
这么做的优点很明显,就是可以大幅提升计算性能,缺点嘛,也有一点,那就是资源开销要大一些。
在网上找了一张图,任务分解就是这个意思:
2、Golang中的Fork/Join实现
对于Golang中的Fork/Join的实现,我参考了JDK的源码,利用了Goroutine特性,这样就能充分利用MPG模型,不必自己再处理任务窃取等问题了,用起来还是蛮爽的。
废话不多说,请看代码:
package like_fork_join import ( "fmt" "github.com/oklog/ulid/v2" ) const defaultPageSize = 10 type MyForkJoinTask struct { size int } // NewMyTask 初始化一个任务 func NewMyTask(pageSize int) *MyForkJoinTask { var size = defaultPageSize if pageSize > size { size = pageSize } return &MyForkJoinTask{ size: size, } } // Do 执行任务时,传入一个切片 func (t *MyForkJoinTask) Do(numbers []int) int { JoinCh := make(chan bool, 1) resultCh := make(chan int, 1) t.do(numbers, JoinCh, resultCh, ulid.Make().String()) result := <-resultCh return result } func (t *MyForkJoinTask) do(numbers []int, joinCh chan bool, resultCh chan int, id string) { defer func() { joinCh <- true close(joinCh) close(resultCh) }() fmt.Printf("id %s numbers %+v\n", id, numbers) // 任务小于最小颗粒度时,直接执行逻辑(此处是求和),不再拆分,否则进行分治 if len(numbers) <= t.size { var sum = 0 for _, number := range numbers { sum += number } resultCh <- sum fmt.Printf("id %s numbers %+v, result %+v\n", id, numbers, sum) return } else { start := 0 end := len(numbers) middle := (start + end) / 2 // 左 leftJoinCh := make(chan bool, 1) leftResultCh := make(chan int, 1) leftId := ulid.Make().String() go t.do(numbers[start:middle], leftJoinCh, leftResultCh, id+"->left->"+leftId) // 右 rightJoinCh := make(chan bool, 1) rightResultCh := make(chan int, 1) rightId := ulid.Make().String() go t.do(numbers[middle:], rightJoinCh, rightResultCh, id+"->right->"+rightId) // 等待左边和右边分治子任务结束 var leftDone, rightDone = false, false for { select { case _, ok := <-leftJoinCh: if ok { fmt.Printf("left %s join done\n", leftId) leftDone = true } case _, ok := <-rightJoinCh: if ok { fmt.Printf("right %s join done\n", rightId) rightDone = true } } if leftDone && rightDone { break } } // 取结果 var ( left = 0 right = 0 leftResultDone = false rightResultDone = false ) for { select { case l, ok := <-leftResultCh: if ok { fmt.Printf("id %s numbers %+v, left %s return: %+v\n", id, numbers, leftId, left) left = l leftResultDone = true } case r, ok := <-rightResultCh: if ok { fmt.Printf("id %s numbers %+v, right %s return: %+v\n", id, numbers, rightId, right) right = r rightResultDone = true } } if leftResultDone && rightResultDone { break } } resultCh <- left + right return } }
代码也不复杂,有注释,大家耐心读一下就明白了。
3、测试验证
我写了一个比较有压力的测试用例代码,请看:
package like_fork_join import ( "fmt" "testing" ) func TestMyTask_Do(t1 *testing.T) { type args struct { numbers []int } const max = 10000 var nums = make([]int, 0, max) var want = 0 for i := 1; i <= max; i++ { nums = append(nums, i) want += i } tests := []struct { name string args args want int }{ {name: fmt.Sprintf("sum(1,%d)", max), args: args{numbers: nums}, want: want}, } for _, tt := range tests { t1.Run(tt.name, func(t1 *testing.T) { for i := 0; i <= 100; i += 5 { t := NewMyTask(i) if got := t.Do(tt.args.numbers); got != tt.want { t1.Errorf("Do() = %v, want %v", got, tt.want) } } }) } }
测试成功:
--- PASS: TestMyTask_Do/sum(1,10000) (1257.79s) PASS
4、小优化
删除所有fmt包的控制台输出,再跑单元测试结果:
=== RUN TestMyTask_Do
--- PASS: TestMyTask_Do (60.53s)
=== RUN TestMyTask_Do/sum(1,10000)
--- PASS: TestMyTask_Do/sum(1,10000) (60.53s)
PASS
20万次加法计算,长度为1万的数组的20次计算,60秒搞定,性能巨强,Golang就是棒!
5、后续计划
计划后续再研究研究,看能否把执行任务的逻辑做成泛型和函数闭包,给抽象出来,这样就能单独形成一个通用型的代码包,供外部各种应用程序使用了,不过考虑到goroutine的上下文等问题,估计会让代码比较复杂,眼下这个版本足够简单,也能满足绝大多数场景了。
加载全部内容