adding ratelimiting middleware and unit tests

This commit is contained in:
David Goetz 2010-10-04 14:11:48 -07:00
parent 8ed75703fc
commit 72d40bd9f6
6 changed files with 648 additions and 98 deletions

View File

@ -56,3 +56,15 @@ use = egg:swift#memcache
# Default for memcache_servers is below, but you can specify multiple servers # Default for memcache_servers is below, but you can specify multiple servers
# with the format: 10.1.2.3:11211,10.1.2.4:11211 # with the format: 10.1.2.3:11211,10.1.2.4:11211
# memcache_servers = 127.0.0.1:11211 # memcache_servers = 127.0.0.1:11211
[filter:ratelimit]
use = egg:swift#ratelimit
account_ratelimit = 200
account_whitelist = a,b
# account_blacklist =
# with container_limit_x = r
# for containers of size x limit requests per second to r
container_limit_0 = 100
container_limit_10 = 50
container_limit_50 = 10

View File

@ -88,6 +88,7 @@ setup(
'auth=swift.common.middleware.auth:filter_factory', 'auth=swift.common.middleware.auth:filter_factory',
'healthcheck=swift.common.middleware.healthcheck:filter_factory', 'healthcheck=swift.common.middleware.healthcheck:filter_factory',
'memcache=swift.common.middleware.memcache:filter_factory', 'memcache=swift.common.middleware.memcache:filter_factory',
# 'ratelimit=swift.common.middeware.ratelimit:filter_factory',
], ],
}, },
) )

View File

