package timingwheel
import (
"context"
"sync/atomic"
"time"
"unsafe"
)
const delayQueueBufferSize = 10
type TimingWheel struct {
tick int64
wheelSize int64
interval int64
currentTime int64
buckets []*bucket
queue *delayQueue
overflowWheel unsafe.Pointer
}
func New(tick, wheelSize int64) *TimingWheel {
return newTimingWheel(tick, wheelSize, time.Now().UnixMilli(), newDelayQueue())
}
func newTimingWheel(tick, wheelSize, currentTime int64, queue *delayQueue) *TimingWheel {
tw := &TimingWheel{
tick: tick,
wheelSize: wheelSize,
interval: tick * wheelSize,
currentTime: truncate(currentTime, tick),
buckets: make([]*bucket, wheelSize),
queue: queue,
}
for i := 0; i < int(wheelSize); i++ {
tw.buckets[i] = newBucket()
}
return tw
}
func (tw *TimingWheel) Run(ctx context.Context) {
bucketChan := tw.queue.channel(ctx, delayQueueBufferSize, func() int64 {
return time.Now().UnixMilli()
})
for {
select {
case b := <-bucketChan:
tw.advance(b.expiration)
b.flush(tw.addOrRun)
case <-ctx.Done():
return
}
}
}
func (tw *TimingWheel) AfterFunc(delay time.Duration, f func()) *Timer {
t := &Timer{
expiration: time.Now().Add(delay).UnixMilli(),
task: f,
}
tw.add(t)
return t
}
type Scheduler interface {
Next(time.Time) time.Time
}
func (tw *TimingWheel) ScheduleFunc(s Scheduler, f func()) (t *Timer) {
expiration := s.Next(time.Now())
if expiration.IsZero() {
return
}
t = &Timer{
expiration: expiration.UnixMilli(),
task: func() {
expiration := s.Next(time.UnixMilli(t.expiration))
if !expiration.IsZero() {
t.expiration = expiration.UnixMilli()
tw.addOrRun(t)
}
f()
},
}
tw.addOrRun(t)
return
}
func (tw *TimingWheel) add(t *Timer) bool {
currentTime := atomic.LoadInt64(&tw.currentTime)
if t.expiration < currentTime+tw.tick {
return false
} else if t.expiration < currentTime+tw.interval {
ticks := t.expiration / tw.tick
b := tw.buckets[ticks%tw.wheelSize]
b.add(t)
if b.setExpiration(ticks * tw.tick) {
tw.queue.push(b)
}
return true
} else {
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel == nil {
tw.setOverflowWheel(currentTime)
overflowWheel = atomic.LoadPointer(&tw.overflowWheel)
}
return (*TimingWheel)(overflowWheel).add(t)
}
}
func (tw *TimingWheel) addOrRun(t *Timer) {
if !tw.add(t) {
go t.task()
}
}
func (tw *TimingWheel) advance(expiration int64) {
currentTime := atomic.LoadInt64(&tw.currentTime)
if expiration >= currentTime+tw.tick {
currentTime := truncate(expiration, tw.tick)
atomic.StoreInt64(&tw.currentTime, currentTime)
overflowWheel := atomic.LoadPointer(&tw.overflowWheel)
if overflowWheel != nil {
(*TimingWheel)(overflowWheel).advance(currentTime)
}
}
}
func (tw *TimingWheel) setOverflowWheel(currentTime int64) {
overflowWheel := newTimingWheel(tw.interval, tw.wheelSize, currentTime, tw.queue)
atomic.CompareAndSwapPointer(&tw.overflowWheel, nil, unsafe.Pointer(overflowWheel))
}
func truncate(time, tick int64) int64 {
return time - time%tick
}