go 分布式限流器
CocoAdapter 人气:0项目中需要对 api 的接口进行限流,但是麻烦的是,api 可能有多个节点,传统的本地限流无法处理这个问题。限流的算法有很多,比如计数器法,漏斗法,令牌桶法,等等。各有利弊,相关博文网上很多,这里不再赘述。
项目的要求主要有以下几点:
- 支持本地/分布式限流,接口统一
- 支持多种限流算法的切换
- 方便配置,配置方式不确定
go 语言不是很支持 OOP,我在实现的时候是按 Java 的思路走的,所以看起来有点不伦不类,希望能抛砖引玉。
1. 接口定义
package ratelimit import "time" // 限流器接口 type Limiter interface { Acquire() error TryAcquire() bool } // 限流定义接口 type Limit interface { Name() string Key() string Period() time.Duration Count() int32 LimitType() LimitType } // 支持 burst type BurstLimit interface { Limit BurstCount() int32 } // 分布式定义的 burst type DistLimit interface { Limit ClusterNum() int32 } type LimitType int32 const ( CUSTOM LimitType = iota IP )
Limiter 接口参考了 Google 的 guava 包里的 Limiter 实现。Acquire 接口是阻塞接口,其实还需要加上 context 来保证调用链安全,因为实际项目中并没有用到 Acquire 接口,所以没有实现完善;同理,超时时间的支持也可以通过添加新接口继承自 Limiter 接口来实现。TryAcquire 会立即返回。
Limit 抽象了一个限流定义,Key() 方法返回这个 Limit 的唯一标识,Name() 仅作辅助,Period() 表示周期,单位是秒,Count() 表示周期内的最大次数,LimitType()表示根据什么来做区分,如 IP,默认是 CUSTOM.
BurstLimit 提供突发的能力,一般是配合令牌桶算法。DistLimit 新增 ClusterNum() 方法,因为 mentor 要求分布式遇到错误的时候,需要退化为单机版本,退化的策略即是:2 节点总共 100QPS,如果出现分区,每个节点需要调整为各 50QPS
2. LocalCounterLimiter
package ratelimit import ( "errors" "fmt" "math" "sync" "sync/atomic" "time" ) // todo timer 需要 stop type localCounterLimiter struct { limit Limit limitCount int32 // 内部使用,对 limit.count 做了 <0 时的转换 ticker *time.Ticker quit chan bool lock sync.Mutex newTerm *sync.Cond count int32 } func (lim *localCounterLimiter) init() { lim.newTerm = sync.NewCond(&lim.lock) lim.limitCount = lim.limit.Count() if lim.limitCount < 0 { lim.limitCount = math.MaxInt32 // count 永远不会大于 limitCount,后面的写法保证溢出也没问题 } else if lim.limitCount == 0 { // 禁止访问, 会无限阻塞 } else { lim.ticker = time.NewTicker(lim.limit.Period()) lim.quit = make(chan bool, 1) go func() { for { select { case <- lim.ticker.C: fmt.Println("ticker .") atomic.StoreInt32(&lim.count, 0) lim.newTerm.Broadcast() //lim.newTerm.L.Unlock() case <- lim.quit: fmt.Println("work well .") lim.ticker.Stop() return } } }() } } // todo 需要机制来防止无限阻塞, 不超时也应该有个极限时间 func (lim *localCounterLimiter) Acquire() error { if lim.limitCount == 0 { return errors.New("rate limit is 0, infinity wait") } lim.newTerm.L.Lock() for lim.count >= lim.limitCount { // block instead of spinning lim.newTerm.Wait() //fmt.Println(count, lim.limitCount) } lim.count++ lim.newTerm.L.Unlock() return nil } func (lim *localCounterLimiter) TryAcquire() bool { count := atomic.AddInt32(&lim.count, 1) if count > lim.limitCount { return false } else { return true } }
代码很简单,就不多说了
3. LocalTokenBucketLimiter
golang 的官方库里提供了一个 ratelimiter,就是采用令牌桶的算法。所以这里并没有重复造轮子,直接代理了 ratelimiter。
package ratelimit import ( "context" "golang.org/x/time/rate" "math" ) type localTokenBucketLimiter struct { limit Limit limiter *rate.Limiter // 直接复用令牌桶的 } func (lim *localTokenBucketLimiter) init() { burstCount := lim.limit.Count() if burstLimit, ok := lim.limit.(BurstLimit); ok { burstCount = burstLimit.BurstCount() } count := lim.limit.Count() if count < 0 { count = math.MaxInt32 } f := float64(count) / lim.limit.Period().Seconds() if f < 0 { f = float64(rate.Inf) // 无限 } else if f == 0 { panic("为 0 的时候,底层实现有问题") } lim.limiter = rate.NewLimiter(rate.Limit(f), int(burstCount)) } func (lim *localTokenBucketLimiter) Acquire() error { err := lim.limiter.Wait(context.TODO()) return err } func (lim *localTokenBucketLimiter) TryAcquire() bool { return lim.limiter.Allow() }
4. RedisCounterLimiter
package ratelimit import ( "math" "sync" "xg-go/log" "xg-go/xg/common" ) type redisCounterLimiter struct { limit DistLimit limitCount int32 // 内部使用,对 limit.count 做了 <0 时的转换 redisClient *common.RedisClient once sync.Once // 退化为本地计数器的时候使用 localLim Limiter //script string } func (lim *redisCounterLimiter) init() { lim.limitCount = lim.limit.Count() if lim.limitCount < 0 { lim.limitCount = math.MaxInt32 } //lim.script = buildScript() } //func buildScript() string { // sb := strings.Builder{} // // sb.WriteString("local c") // sb.WriteString("\nc = redis.call('get',KEYS[1])") // // 调用不超过最大值,则直接返回 // sb.WriteString("\nif c and tonumber(c) > tonumber(ARGV[1]) then") // sb.WriteString("\nreturn c;") // sb.WriteString("\nend") // // 执行计算器自加 // sb.WriteString("\nc = redis.call('incr',KEYS[1])") // sb.WriteString("\nif tonumber(c) == 1 then") // sb.WriteString("\nredis.call('expire',KEYS[1],ARGV[2])") // sb.WriteString("\nend") // sb.WriteString("\nif tonumber(c) == 1 then") // sb.WriteString("\nreturn c;") // // return sb.String() //} func (lim *redisCounterLimiter) Acquire() error { panic("implement me") } func (lim *redisCounterLimiter) TryAcquire() (success bool) { defer func() { // 一般是 redis 连接断了,会触发空指针 if err := recover(); err != nil { //log.Errorw("TryAcquire err", common.ERR, err) //success = lim.degradeTryAcquire() //return success = true } // 没有错误,判断是否开启了 local 如果开启了,把它停掉 //if lim.localLim != nil { // // stop 线程安全 // lim.localLim.Stop() //} }() count, err := lim.redisClient.IncrBy(lim.limit.Key(), 1) //panic("模拟 redis 出错") if err != nil { log.Errorw("TryAcquire err", common.ERR, err) panic(err) } // *2 是为了保留久一点,便于观察 err = lim.redisClient.Expire(lim.limit.Key(), int(2 * lim.limit.Period().Seconds())) if err != nil { log.Errorw("TryAcquire error", common.ERR, err) panic(err) } // 业务正确的情况下 确认超限 if int32(count) > lim.limitCount { return false } return true //keys := []string{lim.limit.Key()} // //log.Errorw("TryAcquire ", keys, lim.limit.Count(), lim.limit.Period().Seconds()) //count, err := lim.redisClient.Eval(lim.script, keys, lim.limit.Count(), lim.limit.Period().Seconds()) //if err != nil { // log.Errorw("TryAcquire error", common.ERR, err) // return false //} // // //typeName := reflect.TypeOf(count).Name() //log.Errorw(typeName) // //if count != nil && count.(int32) <= lim.limitCount { // // return true //} //return false } func (lim *redisCounterLimiter) Stop() { // 判断是否开启了 local 如果开启了,把它停掉 if lim.localLim != nil { // stop 线程安全 lim.localLim.Stop() } } func (lim *redisCounterLimiter) degradeTryAcquire() bool { lim.once.Do(func() { count := lim.limit.Count() / lim.limit.ClusterNum() limit := LocalLimit { name: lim.limit.Name(), key: lim.limit.Key(), count: count, period: lim.limit.Period(), limitType: lim.limit.LimitType(), } lim.localLim = NewLimiter(&limit) }) return lim.localLim.TryAcquire() }
代码里回退的部分注释了,因为线上为了稳定,实习生的代码毕竟,所以先不跑。
本来原有的思路是直接用 lua 脚本在 redis 上保证原子操作,但是底层封装的库对于直接调 eval 跑的时候,会抛错,而且 source 是 go-redis 里面,赶 ddl 没有时间去 debug,所以只能用 incrBy + expire 分开来。
5. RedisTokenBucketLimiter
令牌桶的状态变量得放在一个 线程安全/一致 的地方,redis 是不二人选。但是令牌桶的算法核心是个延迟计算得到令牌数量,这个是一个很长的临界区,所以要么用分布式锁,要么直接利用 redis 的单线程以原子方式跑。一般业界是后者,即 lua 脚本维护令牌桶的状态变量、计算令牌。代码类似这种
local tokens_key = KEYS[1] local timestamp_key = KEYS[2] --redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key) local rate = tonumber(ARGV[1]) local capacity = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) local requested = tonumber(ARGV[4]) local intval = tonumber(ARGV[5]) local fill_time = capacity/rate local ttl = math.floor(fill_time*2) * intval local last_tokens = tonumber(redis.call("get", tokens_key)) if last_tokens == nil then last_tokens = capacity end local last_refreshed = tonumber(redis.call("get", timestamp_key)) if last_refreshed == nil then last_refreshed = 0 end local delta = math.max(0, now-last_refreshed) local filled_tokens = math.min(capacity, last_tokens+(delta*rate)) local allowed = filled_tokens >= requested local new_tokens = filled_tokens if allowed then new_tokens = filled_tokens - requested end redis.call("setex", tokens_key, ttl, new_tokens) redis.call("setex", timestamp_key, ttl, now) return { allowed, new_tokens }
加载全部内容