@ -0,0 +1,198 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from webob import Request, Response
from ConfigParser import ConfigParser, NoOptionError
from swift.common.utils import split_path, cache_from_env, get_logger
from swift.proxy.server import get_container_memcache_key
class MaxSleepTimeHit(Exception):
pass
class RateLimitMiddleware(object):
"""
Rate limiting middleware
"""
def __init__(self, app, conf, logger=None):
self.app = app
self.logger = logger
if logger is None:
self.logger = get_logger(conf)
else:
self.logger = logger
self.account_rate_limit = float(conf.get('account_ratelimit', 1))#200.0))
self.max_sleep_time_seconds = int(conf.get('max_sleep_time_seconds',
2))#60))
self.clock_accuracy = int(conf.get('clock_accuracy', 1000))
self.rate_limit_whitelist = [acc.strip() for acc in
conf.get('account_whitelist', '').split(',')
if acc.strip()]
self.rate_limit_blacklist = [acc.strip() for acc in
conf.get('account_blacklist', '').split(',')
if acc.strip()]
self.memcache_client = None
conf_limits = []
for conf_key in conf.keys():
if conf_key.startswith('container_limit_'):
cont_size = int(conf_key[len('container_limit_'):])
rate = float(conf[conf_key])
conf_limits.append((cont_size,rate))
conf_limits.sort()
self.container_limits = []
while conf_limits:
cur_size, cur_rate = conf_limits.pop(0)
if conf_limits:
# figure out slope for function between this point and next
next_size, next_rate = conf_limits[0]
slope = (float(next_rate) - float(cur_rate)) \
/ (next_size - cur_size)
def new_scope(cur_size, slope, cur_rate):
# making new scope for variables
return lambda x: (x - cur_size) * slope + cur_rate
line_func = new_scope(cur_size, slope, cur_rate)
else:
# don't have to worry about scope here- this is the last
# element in the list
line_func = lambda x : cur_rate
self.container_limits.append((cur_size, cur_rate, line_func))
def get_container_maxrate(self, container_size):
"""
Will figure out the max_rate for a container size
"""
last_func = None
if container_size:
for size, rate, func in self.container_limits:
if container_size < size:
break
last_func = func
if last_func:
return last_func(container_size)
return None
def _generate_key_rate_tuples(self, account_name, container_name, obj_name):
"""
Returns a list of keys (to be used in memcache) that can be
generated given a path. Keys should be checked in order.
:param path: path from request
"""
keys = []
if account_name:
keys.append(("ratelimit/%s" % account_name,
self.account_rate_limit))
if account_name and container_name and not obj_name:
container_size = None
memcache_key = get_container_memcache_key(account_name,
container_name)
container_info = self.memcache_client.get(memcache_key)
if type(container_info) == dict:
container_size = container_info.get('container_size')
container_rate = self.get_container_maxrate(container_size)
if container_rate:
keys.append(("ratelimit/%s/%s" % (account_name,
container_name),
container_rate))
return keys
def _get_sleep_time(self, key, max_rate):
now_m = int(round(time.time() * self.clock_accuracy))
time_per_request_m = int(round(self.clock_accuracy / max_rate))
running_time_m = self.memcache_client.incr(key,
delta=time_per_request_m)
need_to_sleep_m = 0
request_time_limit = now_m + (time_per_request_m * max_rate)
if running_time_m < now_m:
next_avail_time = int(now_m + time_per_request_m)
self.memcache_client.set(key, str(next_avail_time),
serialize=False)
elif running_time_m - now_m - time_per_request_m > 0:
#running_time_m > request_time_limit:
need_to_sleep_m = running_time_m - now_m - time_per_request_m
max_sleep_m = self.max_sleep_time_seconds * self.clock_accuracy
if max_sleep_m - need_to_sleep_m <= self.clock_accuracy * 0.01:
# make it accurate to 1% of clock accuracy
# treat as no-op decrement time
self.memcache_client.decr(key, delta=time_per_request_m)
raise MaxSleepTimeHit("Max Sleep Time Exceeded: %s" %
need_to_sleep_m)
return float(need_to_sleep_m) / self.clock_accuracy
def handle_rate_limit(self, req, account_name, container_name, obj_name,
name=None):
if account_name in self.rate_limit_blacklist:
self.logger.error('Returning 497 because of blacklisting')
return Response(status='497 Blacklisted',
body='Your account has been blacklisted', request=req)
if account_name in self.rate_limit_whitelist:
return None
for key, max_rate in self._generate_key_rate_tuples(account_name,
container_name,
obj_name):
try:
need_to_sleep = self._get_sleep_time(key,
max_rate)
if need_to_sleep > 0:
time.sleep(need_to_sleep)
except MaxSleepTimeHit, e:
self.logger.error('Returning 498 because of ops ' + \
'rate limiting (Max Sleep) %s' % e)
error_resp = Response(status='498 Rate Limited',
body='Slow down', request=req)
return error_resp
return None
def __call__(self, env, start_response, name=None):
req = Request(env)
if self.memcache_client is None:
self.memcache_client = cache_from_env(env)
version, account, container, obj = split_path(req.path, 1, 4, True)
rate_limit_resp = self.handle_rate_limit(req, account, container,
obj, name=name)
if rate_limit_resp is None:
return self.app(env, start_response)
else:
return rate_limit_resp(env, start_response)
def filter_factory(global_conf, **local_conf):
conf = global_conf.copy()
conf.update(local_conf)
def limit_filter(app):
return RateLimitMiddleware(app, conf)
return limit_filter

View File

