diff --git a/src/types.rs b/src/types.rs index deab6c3f..1a77c65c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -279,7 +279,7 @@ impl RateLimit { last_tick: std::time::Instant::now(), timer: crate::timer::PosixTimer::new_with_signal( std::time::Duration::from_secs(0), - std::time::Duration::from_secs(1), + std::time::Duration::from_millis(millis), nix::sys::signal::Signal::SIGALRM, ) .unwrap(), @@ -298,16 +298,14 @@ impl RateLimit { pub fn tick(&mut self) -> bool { let now = std::time::Instant::now(); - self.last_tick += self.rate; - if self.last_tick < now { - self.last_tick = now + self.rate; - } else if self.last_tick > now + self.millis { + if self.last_tick + self.rate > now { + self.active = false; + } else { self.timer.rearm(); + self.last_tick = now; self.active = true; - return false; } - self.active = false; - true + self.active } #[inline(always)] @@ -315,6 +313,121 @@ impl RateLimit { self.timer.si_value } } +#[test] +fn test_rate_limit() { + use std::sync::{Arc, Condvar, Mutex}; + /* RateLimit sends a SIGALRM with its timer value in siginfo_t. */ + let pair = Arc::new((Mutex::new(None), Condvar::new())); + let pair2 = pair.clone(); + + /* self-pipe trick: + * since we can only use signal-safe functions in the signal handler, make a pipe and + * write one byte to it from the handler. Counting the number of bytes in the pipe can tell + * us how many times the handler was called */ + let (alarm_pipe_r, alarm_pipe_w) = nix::unistd::pipe().unwrap(); + nix::fcntl::fcntl( + alarm_pipe_r, + nix::fcntl::FcntlArg::F_SETFL(nix::fcntl::OFlag::O_NONBLOCK), + ) + .expect("Could not set pipe to NONBLOCK?"); + + let alarm_handler = move |info: &nix::libc::siginfo_t| { + let value = unsafe { info.si_value().sival_ptr as u8 }; + let (lock, cvar) = &*pair2; + let mut started = lock.lock().unwrap(); + /* set mutex to timer value */ + *started = Some(value); + /* notify condvar in order to wake up the test thread */ + cvar.notify_all(); + nix::unistd::write(alarm_pipe_w, &[value]).expect("Could not write inside alarm handler?"); + }; + unsafe { + signal_hook_registry::register_sigaction(signal_hook::SIGALRM, alarm_handler).unwrap(); + } + /* Accept at most one request per 3 milliseconds */ + let mut rt = RateLimit::new(1, 3); + std::thread::sleep(std::time::Duration::from_millis(2000)); + let (lock, cvar) = &*pair; + let mut started = lock.lock().unwrap(); + let result = cvar + .wait_timeout(started, std::time::Duration::from_millis(100)) + .unwrap(); + /* assert that the handler was called with rt's timer id */ + assert_eq!(*result.0, Some(rt.id())); + drop(result); + drop(pair); + + let mut buf = [0; 1]; + nix::unistd::read(alarm_pipe_r, buf.as_mut()).expect("Could not read from self-pipe?"); + /* assert that only one request per 3 milliseconds is accepted */ + for _ in 0..5 { + assert!(rt.tick()); + std::thread::sleep(std::time::Duration::from_millis(1)); + assert!(!rt.tick()); + std::thread::sleep(std::time::Duration::from_millis(1)); + assert!(!rt.tick()); + std::thread::sleep(std::time::Duration::from_millis(1)); + /* How many times was the signal handler called? We've slept for at least 3 + * milliseconds, so it should have been called once */ + let mut ctr = 0; + while nix::unistd::read(alarm_pipe_r, buf.as_mut()) + .map(|s| s > 0) + .unwrap_or(false) + { + ctr += 1; + } + assert_eq!(ctr, 1); + } + /* next, test at most 100 requests per second */ + let mut rt = RateLimit::new(100, 1000); + for _ in 0..5 { + let mut ctr = 0; + for _ in 0..500 { + if rt.tick() { + ctr += 1; + } + std::thread::sleep(std::time::Duration::from_millis(2)); + } + /* around 100 requests should succeed. might be 99 if in first loop, since + * RateLimit::new() has a delay */ + assert!(ctr > 97 && ctr < 103); + /* alarm should expire in 1 second */ + std::thread::sleep(std::time::Duration::from_millis(1000)); + /* How many times was the signal handler called? */ + ctr = 0; + while nix::unistd::read(alarm_pipe_r, buf.as_mut()) + .map(|s| s > 0) + .unwrap_or(false) + { + ctr += 1; + } + assert_eq!(ctr, 1); + } + /* next, test at most 500 requests per second */ + let mut rt = RateLimit::new(500, 1000); + for _ in 0..5 { + let mut ctr = 0; + for _ in 0..500 { + if rt.tick() { + ctr += 1; + } + std::thread::sleep(std::time::Duration::from_millis(2)); + } + /* all requests should succeed. */ + assert!(ctr < 503 && ctr > 497); + /* alarm should expire in 1 second */ + std::thread::sleep(std::time::Duration::from_millis(1000)); + /* How many times was the signal handler called? */ + ctr = 0; + while nix::unistd::read(alarm_pipe_r, buf.as_mut()) + .map(|s| s > 0) + .unwrap_or(false) + { + ctr += 1; + } + assert_eq!(ctr, 1); + } +} #[derive(Debug)] pub enum ContactEvent {