java簡單手寫版本實現時間輪算法

時間輪

關於時間輪的介紹,網上有很多,這裡就不重復瞭

核心思想

  • 一個環形數組存儲時間輪的所有槽(看你的手表),每個槽對應當前時間輪的最小精度
  • 超過當前時間輪最大表示范圍的會被丟到上層時間輪,上層時間輪的最小精度即為下層時間輪能表達的最大時間(時分秒概念)
  • 每個槽對應一個環形鏈表存儲該時間應該被執行的任務
  • 需要一個線程去驅動指針運轉,獲取到期任務

以下給出java 簡單手寫版本實現

代碼實現

時間輪主數據結構

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 19:31
 */
@Slf4j
public class TimeWheel {
  /**
   * 一個槽的時間間隔(時間輪最小刻度)
   */
  private long tickMs;

  /**
   * 時間輪大小(槽的個數)
   */
  private int wheelSize;

  /**
   * 一輪的時間跨度
   */
  private long interval;

  private long currentTime;

  /**
   * 槽
   */
  private TimerTaskList[] buckets;

  /**
   * 上層時間輪
   */
  private volatile TimeWheel overflowWheel;

  /**
   * 一個timer隻有一個delayqueue
   */
  private DelayQueue<TimerTaskList> delayQueue;

  public TimeWheel(long tickMs, int wheelSize, long currentTime, DelayQueue<TimerTaskList> delayQueue) {
    this.currentTime = currentTime;
    this.tickMs = tickMs;
    this.wheelSize = wheelSize;
    this.interval = tickMs * wheelSize;
    this.buckets = new TimerTaskList[wheelSize];
    this.currentTime = currentTime - (currentTime % tickMs);
    this.delayQueue = delayQueue;
    for (int i = 0; i < wheelSize; i++) {
      buckets[i] = new TimerTaskList();
    }
  }

  public boolean add(TimerTaskEntry entry) {
    long expiration = entry.getExpireMs();
    if (expiration < tickMs + currentTime) {
      //到期瞭
      return false;
    } else if (expiration < currentTime + interval) {
      //扔進當前時間輪的某個槽裡,隻有時間大於某個槽,才會放進去
      long virtualId = (expiration / tickMs);
      int index = (int) (virtualId % wheelSize);
      TimerTaskList bucket = buckets[index];
      bucket.addTask(entry);
      //設置bucket 過期時間
      if (bucket.setExpiration(virtualId * tickMs)) {
        //設好過期時間的bucket需要入隊
        delayQueue.offer(bucket);
        return true;
      }
    } else {
      //當前輪不能滿足,需要扔到上一輪
      TimeWheel timeWheel = getOverflowWheel();
      return timeWheel.add(entry);
    }
    return false;
  }


  private TimeWheel getOverflowWheel() {
    if (overflowWheel == null) {
      synchronized (this) {
        if (overflowWheel == null) {
          overflowWheel = new TimeWheel(interval, wheelSize, currentTime, delayQueue);
        }
      }
    }
    return overflowWheel;
  }

  /**
   * 推進指針
   *
   * @param timestamp
   */
  public void advanceLock(long timestamp) {
    if (timestamp > currentTime + tickMs) {
      currentTime = timestamp - (timestamp % tickMs);
      if (overflowWheel != null) {
        this.getOverflowWheel().advanceLock(timestamp);
      }
    }
  }
}

定時器接口

/**
 * 定時器
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 20:30
 */
public interface Timer {

  /**
   * 添加一個新任務
   *
   * @param timerTask
   */
  void add(TimerTask timerTask);


  /**
   * 推動指針
   *
   * @param timeout
   */
  void advanceClock(long timeout);

  /**
   * 等待執行的任務
   *
   * @return
   */
  int size();

  /**
   * 關閉服務,剩下的無法被執行
   */
  void shutdown();
}

定時器實現

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 20:33
 */
@Slf4j
public class SystemTimer implements Timer {
  /**
   * 底層時間輪
   */
  private TimeWheel timeWheel;
  /**
   * 一個Timer隻有一個延時隊列
   */
  private DelayQueue<TimerTaskList> delayQueue = new DelayQueue<>();
  /**
   * 過期任務執行線程
   */
  private ExecutorService workerThreadPool;
  /**
   * 輪詢delayQueue獲取過期任務線程
   */
  private ExecutorService bossThreadPool;


  public SystemTimer() {
    this.timeWheel = new TimeWheel(1, 20, System.currentTimeMillis(), delayQueue);
    this.workerThreadPool = Executors.newFixedThreadPool(100);
    this.bossThreadPool = Executors.newFixedThreadPool(1);
    //20ms推動一次時間輪運轉
    this.bossThreadPool.submit(() -> {
      for (; ; ) {
        this.advanceClock(20);
      }
    });
  }


  public void addTimerTaskEntry(TimerTaskEntry entry) {
    if (!timeWheel.add(entry)) {
      //已經過期瞭
      TimerTask timerTask = entry.getTimerTask();
      log.info("=====任務:{} 已到期,準備執行============",timerTask.getDesc());
      workerThreadPool.submit(timerTask);
    }
  }

  @Override
  public void add(TimerTask timerTask) {
    log.info("=======添加任務開始====task:{}", timerTask.getDesc());
    TimerTaskEntry entry = new TimerTaskEntry(timerTask, timerTask.getDelayMs() + System.currentTimeMillis());
    timerTask.setTimerTaskEntry(entry);
    addTimerTaskEntry(entry);
  }