@ -88,6 +88,10 @@ def delay_denial(func):
return func(*a, **kw) return func(*a, **kw)
return wrapped return wrapped
def get_container_memcache_key(account, container):
path = '/%s/%s' % (account, container)
return 'container%s' % path
class Controller(object): class Controller(object):
"""Base WSGI controller class for the proxy""" """Base WSGI controller class for the proxy"""
@ -228,15 +232,22 @@ class Controller(object):
""" """
partition, nodes = self.app.container_ring.get_nodes( partition, nodes = self.app.container_ring.get_nodes(
account, container) account, container)
path = '/%s/%s' % (account, container) path = '/%s/%s' % (account, container)
cache_key = 'container%s' % path cache_key = get_container_memcache_key(account, container)
# Older memcache values (should be treated as if they aren't there): # Older memcache values (should be treated as if they aren't there):
# 0 = no responses, 200 = found, 404 = not found, -1 = mixed responses # 0 = no responses, 200 = found, 404 = not found, -1 = mixed responses
# Newer memcache values: # Newer memcache values:
# [older status value from above, read acl, write acl] # [older status value from above, read acl, write acl]
cache_value = self.app.memcache.get(cache_key) cache_value = self.app.memcache.get(cache_key)
if hasattr(cache_value, '__iter__'): if hasattr(cache_value, '__iter__'):
status, read_acl, write_acl = cache_value if type(cache_value) == dict:
status = cache_value['status']
read_acl = cache_value['read_acl']
write_acl = cache_value['write_acl']
else:
status, read_acl, write_acl = cache_value
if status == 200: if status == 200:
return partition, nodes, read_acl, write_acl return partition, nodes, read_acl, write_acl
if not self.account_info(account)[1]: if not self.account_info(account)[1]:
@ -244,6 +255,7 @@ class Controller(object):
result_code = 0 result_code = 0
read_acl = None read_acl = None
write_acl = None write_acl = None
container_size = None
attempts_left = self.app.container_ring.replica_count attempts_left = self.app.container_ring.replica_count
headers = {'x-cf-trans-id': self.trans_id} headers = {'x-cf-trans-id': self.trans_id}
for node in self.iter_nodes(partition, nodes, self.app.container_ring): for node in self.iter_nodes(partition, nodes, self.app.container_ring):
@ -260,6 +272,8 @@ class Controller(object):
result_code = 200 result_code = 200
read_acl = resp.getheader('x-container-read') read_acl = resp.getheader('x-container-read')
write_acl = resp.getheader('x-container-write') write_acl = resp.getheader('x-container-write')
container_size = \
resp.getheader('X-Container-Object-Count')
break break
elif resp.status == 404: elif resp.status == 404:
result_code = 404 if not result_code else -1 result_code = 404 if not result_code else -1
@ -278,7 +292,10 @@ class Controller(object):
cache_timeout = self.app.recheck_container_existence cache_timeout = self.app.recheck_container_existence
else: else:
cache_timeout = self.app.recheck_container_existence * 0.1 cache_timeout = self.app.recheck_container_existence * 0.1
self.app.memcache.set(cache_key, (result_code, read_acl, write_acl), self.app.memcache.set(cache_key, {'status': result_code,
'read_acl': read_acl,
'write_acl': write_acl,
'container_size': container_size},
timeout=cache_timeout) timeout=cache_timeout)
if result_code == 200: if result_code == 200:
return partition, nodes, read_acl, write_acl return partition, nodes, read_acl, write_acl
@ -941,6 +958,8 @@ class ContainerController(Controller):
statuses.append(503) statuses.append(503)
reasons.append('') reasons.append('')
bodies.append('') bodies.append('')
#TODO : David - does this need to be using the
# get_container_memcache_key function????
self.app.memcache.delete('container%s' % req.path_info.rstrip('/')) self.app.memcache.delete('container%s' % req.path_info.rstrip('/'))
return self.best_response(req, statuses, reasons, bodies, return self.best_response(req, statuses, reasons, bodies,
'Container PUT') 'Container PUT')
@ -1214,14 +1233,6 @@ class BaseApplication(object):
self.account_ring = account_ring or \ self.account_ring = account_ring or \
Ring(os.path.join(swift_dir, 'account.ring.gz')) Ring(os.path.join(swift_dir, 'account.ring.gz'))
self.memcache = memcache self.memcache = memcache
self.rate_limit = float(conf.get('rate_limit', 20000.0))
self.account_rate_limit = float(conf.get('account_rate_limit', 200.0))
self.rate_limit_whitelist = [x.strip() for x in
conf.get('rate_limit_account_whitelist', '').split(',')
if x.strip()]
self.rate_limit_blacklist = [x.strip() for x in
conf.get('rate_limit_account_blacklist', '').split(',')
if x.strip()]
def get_controller(self, path): def get_controller(self, path):
""" """
@ -1302,10 +1313,6 @@ class BaseApplication(object):
return HTTPPreconditionFailed(request=req, body='Invalid UTF8') return HTTPPreconditionFailed(request=req, body='Invalid UTF8')
if not controller: if not controller:
return HTTPPreconditionFailed(request=req, body='Bad URL') return HTTPPreconditionFailed(request=req, body='Bad URL')
rate_limit_allowed_err_resp = \
self.check_rate_limit(req, path_parts)
if rate_limit_allowed_err_resp is not None:
return rate_limit_allowed_err_resp
controller = controller(self, **path_parts) controller = controller(self, **path_parts)
controller.trans_id = req.headers.get('x-cf-trans-id', '-') controller.trans_id = req.headers.get('x-cf-trans-id', '-')
@ -1339,10 +1346,6 @@ class BaseApplication(object):
self.logger.exception('ERROR Unhandled exception in request') self.logger.exception('ERROR Unhandled exception in request')
return HTTPServerError(request=req) return HTTPServerError(request=req)
def check_rate_limit(self, req, path_parts):
"""Check for rate limiting."""
return None
class Application(BaseApplication): class Application(BaseApplication):
"""WSGI application for the proxy server.""" """WSGI application for the proxy server."""
@ -1395,46 +1398,6 @@ class Application(BaseApplication):
trans_time, trans_time,
))) )))
def check_rate_limit(self, req, path_parts):
"""
Check for rate limiting.
:param req: webob.Request object
:param path_parts: parsed path dictionary
"""
if path_parts['account_name'] in self.rate_limit_blacklist:
self.logger.error('Returning 497 because of blacklisting')
return Response(status='497 Blacklisted',
body='Your account has been blacklisted', request=req)
if path_parts['account_name'] not in self.rate_limit_whitelist:
current_second = time.strftime('%x%H%M%S')
general_rate_limit_key = '%s%s' % (path_parts['account_name'],
current_second)
ops_count = self.memcache.incr(general_rate_limit_key, timeout=2)
if ops_count > self.rate_limit:
self.logger.error(
'Returning 498 because of ops rate limiting')
return Response(status='498 Rate Limited',
body='Slow down', request=req)
elif (path_parts['container_name']
and not path_parts['object_name']) \
or \
(path_parts['account_name']
and not path_parts['container_name']):
# further limit operations on a single account or container
rate_limit_key = '%s%s%s' % (path_parts['account_name'],
path_parts['container_name'] or '-',
current_second)
ops_count = self.memcache.incr(rate_limit_key, timeout=2)
if ops_count > self.account_rate_limit:
self.logger.error(
'Returning 498 because of account and container'
' rate limiting')
return Response(status='498 Rate Limited',
body='Slow down', request=req)
return None
def app_factory(global_conf, **local_conf): def app_factory(global_conf, **local_conf):
"""paste.deploy app factory for creating WSGI proxy apps.""" """paste.deploy app factory for creating WSGI proxy apps."""
conf = global_conf.copy() conf = global_conf.copy()

