Mercurial > public > sg101
comparison gpp/antispam/rate_limit.py @ 473:5e826e232932
Fixing #224; make sure we block IP's that have tripped the rate limiter or have been manually blocked.
author | Brian Neal <bgneal@gmail.com> |
---|---|
date | Sat, 27 Aug 2011 04:23:30 +0000 |
parents | 7c3816d76c6c |
children | 32cec6cd8808 |
comparison
equal
deleted
inserted
replaced
472:7c3816d76c6c | 473:5e826e232932 |
---|---|
15 HOST = getattr(settings, 'RATE_LIMIT_REDIS_HOST', 'localhost') | 15 HOST = getattr(settings, 'RATE_LIMIT_REDIS_HOST', 'localhost') |
16 PORT = getattr(settings, 'RATE_LIMIT_REDIS_PORT', 6379) | 16 PORT = getattr(settings, 'RATE_LIMIT_REDIS_PORT', 6379) |
17 DB = getattr(settings, 'RATE_LIMIT_REDIS_DB', 0) | 17 DB = getattr(settings, 'RATE_LIMIT_REDIS_DB', 0) |
18 | 18 |
19 | 19 |
20 class RateLimiterUnavailable(Exception): | |
21 pass | |
22 | |
23 | |
20 def _make_key(ip): | 24 def _make_key(ip): |
21 """ | 25 """ |
22 Creates and returns a key string from a given IP address. | 26 Creates and returns a key string from a given IP address. |
23 | 27 |
24 """ | 28 """ |
29 """ | 33 """ |
30 Create and return a Redis connection. Returns None on failure. | 34 Create and return a Redis connection. Returns None on failure. |
31 """ | 35 """ |
32 try: | 36 try: |
33 conn = redis.Redis(host=HOST, port=PORT, db=DB) | 37 conn = redis.Redis(host=HOST, port=PORT, db=DB) |
34 return conn | |
35 except redis.RedisError, e: | 38 except redis.RedisError, e: |
36 logger.error("rate limit: %s" % e) | 39 logger.error("rate limit: %s" % e) |
40 raise RateLimiterUnavailable | |
37 | 41 |
38 return None | 42 return conn |
39 | 43 |
40 | 44 |
41 def _to_seconds(interval): | 45 def _to_seconds(interval): |
42 """ | 46 """ |
43 Converts the timedelta interval object into a count of seconds. | 47 Converts the timedelta interval object into a count of seconds. |
44 | 48 |
45 """ | 49 """ |
46 return interval.days * 24 * 3600 + interval.seconds | 50 return interval.days * 24 * 3600 + interval.seconds |
47 | 51 |
48 | 52 |
49 def rate_check(ip, set_point, interval, lockout): | |
50 """ | |
51 This function performs a rate limit check. | |
52 One is added to a counter associated with the IP address. If the | |
53 counter exceeds set_point per interval, True is returned, and False | |
54 otherwise. If the set_point is exceeded for the first time, the counter | |
55 associated with the IP is set to expire according to the lockout parameter. | |
56 This locks the IP address as this function will then return True for the | |
57 period specified by lockout. | |
58 | |
59 """ | |
60 if not ip: | |
61 logger.error("rate_limit.rate_check could not get IP") | |
62 return False | |
63 key = _make_key(ip) | |
64 | |
65 conn = _get_connection() | |
66 if not conn: | |
67 return False | |
68 | |
69 val = conn.incr(key) | |
70 | |
71 # Set expire time, if necessary. | |
72 # If this is the first time, set it according to interval. | |
73 # If the set_point has just been exceeded, set it according to lockout. | |
74 if val == 1: | |
75 conn.expire(key, _to_seconds(interval)) | |
76 elif val == set_point: | |
77 conn.expire(key, _to_seconds(lockout)) | |
78 | |
79 tripped = val >= set_point | |
80 | |
81 if tripped: | |
82 logger.info("Rate limiter tripped for %s; counter = %d", ip, val) | |
83 | |
84 return tripped | |
85 | |
86 | |
87 def block_ip(ip, count=1000000, interval=datetime.timedelta(weeks=2)): | 53 def block_ip(ip, count=1000000, interval=datetime.timedelta(weeks=2)): |
88 """ | 54 """ |
89 This function jams the rate limit record for the given IP so that the IP is | 55 This function jams the rate limit record for the given IP so that the IP is |
90 blocked for the given interval. If the record doesn't exist, it is created. | 56 blocked for the given interval. If the record doesn't exist, it is created. |
91 This is useful for manually blocking an IP after detecting suspicious | 57 This is useful for manually blocking an IP after detecting suspicious |
92 behavior. | 58 behavior. |
59 This function may throw RateLimiterUnavailable. | |
93 | 60 |
94 """ | 61 """ |
95 if not ip: | |
96 logger.error("rate_limit.block_ip could not get IP") | |
97 return | |
98 | |
99 key = _make_key(ip) | 62 key = _make_key(ip) |
100 conn = _get_connection() | 63 conn = _get_connection() |
101 if not conn: | |
102 return | |
103 | 64 |
104 conn.setex(key, count, _to_seconds(interval)) | 65 conn.setex(key, count, _to_seconds(interval)) |
105 logger.info("Rate limiter blocked IP %s; %d / %s", ip, count, interval) | 66 logger.info("Rate limiter blocked IP %s; %d / %s", ip, count, interval) |
67 | |
68 | |
69 class RateLimiter(object): | |
70 """ | |
71 This class encapsulates the rate limiting logic for a given IP address. | |
72 | |
73 """ | |
74 def __init__(self, ip, set_point, interval, lockout): | |
75 self.ip = ip | |
76 self.set_point = set_point | |
77 self.interval = interval | |
78 self.lockout = lockout | |
79 self.key = _make_key(ip) | |
80 self.conn = _get_connection() | |
81 | |
82 def is_blocked(self): | |
83 """ | |
84 Return True if the IP is blocked, and false otherwise. | |
85 | |
86 """ | |
87 val = self.conn.get(self.key) | |
88 try: | |
89 val = int(val) if val else 0 | |
90 except ValueError: | |
91 return False | |
92 | |
93 blocked = val >= self.set_point | |
94 if blocked: | |
95 logger.info("Rate limiter blocking %s", self.ip) | |
96 | |
97 return blocked | |
98 | |
99 def incr(self): | |
100 """ | |
101 One is added to a counter associated with the IP address. If the | |
102 counter exceeds set_point per interval, True is returned, and False | |
103 otherwise. If the set_point is exceeded for the first time, the counter | |
104 associated with the IP is set to expire according to the lockout | |
105 parameter. | |
106 | |
107 """ | |
108 val = self.conn.incr(self.key) | |
109 | |
110 # Set expire time, if necessary. | |
111 # If this is the first time, set it according to interval. | |
112 # If the set_point has just been exceeded, set it according to lockout. | |
113 if val == 1: | |
114 self.conn.expire(self.key, _to_seconds(self.interval)) | |
115 elif val == self.set_point: | |
116 self.conn.expire(self.key, _to_seconds(self.lockout)) | |
117 | |
118 tripped = val >= self.set_point | |
119 | |
120 if tripped: | |
121 logger.info("Rate limiter tripped for %s; counter = %d", self.ip, val) | |
122 | |
123 return tripped |