bgneal@472: """
bgneal@472: This module contains the rate limiting functionality.
bgneal@472: 
bgneal@472: """
bgneal@472: import datetime
bgneal@472: import logging
bgneal@472: 
bgneal@472: import redis
bgneal@472: 
bgneal@508: from core.services import get_redis_connection
bgneal@508: 
bgneal@472: 
bgneal@472: logger = logging.getLogger(__name__)
bgneal@472: 
bgneal@472: 
bgneal@479: # This exception is thrown upon any Redis error. This insulates client code from
bgneal@479: # knowing that we are using Redis and will allow us to use something else in the
bgneal@479: # future.
bgneal@473: class RateLimiterUnavailable(Exception):
bgneal@473:     pass
bgneal@473: 
bgneal@473: 
bgneal@472: def _make_key(ip):
bgneal@472:     """
bgneal@472:     Creates and returns a key string from a given IP address.
bgneal@472: 
bgneal@472:     """
bgneal@472:     return 'rate-limit-' + ip
bgneal@472: 
bgneal@472: 
bgneal@472: def _get_connection():
bgneal@472:     """
bgneal@472:     Create and return a Redis connection. Returns None on failure.
bgneal@472:     """
bgneal@472:     try:
bgneal@508:         conn = get_redis_connection()
bgneal@472:     except redis.RedisError, e:
bgneal@472:         logger.error("rate limit: %s" % e)
bgneal@473:         raise RateLimiterUnavailable
bgneal@472: 
bgneal@473:     return conn
bgneal@472: 
bgneal@472: 
bgneal@472: def _to_seconds(interval):
bgneal@472:     """
bgneal@472:     Converts the timedelta interval object into a count of seconds.
bgneal@472: 
bgneal@472:     """
bgneal@472:     return interval.days * 24 * 3600 + interval.seconds
bgneal@472: 
bgneal@472: 
bgneal@472: def block_ip(ip, count=1000000, interval=datetime.timedelta(weeks=2)):
bgneal@472:     """
bgneal@472:     This function jams the rate limit record for the given IP so that the IP is
bgneal@472:     blocked for the given interval. If the record doesn't exist, it is created.
bgneal@472:     This is useful for manually blocking an IP after detecting suspicious
bgneal@472:     behavior.
bgneal@473:     This function may throw RateLimiterUnavailable.
bgneal@472: 
bgneal@472:     """
bgneal@472:     key = _make_key(ip)
bgneal@472:     conn = _get_connection()
bgneal@472: 
bgneal@479:     try:
bgneal@479:         conn.setex(key, count, _to_seconds(interval))
bgneal@479:     except redis.RedisError, e:
bgneal@479:         logger.error("rate limit (block_ip): %s" % e)
bgneal@479:         raise RateLimiterUnavailable
bgneal@479: 
bgneal@472:     logger.info("Rate limiter blocked IP %s; %d / %s", ip, count, interval)
bgneal@473: 
bgneal@473: 
bgneal@565: def unblock_ip(ip):
bgneal@565:     """
bgneal@565:     This function removes the block for the given IP address.
bgneal@565: 
bgneal@565:     """
bgneal@565:     key = _make_key(ip)
bgneal@565:     conn = _get_connection()
bgneal@565:     try:
bgneal@565:         conn.delete(key)
bgneal@565:     except redis.RedisError, e:
bgneal@565:         logger.error("rate limit (unblock_ip): %s" % e)
bgneal@565:         raise RateLimiterUnavailable
bgneal@565: 
bgneal@565:     logger.info("Rate limiter unblocked IP %s", ip)
bgneal@565: 
bgneal@565: 
bgneal@473: class RateLimiter(object):
bgneal@473:     """
bgneal@473:     This class encapsulates the rate limiting logic for a given IP address.
bgneal@473: 
bgneal@473:     """
bgneal@473:     def __init__(self, ip, set_point, interval, lockout):
bgneal@473:         self.ip = ip
bgneal@473:         self.set_point = set_point
bgneal@473:         self.interval = interval
bgneal@473:         self.lockout = lockout
bgneal@473:         self.key = _make_key(ip)
bgneal@473:         self.conn = _get_connection()
bgneal@473: 
bgneal@473:     def is_blocked(self):
bgneal@473:         """
bgneal@473:         Return True if the IP is blocked, and false otherwise.
bgneal@473: 
bgneal@473:         """
bgneal@479:         try:
bgneal@479:             val = self.conn.get(self.key)
bgneal@479:         except redis.RedisError, e:
bgneal@479:             logger.error("RateLimiter (is_blocked): %s" % e)
bgneal@479:             raise RateLimiterUnavailable
bgneal@479: 
bgneal@473:         try:
bgneal@473:             val = int(val) if val else 0
bgneal@473:         except ValueError:
bgneal@473:             return False
bgneal@473: 
bgneal@473:         blocked = val >= self.set_point
bgneal@473:         if blocked:
bgneal@473:             logger.info("Rate limiter blocking %s", self.ip)
bgneal@473: 
bgneal@473:         return blocked
bgneal@473: 
bgneal@473:     def incr(self):
bgneal@473:         """
bgneal@473:         One is added to a counter associated with the IP address. If the
bgneal@473:         counter exceeds set_point per interval, True is returned, and False
bgneal@473:         otherwise. If the set_point is exceeded for the first time, the counter
bgneal@473:         associated with the IP is set to expire according to the lockout
bgneal@473:         parameter.
bgneal@473: 
bgneal@473:         """
bgneal@479:         try:
bgneal@479:             val = self.conn.incr(self.key)
bgneal@473: 
bgneal@479:             # Set expire time, if necessary.
bgneal@479:             # If this is the first time, set it according to interval.
bgneal@479:             # If the set_point has just been exceeded, set it according to lockout.
bgneal@479:             if val == 1:
bgneal@479:                 self.conn.expire(self.key, _to_seconds(self.interval))
bgneal@479:             elif val == self.set_point:
bgneal@479:                 self.conn.expire(self.key, _to_seconds(self.lockout))
bgneal@473: 
bgneal@479:             tripped = val >= self.set_point
bgneal@473: 
bgneal@479:             if tripped:
bgneal@479:                 logger.info("Rate limiter tripped for %s; counter = %d", self.ip, val)
bgneal@479:             return tripped
bgneal@473: 
bgneal@479:         except redis.RedisError, e:
bgneal@479:             logger.error("RateLimiter (incr): %s" % e)
bgneal@479:             raise RateLimiterUnavailable