# HG changeset patch # User Brian Neal # Date 1314419010 0 # Node ID 5e826e2329322ec8e6c9a57e9642221711ad82bb # Parent 7c3816d76c6c1da32d15735e28290b5e6cad3f0b Fixing #224; make sure we block IP's that have tripped the rate limiter or have been manually blocked. diff -r 7c3816d76c6c -r 5e826e232932 gpp/antispam/decorators.py --- a/gpp/antispam/decorators.py Thu Aug 25 02:23:55 2011 +0000 +++ b/gpp/antispam/decorators.py Sat Aug 27 04:23:30 2011 +0000 @@ -7,7 +7,7 @@ from django.shortcuts import render -from antispam.rate_limit import rate_check +from antispam.rate_limit import RateLimiter, RateLimiterUnavailable def rate_limit(count=10, interval=timedelta(minutes=1), @@ -18,15 +18,23 @@ @wraps(fn) def wrapped(request, *args, **kwargs): + ip = request.META.get('REMOTE_ADDR') + try: + rate_limiter = RateLimiter(ip, count, interval, lockout) + except RateLimiterUnavailable: + # just call the function and return the result + return fn(request, *args, **kwargs) + + if rate_limiter.is_blocked(): + return render(request, 'antispam/blocked.html', status=403) + response = fn(request, *args, **kwargs) if request.method == 'POST': success = (response and response.has_header('location') and response.status_code == 302) - if not success: - ip = request.META.get('REMOTE_ADDR') - if rate_check(ip, count, interval, lockout): - return render(request, 'antispam/blocked.html', status=403) + if not success and rate_limiter.incr(): + return render(request, 'antispam/blocked.html', status=403) return response diff -r 7c3816d76c6c -r 5e826e232932 gpp/antispam/rate_limit.py --- a/gpp/antispam/rate_limit.py Thu Aug 25 02:23:55 2011 +0000 +++ b/gpp/antispam/rate_limit.py Sat Aug 27 04:23:30 2011 +0000 @@ -17,6 +17,10 @@ DB = getattr(settings, 'RATE_LIMIT_REDIS_DB', 0) +class RateLimiterUnavailable(Exception): + pass + + def _make_key(ip): """ Creates and returns a key string from a given IP address. @@ -31,11 +35,11 @@ """ try: conn = redis.Redis(host=HOST, port=PORT, db=DB) - return conn except redis.RedisError, e: logger.error("rate limit: %s" % e) + raise RateLimiterUnavailable - return None + return conn def _to_seconds(interval): @@ -46,60 +50,74 @@ return interval.days * 24 * 3600 + interval.seconds -def rate_check(ip, set_point, interval, lockout): - """ - This function performs a rate limit check. - One is added to a counter associated with the IP address. If the - counter exceeds set_point per interval, True is returned, and False - otherwise. If the set_point is exceeded for the first time, the counter - associated with the IP is set to expire according to the lockout parameter. - This locks the IP address as this function will then return True for the - period specified by lockout. - - """ - if not ip: - logger.error("rate_limit.rate_check could not get IP") - return False - key = _make_key(ip) - - conn = _get_connection() - if not conn: - return False - - val = conn.incr(key) - - # Set expire time, if necessary. - # If this is the first time, set it according to interval. - # If the set_point has just been exceeded, set it according to lockout. - if val == 1: - conn.expire(key, _to_seconds(interval)) - elif val == set_point: - conn.expire(key, _to_seconds(lockout)) - - tripped = val >= set_point - - if tripped: - logger.info("Rate limiter tripped for %s; counter = %d", ip, val) - - return tripped - - def block_ip(ip, count=1000000, interval=datetime.timedelta(weeks=2)): """ This function jams the rate limit record for the given IP so that the IP is blocked for the given interval. If the record doesn't exist, it is created. This is useful for manually blocking an IP after detecting suspicious behavior. + This function may throw RateLimiterUnavailable. """ - if not ip: - logger.error("rate_limit.block_ip could not get IP") - return - key = _make_key(ip) conn = _get_connection() - if not conn: - return conn.setex(key, count, _to_seconds(interval)) logger.info("Rate limiter blocked IP %s; %d / %s", ip, count, interval) + + +class RateLimiter(object): + """ + This class encapsulates the rate limiting logic for a given IP address. + + """ + def __init__(self, ip, set_point, interval, lockout): + self.ip = ip + self.set_point = set_point + self.interval = interval + self.lockout = lockout + self.key = _make_key(ip) + self.conn = _get_connection() + + def is_blocked(self): + """ + Return True if the IP is blocked, and false otherwise. + + """ + val = self.conn.get(self.key) + try: + val = int(val) if val else 0 + except ValueError: + return False + + blocked = val >= self.set_point + if blocked: + logger.info("Rate limiter blocking %s", self.ip) + + return blocked + + def incr(self): + """ + One is added to a counter associated with the IP address. If the + counter exceeds set_point per interval, True is returned, and False + otherwise. If the set_point is exceeded for the first time, the counter + associated with the IP is set to expire according to the lockout + parameter. + + """ + val = self.conn.incr(self.key) + + # Set expire time, if necessary. + # If this is the first time, set it according to interval. + # If the set_point has just been exceeded, set it according to lockout. + if val == 1: + self.conn.expire(self.key, _to_seconds(self.interval)) + elif val == self.set_point: + self.conn.expire(self.key, _to_seconds(self.lockout)) + + tripped = val >= self.set_point + + if tripped: + logger.info("Rate limiter tripped for %s; counter = %d", self.ip, val) + + return tripped