changeset 473:5e826e232932

Fixing #224; make sure we block IP's that have tripped the rate limiter or have been manually blocked.
author Brian Neal <bgneal@gmail.com>
date Sat, 27 Aug 2011 04:23:30 +0000
parents 7c3816d76c6c
children 501dfb88035d
files gpp/antispam/decorators.py gpp/antispam/rate_limit.py
diffstat 2 files changed, 77 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/gpp/antispam/decorators.py	Thu Aug 25 02:23:55 2011 +0000
+++ b/gpp/antispam/decorators.py	Sat Aug 27 04:23:30 2011 +0000
@@ -7,7 +7,7 @@
 
 from django.shortcuts import render
 
-from antispam.rate_limit import rate_check
+from antispam.rate_limit import RateLimiter, RateLimiterUnavailable
 
 
 def rate_limit(count=10, interval=timedelta(minutes=1),
@@ -18,15 +18,23 @@
         @wraps(fn)
         def wrapped(request, *args, **kwargs):
 
+            ip = request.META.get('REMOTE_ADDR')
+            try:
+                rate_limiter = RateLimiter(ip, count, interval, lockout)
+            except RateLimiterUnavailable:
+                # just call the function and return the result
+                return fn(request, *args, **kwargs)
+
+            if rate_limiter.is_blocked():
+                return render(request, 'antispam/blocked.html', status=403)
+
             response = fn(request, *args, **kwargs)
 
             if request.method == 'POST':
                 success = (response and response.has_header('location') and
                         response.status_code == 302)
-                if not success:
-                    ip = request.META.get('REMOTE_ADDR')
-                    if rate_check(ip, count, interval, lockout):
-                        return render(request, 'antispam/blocked.html', status=403)
+                if not success and rate_limiter.incr():
+                    return render(request, 'antispam/blocked.html', status=403)
 
             return response
 
--- a/gpp/antispam/rate_limit.py	Thu Aug 25 02:23:55 2011 +0000
+++ b/gpp/antispam/rate_limit.py	Sat Aug 27 04:23:30 2011 +0000
@@ -17,6 +17,10 @@
 DB = getattr(settings, 'RATE_LIMIT_REDIS_DB', 0)
 
 
+class RateLimiterUnavailable(Exception):
+    pass
+
+
 def _make_key(ip):
     """
     Creates and returns a key string from a given IP address.
@@ -31,11 +35,11 @@
     """
     try:
         conn = redis.Redis(host=HOST, port=PORT, db=DB)
-        return conn
     except redis.RedisError, e:
         logger.error("rate limit: %s" % e)
+        raise RateLimiterUnavailable
 
-    return None
+    return conn
 
 
 def _to_seconds(interval):
@@ -46,60 +50,74 @@
     return interval.days * 24 * 3600 + interval.seconds
 
 
-def rate_check(ip, set_point, interval, lockout):
-    """
-    This function performs a rate limit check.
-    One is added to a counter associated with the IP address. If the
-    counter exceeds set_point per interval, True is returned, and False
-    otherwise. If the set_point is exceeded for the first time, the counter
-    associated with the IP is set to expire according to the lockout parameter.
-    This locks the IP address as this function will then return True for the
-    period specified by lockout.
-
-    """
-    if not ip:
-        logger.error("rate_limit.rate_check could not get IP")
-        return False
-    key = _make_key(ip)
-
-    conn = _get_connection()
-    if not conn:
-        return False
-
-    val = conn.incr(key)
-
-    # Set expire time, if necessary.
-    # If this is the first time, set it according to interval.
-    # If the set_point has just been exceeded, set it according to lockout.
-    if val == 1:
-        conn.expire(key, _to_seconds(interval))
-    elif val == set_point:
-        conn.expire(key, _to_seconds(lockout))
-
-    tripped = val >= set_point
-
-    if tripped:
-        logger.info("Rate limiter tripped for %s; counter = %d", ip, val)
-
-    return tripped
-
-
 def block_ip(ip, count=1000000, interval=datetime.timedelta(weeks=2)):
     """
     This function jams the rate limit record for the given IP so that the IP is
     blocked for the given interval. If the record doesn't exist, it is created.
     This is useful for manually blocking an IP after detecting suspicious
     behavior.
+    This function may throw RateLimiterUnavailable.
 
     """
-    if not ip:
-        logger.error("rate_limit.block_ip could not get IP")
-        return
-
     key = _make_key(ip)
     conn = _get_connection()
-    if not conn:
-        return
 
     conn.setex(key, count, _to_seconds(interval))
     logger.info("Rate limiter blocked IP %s; %d / %s", ip, count, interval)
+
+
+class RateLimiter(object):
+    """
+    This class encapsulates the rate limiting logic for a given IP address.
+
+    """
+    def __init__(self, ip, set_point, interval, lockout):
+        self.ip = ip
+        self.set_point = set_point
+        self.interval = interval
+        self.lockout = lockout
+        self.key = _make_key(ip)
+        self.conn = _get_connection()
+
+    def is_blocked(self):
+        """
+        Return True if the IP is blocked, and false otherwise.
+
+        """
+        val = self.conn.get(self.key)
+        try:
+            val = int(val) if val else 0
+        except ValueError:
+            return False
+
+        blocked = val >= self.set_point
+        if blocked:
+            logger.info("Rate limiter blocking %s", self.ip)
+
+        return blocked
+
+    def incr(self):
+        """
+        One is added to a counter associated with the IP address. If the
+        counter exceeds set_point per interval, True is returned, and False
+        otherwise. If the set_point is exceeded for the first time, the counter
+        associated with the IP is set to expire according to the lockout
+        parameter.
+
+        """
+        val = self.conn.incr(self.key)
+
+        # Set expire time, if necessary.
+        # If this is the first time, set it according to interval.
+        # If the set_point has just been exceeded, set it according to lockout.
+        if val == 1:
+            self.conn.expire(self.key, _to_seconds(self.interval))
+        elif val == self.set_point:
+            self.conn.expire(self.key, _to_seconds(self.lockout))
+
+        tripped = val >= self.set_point
+
+        if tripped:
+            logger.info("Rate limiter tripped for %s; counter = %d", self.ip, val)
+
+        return tripped