View File

@ -0,0 +1,412 @@
# Copyright (c) 2010 OpenStack, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import time
from contextlib import contextmanager
from threading import Thread
import eventlet
from webob import Request
from swift.common.middleware import ratelimit
from swift.proxy.server import get_container_memcache_key
# mocks
#logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
class FakeMemcache(object):
def __init__(self):
self.store = {}
def get(self, key):
return self.store.get(key)
def set(self, key, value, serialize=False, timeout=0):
self.store[key] = value
return True
def incr(self, key, delta=1, timeout=0):
self.store[key] = int(self.store.setdefault(key, 0)) + delta
return int(self.store[key])
def decr(self, key, delta=1, timeout=0):
self.store[key] = int(self.store.setdefault(key, 0)) - delta
return int(self.store[key])
@contextmanager
def soft_lock(self, key, timeout=0, retries=5):
yield True
def delete(self, key):
try:
del self.store[key]
except:
pass
return True
def mock_http_connect(response, headers=None, with_exc=False):
class FakeConn(object):
def __init__(self, status, headers, with_exc):
self.status = status
self.reason = 'Fake'
self.host = '1.2.3.4'
self.port = '1234'
self.with_exc = with_exc
self.headers = headers
if self.headers is None:
self.headers = {}
def getresponse(self):
if self.with_exc:
raise Exception('test')
return self
def getheader(self, header):
return self.headers[header]
def read(self, amt=None):
return ''
def close(self):
return
return lambda *args, **kwargs: FakeConn(response, headers, with_exc)
class FakeApp(object):
def __call__(self, env, start_response):
return ['204 No Content']
class FakeLogger(object):
def error(self, msg):
# a thread safe logger
pass
def start_response(*args):
pass
def dummy_filter_factory(global_conf, **local_conf):
conf = global_conf.copy()
conf.update(local_conf)
def limit_filter(app):
return ratelimit.RateLimitMiddleware(app, conf, logger=FakeLogger())
return limit_filter
class TestRateLimit(unittest.TestCase):
def _run(self, callable_func, num, rate, extra_sleep=0,
total_time=None, check_time=True):
begin = time.time()
for x in range(0, num):
result = callable_func()
# Extra sleep is here to test with different call intervals.
time.sleep(extra_sleep)
end = time.time()
if total_time is None:
total_time = num / rate
# Allow for one second of variation in the total time.
time_diff = abs(total_time - (end - begin))
if check_time:
self.assertTrue(time_diff < 1)
return time_diff
def test_get_container_maxrate(self):
conf_dict = {'container_limit_10': 200,
'container_limit_50': 100,
'container_limit_75': 30,}
test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
self.assertEquals(test_ratelimit.get_container_maxrate(0), None)
self.assertEquals(test_ratelimit.get_container_maxrate(5), None)
self.assertEquals(test_ratelimit.get_container_maxrate(10), 200)
self.assertEquals(test_ratelimit.get_container_maxrate(60), 72)
self.assertEquals(test_ratelimit.get_container_maxrate(160), 30)
def test_ratelimit(self):
current_rate = 13
num_calls = 100
conf_dict = {'account_ratelimit': current_rate}
self.test_ratelimit = ratelimit.filter_factory(conf_dict)(FakeApp())
ratelimit.http_connect = mock_http_connect(204)
req = Request.blank('/v/a/c')
req.environ['swift.cache'] = FakeMemcache()
make_app_call = lambda: self.test_ratelimit(req.environ, start_response)
self._run(make_app_call, num_calls, current_rate)
def test_ratelimit_whitelist(self):
current_rate = 2
conf_dict = {'account_ratelimit': current_rate,
'max_sleep_time_seconds': 2,
'account_whitelist': 'a',
'account_blacklist': 'b',
}
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
ratelimit.http_connect = mock_http_connect(204)
req = Request.blank('/v/a/c')
req.environ['swift.cache'] = FakeMemcache()
class rate_caller(Thread):
def __init__(self, parent):
Thread.__init__(self)
self.parent = parent
def run(self):
self.result = self.parent.test_ratelimit(req.environ,
start_response)
nt = 5
begin = time.time()
threads = []
for i in range(nt):
rc = rate_caller(self)
rc.start()
threads.append(rc)
for thread in threads:
thread.join()
the_498s = [t for t in threads if \
''.join(t.result).startswith('Slow down')]
self.assertEquals(len(the_498s), 0)
time_took = time.time() - begin
# the 4th request will happen at 1.5
self.assert_(round(time_took, 1) == 0)
def test_ratelimit_blacklist(self):
current_rate = 2
conf_dict = {'account_ratelimit': current_rate,
'max_sleep_time_seconds': 2,
'account_whitelist': 'a',
'account_blacklist': 'b',
}
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
ratelimit.http_connect = mock_http_connect(204)
req = Request.blank('/v/b/c')
req.environ['swift.cache'] = FakeMemcache()
class rate_caller(Thread):
def __init__(self, parent):
Thread.__init__(self)
self.parent = parent
def run(self):
self.result = self.parent.test_ratelimit(req.environ,
start_response)
nt = 5
begin = time.time()
threads = []
for i in range(nt):
rc = rate_caller(self)
rc.start()
threads.append(rc)
for thread in threads:
thread.join()
the_497s = [t for t in threads if \
''.join(t.result).startswith('Your account')]
self.assertEquals(len(the_497s), 5)
time_took = time.time() - begin
self.assert_(round(time_took, 1) == 0)
def test_ratelimit_max_rate(self):
'''
Running 5 threads at rate 2 a sec. and max sleep of 2 seconds
Expect threads to be run as follows:
t1:0, t2:0, t3:1, t4:1.5, t5:2(Max Rate thrown)
'''
current_rate = 2
conf_dict = {'account_ratelimit': current_rate,
'max_sleep_time_seconds': 2}
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
ratelimit.http_connect = mock_http_connect(204)
req = Request.blank('/v/a/c')
req.environ['swift.cache'] = FakeMemcache()
class rate_caller(Thread):
def __init__(self, parent):
Thread.__init__(self)
self.parent = parent
def run(self):
self.result = self.parent.test_ratelimit(req.environ,
start_response)
nt = 5
begin = time.time()
threads = []
for i in range(nt):
rc = rate_caller(self)
rc.start()
threads.append(rc)
for thread in threads:
thread.join()
the_498s = [t for t in threads if \
''.join(t.result).startswith('Slow down')]
self.assertEquals(len(the_498s), 1)
time_took = time.time() - begin
# the 4th request will happen at 1.5
self.assert_(round(time_took, 1) == 1.5)
def test_ratelimit_max_rate_double(self):
current_rate = 2
conf_dict = {'account_ratelimit': current_rate,
'clock_accuracy': 100,
'max_sleep_time_seconds': 4}
# making clock less accurate for nosetests running slow
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
ratelimit.http_connect = mock_http_connect(204)
req = Request.blank('/v/a/c')
req.environ['swift.cache'] = FakeMemcache()
begin = time.time()
class rate_caller(Thread):
def __init__(self, parent, name):
Thread.__init__(self)
self.parent = parent
self.name = name
def run(self):
self.result1 = self.parent.test_ratelimit(req.environ,
start_response)
time.sleep(.1)
self.result2 = self.parent.test_ratelimit(req.environ,
start_response)
nt = 9
threads = []
for i in range(nt):
rc = rate_caller(self, "thread %s" % i)
rc.start()
threads.append(rc)
for thread in threads:
thread.join()
all_results = [''.join(t.result1) for t in threads]
all_results += [''.join(t.result2) for t in threads]
the_498s = [t for t in all_results if t.startswith('Slow down')]
self.assertEquals(len(the_498s), 2)
time_took = time.time() - begin
self.assert_(round(time_took, 1) == 7.5)
def test_ratelimit_max_rate_multiple_acc(self):
num_calls = 4
current_rate = 2
conf_dict = {'account_ratelimit': current_rate,
'max_sleep_time_seconds': 2}
fake_memcache = FakeMemcache()
the_app = ratelimit.RateLimitMiddleware(None, conf_dict,
logger=FakeLogger())
the_app.memcache_client = fake_memcache
class rate_caller(Thread):
def __init__(self, name):
self.myname = name
Thread.__init__(self)
def run(self):
for j in range(num_calls):
self.result = the_app.handle_rate_limit(None, self.myname,
None, None)
nt = 15
begin = time.time()
threads = []
for i in range(nt):
rc = rate_caller('a%s' % i)
rc.start()
threads.append(rc)
for thread in threads:
thread.join()
time_took = time.time() - begin
# the all 15 threads still take 1.5 secs
self.assert_(round(time_took, 1) == 1.5)
def test_ratelimit_acc_vrs_container(self):
conf_dict = {'clock_accuracy': 1000,
'account_ratelimit': 10,
'max_sleep_time_seconds': 4,
'container_limit_10': 6,
'container_limit_50': 2,
'container_limit_75': 1,}
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
ratelimit.http_connect = mock_http_connect(204)
req = Request.blank('/v/a/c')
req.environ['swift.cache'] = FakeMemcache()
cont_key = get_container_memcache_key('a','c')
class rate_caller(Thread):
def __init__(self, parent, name):
Thread.__init__(self)
self.parent = parent
self.name = name
def run(self):
self.result = self.parent.test_ratelimit(req.environ,
start_response,
name=self.name)
def runthreads(threads, nt):
for i in range(nt):
rc = rate_caller(self, "thread %s" % i)
rc.start()
threads.append(rc)
for thread in threads:
thread.join()
begin = time.time()
req.environ['swift.cache'].set(cont_key, {'container_size': 20})
begin = time.time()
threads = []
runthreads(threads,3)
time_took = time.time() - begin
self.assert_(round(time_took, 1) == .4)
if __name__ == '__main__':
unittest.main()