  /**
   * 推動指針運轉獲取過期任務
   *
   * @param timeout 時間間隔
   * @return
   */
  @Override
  public synchronized void advanceClock(long timeout) {
    try {
      TimerTaskList bucket = delayQueue.poll(timeout, TimeUnit.MILLISECONDS);
      if (bucket != null) {
        //推進時間
        timeWheel.advanceLock(bucket.getExpiration());
        //執行過期任務(包含降級)
        bucket.clear(this::addTimerTaskEntry);
      }
    } catch (InterruptedException e) {
      log.error("advanceClock error");
    }
  }

  @Override
  public int size() {
    //todo
    return 0;
  }

  @Override
  public void shutdown() {
    this.bossThreadPool.shutdown();
    this.workerThreadPool.shutdown();
    this.timeWheel = null;
  }
}

存儲任務的環形鏈表

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 19:26
 */
@Data
@Slf4j
class TimerTaskList implements Delayed {
  /**
   * TimerTaskList 環形鏈表使用一個虛擬根節點root
   */
  private TimerTaskEntry root = new TimerTaskEntry(null, -1);

  {
    root.next = root;
    root.prev = root;
  }

  /**
   * bucket的過期時間
   */
  private AtomicLong expiration = new AtomicLong(-1L);

  public long getExpiration() {
    return expiration.get();
  }

  /**
   * 設置bucket的過期時間,設置成功返回true
   *
   * @param expirationMs
   * @return
   */
  boolean setExpiration(long expirationMs) {
    return expiration.getAndSet(expirationMs) != expirationMs;
  }

  public boolean addTask(TimerTaskEntry entry) {
    boolean done = false;
    while (!done) {
      //如果TimerTaskEntry已經在別的list中就先移除,同步代碼塊外面移除,避免死鎖,一直到成功為止
      entry.remove();
      synchronized (this) {
        if (entry.timedTaskList == null) {
          //加到鏈表的末尾
          entry.timedTaskList = this;
          TimerTaskEntry tail = root.prev;
          entry.prev = tail;
          entry.next = root;
          tail.next = entry;
          root.prev = entry;
          done = true;
        }
      }
    }
    return true;
  }

  /**
   * 從 TimedTaskList 移除指定的 timerTaskEntry
   *
   * @param entry
   */
  public void remove(TimerTaskEntry entry) {
    synchronized (this) {
      if (entry.getTimedTaskList().equals(this)) {
        entry.next.prev = entry.prev;
        entry.prev.next = entry.next;
        entry.next = null;
        entry.prev = null;
        entry.timedTaskList = null;
      }
    }
  }

  /**
   * 移除所有
   */
  public synchronized void clear(Consumer<TimerTaskEntry> entry) {
    TimerTaskEntry head = root.next;
    while (!head.equals(root)) {
      remove(head);
      entry.accept(head);
      head = root.next;
    }
    expiration.set(-1L);
  }

  @Override
  public long getDelay(TimeUnit unit) {
    return Math.max(0, unit.convert(expiration.get() - System.currentTimeMillis(), TimeUnit.MILLISECONDS));
  }

  @Override
  public int compareTo(Delayed o) {
    if (o instanceof TimerTaskList) {
      return Long.compare(expiration.get(), ((TimerTaskList) o).expiration.get());
    }
    return 0;
  }
}

存儲任務的容器entry

/**
 * @author apdoer
 * @version 1.0
 * @date 2021/3/22 19:26
 */
@Data
class TimerTaskEntry implements Comparable<TimerTaskEntry> {
  private TimerTask timerTask;
  private long expireMs;
  volatile TimerTaskList timedTaskList;
  TimerTaskEntry next;
  TimerTaskEntry prev;

  public TimerTaskEntry(TimerTask timedTask, long expireMs) {
    this.timerTask = timedTask;
    this.expireMs = expireMs;
    this.next = null;
    this.prev = null;
  }

  void remove() {
    TimerTaskList currentList = timedTaskList;
    while (currentList != null) {
      currentList.remove(this);
      currentList = timedTaskList;
    }
  }

  @Override
  public int compareTo(TimerTaskEntry o) {
    return ((int) (this.expireMs - o.expireMs));
  }
}

任務包裝類(這裡也可以將工作任務以線程變量的方式去傳入)

@Data
@Slf4j
class TimerTask implements Runnable {
  /**
   * 延時時間
   */
  private long delayMs;
  /**
   * 任務所在的entry
   */
  private TimerTaskEntry timerTaskEntry;

  private String desc;

  public TimerTask(String desc, long delayMs) {
    this.desc = desc;
    this.delayMs = delayMs;
    this.timerTaskEntry = null;
  }

  public synchronized void setTimerTaskEntry(TimerTaskEntry entry) {
    // 如果這個timetask已經被一個已存在的TimerTaskEntry持有,先移除一個
    if (timerTaskEntry != null && timerTaskEntry != entry) {
      timerTaskEntry.remove();
    }
    timerTaskEntry = entry;
  }

  public TimerTaskEntry getTimerTaskEntry() {
    return timerTaskEntry;
  }

  @Override
  public void run() {
    log.info("============={}任務執行", desc);
  }
}

以上就是本文的全部內容,希望對大傢的學習有所幫助,也希望大傢多多支持WalkonNet。