1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use std::{
    cmp::Ordering,
    collections::BinaryHeap,
    fmt::{Debug, Formatter},
};

use log::warn;
use tokio::time::{sleep, Duration, Instant};

use crate::sync::LOG_TARGET;

#[derive(Clone)]
struct ScheduledTask<T> {
    task: T,
    scheduled_time: Instant,
}

impl<T> Eq for ScheduledTask<T> {}

impl<T> PartialEq for ScheduledTask<T> {
    fn eq(&self, other: &Self) -> bool {
        other.scheduled_time.eq(&self.scheduled_time)
    }
}

impl<T> PartialOrd for ScheduledTask<T> {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

impl<T> Ord for ScheduledTask<T> {
    /// Compare tasks so that earlier times come first in a max-heap.
    fn cmp(&self, other: &Self) -> Ordering {
        other.scheduled_time.cmp(&self.scheduled_time)
    }
}

#[derive(Clone, Default)]
pub struct TaskQueue<T> {
    queue: BinaryHeap<ScheduledTask<T>>,
}

impl<T> Debug for TaskQueue<T> {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TaskQueue")
            .field("task count", &self.queue.len())
            .finish()
    }
}

/// Implements a queue allowing for scheduling tasks for some time in the future.
///
/// Does not actually execute any tasks, is used for ordering in time only.
impl<T> TaskQueue<T> {
    /// Creates an empty queue.
    pub fn new() -> Self {
        Self {
            queue: BinaryHeap::new(),
        }
    }

    /// Schedules `task` for after `delay`.
    pub fn schedule_in(&mut self, task: T, delay: Duration) {
        let scheduled_time = match Instant::now().checked_add(delay) {
            Some(time) => time,
            None => {
                warn!(
                    target: LOG_TARGET,
                    "Could not schedule task in {:?}. Instant out of bound.", delay
                );
                return;
            }
        };
        self.queue.push(ScheduledTask {
            task,
            scheduled_time,
        });
    }

    /// Awaits for the first and most overdue task and returns it. Returns `None` if there are no tasks.
    ///
    /// # Cancel safety
    ///
    /// This method is cancellation safe.
    /// If you use it as the event in a tokio::select! statement and some other branch completes first,
    /// then it is guaranteed that the TaskQueue state will be unchanged.
    pub async fn pop(&mut self) -> Option<T> {
        self.sleep_until_the_next_task_is_ready().await;
        self.queue.pop().map(|t| t.task)
    }

    /// Sleeps until some task is ready to be executed,
    /// or returns immediately if there are no tasks.
    /// Cancellation safe, since doesn't mutate &self.
    async fn sleep_until_the_next_task_is_ready(&self) {
        if let Some(scheduled_task) = self.queue.peek() {
            let duration = scheduled_task
                .scheduled_time
                .saturating_duration_since(Instant::now());
            if !duration.is_zero() {
                sleep(duration).await;
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use tokio::time::{timeout, Duration};

    use super::TaskQueue;

    #[tokio::test]
    async fn test_scheduling() {
        let mut q = TaskQueue::new();
        q.schedule_in(2, Duration::from_millis(50));
        q.schedule_in(1, Duration::from_millis(20));

        assert!(timeout(Duration::from_millis(5), q.pop()).await.is_err());
        assert_eq!(
            timeout(Duration::from_millis(20), q.pop()).await,
            Ok(Some(1))
        );
        assert!(timeout(Duration::from_millis(10), q.pop()).await.is_err());
        assert_eq!(
            timeout(Duration::from_millis(50), q.pop()).await,
            Ok(Some(2))
        );
    }
}