View File

@ -1295,17 +1295,6 @@ class TestObjectController(unittest.TestCase):
headers = readuntil2crlfs(fd) headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 404' exp = 'HTTP/1.1 404'
self.assertEquals(headers[:len(exp)], exp) self.assertEquals(headers[:len(exp)], exp)
# Check blacklist
prosrv.rate_limit_blacklist = ['a']
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile()
fd.write('GET /v1/a HTTP/1.1\r\nHost: localhost\r\n'
'Connection: close\r\nContent-Length: 0\r\n\r\n')
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 497'
self.assertEquals(headers[:len(exp)], exp)
prosrv.rate_limit_blacklist = []
# Check invalid utf-8 # Check invalid utf-8
sock = connect_tcp(('localhost', prolis.getsockname()[1])) sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile() fd = sock.makefile()
@ -1326,31 +1315,6 @@ class TestObjectController(unittest.TestCase):
headers = readuntil2crlfs(fd) headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 412' exp = 'HTTP/1.1 412'
self.assertEquals(headers[:len(exp)], exp) self.assertEquals(headers[:len(exp)], exp)
# Check rate limiting
orig_rate_limit = prosrv.rate_limit
prosrv.rate_limit = 0
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile()
fd.write('GET /v1/a HTTP/1.1\r\nHost: localhost\r\n'
'Connection: close\r\nX-Auth-Token: t\r\n'
'Content-Length: 0\r\n\r\n')
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 498'
self.assertEquals(headers[:len(exp)], exp)
prosrv.rate_limit = orig_rate_limit
orig_rate_limit = prosrv.account_rate_limit
prosrv.account_rate_limit = 0
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile()
fd.write('PUT /v1/a/c HTTP/1.1\r\nHost: localhost\r\n'
'Connection: close\r\nX-Auth-Token: t\r\n'
'Content-Length: 0\r\n\r\n')
fd.flush()
headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 498'
self.assertEquals(headers[:len(exp)], exp)
prosrv.account_rate_limit = orig_rate_limit
# Check bad method # Check bad method
sock = connect_tcp(('localhost', prolis.getsockname()[1])) sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile() fd = sock.makefile()
@ -1362,8 +1326,8 @@ class TestObjectController(unittest.TestCase):
exp = 'HTTP/1.1 405' exp = 'HTTP/1.1 405'
self.assertEquals(headers[:len(exp)], exp) self.assertEquals(headers[:len(exp)], exp)
# Check unhandled exception # Check unhandled exception
orig_rate_limit = prosrv.rate_limit orig_logger = prosrv.logger
del prosrv.rate_limit del prosrv.logger
sock = connect_tcp(('localhost', prolis.getsockname()[1])) sock = connect_tcp(('localhost', prolis.getsockname()[1]))
fd = sock.makefile() fd = sock.makefile()
fd.write('HEAD /v1/a HTTP/1.1\r\nHost: localhost\r\n' fd.write('HEAD /v1/a HTTP/1.1\r\nHost: localhost\r\n'
@ -1373,7 +1337,7 @@ class TestObjectController(unittest.TestCase):
headers = readuntil2crlfs(fd) headers = readuntil2crlfs(fd)
exp = 'HTTP/1.1 500' exp = 'HTTP/1.1 500'
self.assertEquals(headers[:len(exp)], exp) self.assertEquals(headers[:len(exp)], exp)
prosrv.rate_limit = orig_rate_limit prosrv.logger = orig_logger
# Okay, back to chunked put testing; Create account # Okay, back to chunked put testing; Create account
ts = normalize_timestamp(time()) ts = normalize_timestamp(time())
partition, nodes = prosrv.account_ring.get_nodes('a') partition, nodes = prosrv.account_ring.get_nodes('a')