adding ratelimiting middleware and unit tests
This commit is contained in:
parent
8ed75703fc
commit
72d40bd9f6
@ -56,3 +56,15 @@ use = egg:swift#memcache
|
|||||||
# Default for memcache_servers is below, but you can specify multiple servers
|
# Default for memcache_servers is below, but you can specify multiple servers
|
||||||
# with the format: 10.1.2.3:11211,10.1.2.4:11211
|
# with the format: 10.1.2.3:11211,10.1.2.4:11211
|
||||||
# memcache_servers = 127.0.0.1:11211
|
# memcache_servers = 127.0.0.1:11211
|
||||||
|
|
||||||
|
[filter:ratelimit]
|
||||||
|
use = egg:swift#ratelimit
|
||||||
|
account_ratelimit = 200
|
||||||
|
account_whitelist = a,b
|
||||||
|
# account_blacklist =
|
||||||
|
|
||||||
|
# with container_limit_x = r
|
||||||
|
# for containers of size x limit requests per second to r
|
||||||
|
container_limit_0 = 100
|
||||||
|
container_limit_10 = 50
|
||||||
|
container_limit_50 = 10
|
||||||
|
1
setup.py
1
setup.py
@ -88,6 +88,7 @@ setup(
|
|||||||
'auth=swift.common.middleware.auth:filter_factory',
|
'auth=swift.common.middleware.auth:filter_factory',
|
||||||
'healthcheck=swift.common.middleware.healthcheck:filter_factory',
|
'healthcheck=swift.common.middleware.healthcheck:filter_factory',
|
||||||
'memcache=swift.common.middleware.memcache:filter_factory',
|
'memcache=swift.common.middleware.memcache:filter_factory',
|
||||||
|
# 'ratelimit=swift.common.middeware.ratelimit:filter_factory',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
198
swift/common/middleware/ratelimit.py
Normal file
198
swift/common/middleware/ratelimit.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import time
|
||||||
|
from webob import Request, Response
|
||||||
|
from ConfigParser import ConfigParser, NoOptionError
|
||||||
|
|
||||||
|
from swift.common.utils import split_path, cache_from_env, get_logger
|
||||||
|
from swift.proxy.server import get_container_memcache_key
|
||||||
|
|
||||||
|
class MaxSleepTimeHit(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RateLimitMiddleware(object):
|
||||||
|
"""
|
||||||
|
Rate limiting middleware
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, conf, logger=None):
|
||||||
|
self.app = app
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
if logger is None:
|
||||||
|
self.logger = get_logger(conf)
|
||||||
|
else:
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
self.account_rate_limit = float(conf.get('account_ratelimit', 1))#200.0))
|
||||||
|
self.max_sleep_time_seconds = int(conf.get('max_sleep_time_seconds',
|
||||||
|
2))#60))
|
||||||
|
self.clock_accuracy = int(conf.get('clock_accuracy', 1000))
|
||||||
|
|
||||||
|
self.rate_limit_whitelist = [acc.strip() for acc in
|
||||||
|
conf.get('account_whitelist', '').split(',')
|
||||||
|
if acc.strip()]
|
||||||
|
self.rate_limit_blacklist = [acc.strip() for acc in
|
||||||
|
conf.get('account_blacklist', '').split(',')
|
||||||
|
if acc.strip()]
|
||||||
|
self.memcache_client = None
|
||||||
|
conf_limits = []
|
||||||
|
for conf_key in conf.keys():
|
||||||
|
if conf_key.startswith('container_limit_'):
|
||||||
|
cont_size = int(conf_key[len('container_limit_'):])
|
||||||
|
rate = float(conf[conf_key])
|
||||||
|
conf_limits.append((cont_size,rate))
|
||||||
|
|
||||||
|
conf_limits.sort()
|
||||||
|
self.container_limits = []
|
||||||
|
while conf_limits:
|
||||||
|
cur_size, cur_rate = conf_limits.pop(0)
|
||||||
|
if conf_limits:
|
||||||
|
# figure out slope for function between this point and next
|
||||||
|
next_size, next_rate = conf_limits[0]
|
||||||
|
slope = (float(next_rate) - float(cur_rate)) \
|
||||||
|
/ (next_size - cur_size)
|
||||||
|
def new_scope(cur_size, slope, cur_rate):
|
||||||
|
# making new scope for variables
|
||||||
|
return lambda x: (x - cur_size) * slope + cur_rate
|
||||||
|
line_func = new_scope(cur_size, slope, cur_rate)
|
||||||
|
else:
|
||||||
|
# don't have to worry about scope here- this is the last
|
||||||
|
# element in the list
|
||||||
|
line_func = lambda x : cur_rate
|
||||||
|
|
||||||
|
self.container_limits.append((cur_size, cur_rate, line_func))
|
||||||
|
|
||||||
|
def get_container_maxrate(self, container_size):
|
||||||
|
"""
|
||||||
|
Will figure out the max_rate for a container size
|
||||||
|
"""
|
||||||
|
last_func = None
|
||||||
|
if container_size:
|
||||||
|
for size, rate, func in self.container_limits:
|
||||||
|
if container_size < size:
|
||||||
|
break
|
||||||
|
last_func = func
|
||||||
|
|
||||||
|
if last_func:
|
||||||
|
return last_func(container_size)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_key_rate_tuples(self, account_name, container_name, obj_name):
|
||||||
|
"""
|
||||||
|
Returns a list of keys (to be used in memcache) that can be
|
||||||
|
generated given a path. Keys should be checked in order.
|
||||||
|
|
||||||
|
:param path: path from request
|
||||||
|
"""
|
||||||
|
keys = []
|
||||||
|
if account_name:
|
||||||
|
keys.append(("ratelimit/%s" % account_name,
|
||||||
|
self.account_rate_limit))
|
||||||
|
if account_name and container_name and not obj_name:
|
||||||
|
container_size = None
|
||||||
|
memcache_key = get_container_memcache_key(account_name,
|
||||||
|
container_name)
|
||||||
|
container_info = self.memcache_client.get(memcache_key)
|
||||||
|
if type(container_info) == dict:
|
||||||
|
container_size = container_info.get('container_size')
|
||||||
|
|
||||||
|
container_rate = self.get_container_maxrate(container_size)
|
||||||
|
if container_rate:
|
||||||
|
keys.append(("ratelimit/%s/%s" % (account_name,
|
||||||
|
container_name),
|
||||||
|
container_rate))
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def _get_sleep_time(self, key, max_rate):
|
||||||
|
now_m = int(round(time.time() * self.clock_accuracy))
|
||||||
|
time_per_request_m = int(round(self.clock_accuracy / max_rate))
|
||||||
|
running_time_m = self.memcache_client.incr(key,
|
||||||
|
delta=time_per_request_m)
|
||||||
|
|
||||||
|
need_to_sleep_m = 0
|
||||||
|
request_time_limit = now_m + (time_per_request_m * max_rate)
|
||||||
|
|
||||||
|
if running_time_m < now_m:
|
||||||
|
next_avail_time = int(now_m + time_per_request_m)
|
||||||
|
self.memcache_client.set(key, str(next_avail_time),
|
||||||
|
serialize=False)
|
||||||
|
|
||||||
|
elif running_time_m - now_m - time_per_request_m > 0:
|
||||||
|
#running_time_m > request_time_limit:
|
||||||
|
need_to_sleep_m = running_time_m - now_m - time_per_request_m
|
||||||
|
|
||||||
|
|
||||||
|
max_sleep_m = self.max_sleep_time_seconds * self.clock_accuracy
|
||||||
|
if max_sleep_m - need_to_sleep_m <= self.clock_accuracy * 0.01:
|
||||||
|
# make it accurate to 1% of clock accuracy
|
||||||
|
# treat as no-op decrement time
|
||||||
|
self.memcache_client.decr(key, delta=time_per_request_m)
|
||||||
|
raise MaxSleepTimeHit("Max Sleep Time Exceeded: %s" %
|
||||||
|
need_to_sleep_m)
|
||||||
|
|
||||||
|
return float(need_to_sleep_m) / self.clock_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def handle_rate_limit(self, req, account_name, container_name, obj_name,
|
||||||
|
name=None):
|
||||||
|
|
||||||
|
if account_name in self.rate_limit_blacklist:
|
||||||
|
self.logger.error('Returning 497 because of blacklisting')
|
||||||
|
|
||||||
|
return Response(status='497 Blacklisted',
|
||||||
|
body='Your account has been blacklisted', request=req)
|
||||||
|
if account_name in self.rate_limit_whitelist:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for key, max_rate in self._generate_key_rate_tuples(account_name,
|
||||||
|
container_name,
|
||||||
|
obj_name):
|
||||||
|
try:
|
||||||
|
need_to_sleep = self._get_sleep_time(key,
|
||||||
|
max_rate)
|
||||||
|
if need_to_sleep > 0:
|
||||||
|
time.sleep(need_to_sleep)
|
||||||
|
|
||||||
|
except MaxSleepTimeHit, e:
|
||||||
|
self.logger.error('Returning 498 because of ops ' + \
|
||||||
|
'rate limiting (Max Sleep) %s' % e)
|
||||||
|
error_resp = Response(status='498 Rate Limited',
|
||||||
|
body='Slow down', request=req)
|
||||||
|
return error_resp
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def __call__(self, env, start_response, name=None):
|
||||||
|
req = Request(env)
|
||||||
|
if self.memcache_client is None:
|
||||||
|
self.memcache_client = cache_from_env(env)
|
||||||
|
version, account, container, obj = split_path(req.path, 1, 4, True)
|
||||||
|
|
||||||
|
rate_limit_resp = self.handle_rate_limit(req, account, container,
|
||||||
|
obj, name=name)
|
||||||
|
if rate_limit_resp is None:
|
||||||
|
return self.app(env, start_response)
|
||||||
|
else:
|
||||||
|
return rate_limit_resp(env, start_response)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_factory(global_conf, **local_conf):
|
||||||
|
conf = global_conf.copy()
|
||||||
|
conf.update(local_conf)
|
||||||
|
def limit_filter(app):
|
||||||
|
return RateLimitMiddleware(app, conf)
|
||||||
|
return limit_filter
|
@ -88,6 +88,10 @@ def delay_denial(func):
|
|||||||
return func(*a, **kw)
|
return func(*a, **kw)
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
def get_container_memcache_key(account, container):
|
||||||
|
path = '/%s/%s' % (account, container)
|
||||||
|
return 'container%s' % path
|
||||||
|
|
||||||
|
|
||||||
class Controller(object):
|
class Controller(object):
|
||||||
"""Base WSGI controller class for the proxy"""
|
"""Base WSGI controller class for the proxy"""
|
||||||
@ -228,15 +232,22 @@ class Controller(object):
|
|||||||
"""
|
"""
|
||||||
partition, nodes = self.app.container_ring.get_nodes(
|
partition, nodes = self.app.container_ring.get_nodes(
|
||||||
account, container)
|
account, container)
|
||||||
|
|
||||||
path = '/%s/%s' % (account, container)
|
path = '/%s/%s' % (account, container)
|
||||||
cache_key = 'container%s' % path
|
cache_key = get_container_memcache_key(account, container)
|
||||||
|
|
||||||
# Older memcache values (should be treated as if they aren't there):
|
# Older memcache values (should be treated as if they aren't there):
|
||||||
# 0 = no responses, 200 = found, 404 = not found, -1 = mixed responses
|
# 0 = no responses, 200 = found, 404 = not found, -1 = mixed responses
|
||||||
# Newer memcache values:
|
# Newer memcache values:
|
||||||
# [older status value from above, read acl, write acl]
|
# [older status value from above, read acl, write acl]
|
||||||
cache_value = self.app.memcache.get(cache_key)
|
cache_value = self.app.memcache.get(cache_key)
|
||||||
if hasattr(cache_value, '__iter__'):
|
if hasattr(cache_value, '__iter__'):
|
||||||
status, read_acl, write_acl = cache_value
|
if type(cache_value) == dict:
|
||||||
|
status = cache_value['status']
|
||||||
|
read_acl = cache_value['read_acl']
|
||||||
|
write_acl = cache_value['write_acl']
|
||||||
|
else:
|
||||||
|
status, read_acl, write_acl = cache_value
|
||||||
if status == 200:
|
if status == 200:
|
||||||
return partition, nodes, read_acl, write_acl
|
return partition, nodes, read_acl, write_acl
|
||||||
if not self.account_info(account)[1]:
|
if not self.account_info(account)[1]:
|
||||||
@ -244,6 +255,7 @@ class Controller(object):
|
|||||||
result_code = 0
|
result_code = 0
|
||||||
read_acl = None
|
read_acl = None
|
||||||
write_acl = None
|
write_acl = None
|
||||||
|
container_size = None
|
||||||
attempts_left = self.app.container_ring.replica_count
|
attempts_left = self.app.container_ring.replica_count
|
||||||
headers = {'x-cf-trans-id': self.trans_id}
|
headers = {'x-cf-trans-id': self.trans_id}
|
||||||
for node in self.iter_nodes(partition, nodes, self.app.container_ring):
|
for node in self.iter_nodes(partition, nodes, self.app.container_ring):
|
||||||
@ -260,6 +272,8 @@ class Controller(object):
|
|||||||
result_code = 200
|
result_code = 200
|
||||||
read_acl = resp.getheader('x-container-read')
|
read_acl = resp.getheader('x-container-read')
|
||||||
write_acl = resp.getheader('x-container-write')
|
write_acl = resp.getheader('x-container-write')
|
||||||
|
container_size = \
|
||||||
|
resp.getheader('X-Container-Object-Count')
|
||||||
break
|
break
|
||||||
elif resp.status == 404:
|
elif resp.status == 404:
|
||||||
result_code = 404 if not result_code else -1
|
result_code = 404 if not result_code else -1
|
||||||
@ -278,7 +292,10 @@ class Controller(object):
|
|||||||
cache_timeout = self.app.recheck_container_existence
|
cache_timeout = self.app.recheck_container_existence
|
||||||
else:
|
else:
|
||||||
cache_timeout = self.app.recheck_container_existence * 0.1
|
cache_timeout = self.app.recheck_container_existence * 0.1
|
||||||
self.app.memcache.set(cache_key, (result_code, read_acl, write_acl),
|
self.app.memcache.set(cache_key, {'status': result_code,
|
||||||
|
'read_acl': read_acl,
|
||||||
|
'write_acl': write_acl,
|
||||||
|
'container_size': container_size},
|
||||||
timeout=cache_timeout)
|
timeout=cache_timeout)
|
||||||
if result_code == 200:
|
if result_code == 200:
|
||||||
return partition, nodes, read_acl, write_acl
|
return partition, nodes, read_acl, write_acl
|
||||||
@ -941,6 +958,8 @@ class ContainerController(Controller):
|
|||||||
statuses.append(503)
|
statuses.append(503)
|
||||||
reasons.append('')
|
reasons.append('')
|
||||||
bodies.append('')
|
bodies.append('')
|
||||||
|
#TODO : David - does this need to be using the
|
||||||
|
# get_container_memcache_key function????
|
||||||
self.app.memcache.delete('container%s' % req.path_info.rstrip('/'))
|
self.app.memcache.delete('container%s' % req.path_info.rstrip('/'))
|
||||||
return self.best_response(req, statuses, reasons, bodies,
|
return self.best_response(req, statuses, reasons, bodies,
|
||||||
'Container PUT')
|
'Container PUT')
|
||||||
@ -1214,14 +1233,6 @@ class BaseApplication(object):
|
|||||||
self.account_ring = account_ring or \
|
self.account_ring = account_ring or \
|
||||||
Ring(os.path.join(swift_dir, 'account.ring.gz'))
|
Ring(os.path.join(swift_dir, 'account.ring.gz'))
|
||||||
self.memcache = memcache
|
self.memcache = memcache
|
||||||
self.rate_limit = float(conf.get('rate_limit', 20000.0))
|
|
||||||
self.account_rate_limit = float(conf.get('account_rate_limit', 200.0))
|
|
||||||
self.rate_limit_whitelist = [x.strip() for x in
|
|
||||||
conf.get('rate_limit_account_whitelist', '').split(',')
|
|
||||||
if x.strip()]
|
|
||||||
self.rate_limit_blacklist = [x.strip() for x in
|
|
||||||
conf.get('rate_limit_account_blacklist', '').split(',')
|
|
||||||
if x.strip()]
|
|
||||||
|
|
||||||
def get_controller(self, path):
|
def get_controller(self, path):
|
||||||
"""
|
"""
|
||||||
@ -1302,10 +1313,6 @@ class BaseApplication(object):
|
|||||||
return HTTPPreconditionFailed(request=req, body='Invalid UTF8')
|
return HTTPPreconditionFailed(request=req, body='Invalid UTF8')
|
||||||
if not controller:
|
if not controller:
|
||||||
return HTTPPreconditionFailed(request=req, body='Bad URL')
|
return HTTPPreconditionFailed(request=req, body='Bad URL')
|
||||||
rate_limit_allowed_err_resp = \
|
|
||||||
self.check_rate_limit(req, path_parts)
|
|
||||||
if rate_limit_allowed_err_resp is not None:
|
|
||||||
return rate_limit_allowed_err_resp
|
|
||||||
|
|
||||||
controller = controller(self, **path_parts)
|
controller = controller(self, **path_parts)
|
||||||
controller.trans_id = req.headers.get('x-cf-trans-id', '-')
|
controller.trans_id = req.headers.get('x-cf-trans-id', '-')
|
||||||
@ -1339,10 +1346,6 @@ class BaseApplication(object):
|
|||||||
self.logger.exception('ERROR Unhandled exception in request')
|
self.logger.exception('ERROR Unhandled exception in request')
|
||||||
return HTTPServerError(request=req)
|
return HTTPServerError(request=req)
|
||||||
|
|
||||||
def check_rate_limit(self, req, path_parts):
|
|
||||||
"""Check for rate limiting."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Application(BaseApplication):
|
class Application(BaseApplication):
|
||||||
"""WSGI application for the proxy server."""
|
"""WSGI application for the proxy server."""
|
||||||
@ -1395,46 +1398,6 @@ class Application(BaseApplication):
|
|||||||
trans_time,
|
trans_time,
|
||||||
)))
|
)))
|
||||||
|
|
||||||
def check_rate_limit(self, req, path_parts):
|
|
||||||
"""
|
|
||||||
Check for rate limiting.
|
|
||||||
|
|
||||||
:param req: webob.Request object
|
|
||||||
:param path_parts: parsed path dictionary
|
|
||||||
"""
|
|
||||||
if path_parts['account_name'] in self.rate_limit_blacklist:
|
|
||||||
self.logger.error('Returning 497 because of blacklisting')
|
|
||||||
return Response(status='497 Blacklisted',
|
|
||||||
body='Your account has been blacklisted', request=req)
|
|
||||||
if path_parts['account_name'] not in self.rate_limit_whitelist:
|
|
||||||
current_second = time.strftime('%x%H%M%S')
|
|
||||||
general_rate_limit_key = '%s%s' % (path_parts['account_name'],
|
|
||||||
current_second)
|
|
||||||
ops_count = self.memcache.incr(general_rate_limit_key, timeout=2)
|
|
||||||
if ops_count > self.rate_limit:
|
|
||||||
self.logger.error(
|
|
||||||
'Returning 498 because of ops rate limiting')
|
|
||||||
return Response(status='498 Rate Limited',
|
|
||||||
body='Slow down', request=req)
|
|
||||||
elif (path_parts['container_name']
|
|
||||||
and not path_parts['object_name']) \
|
|
||||||
or \
|
|
||||||
(path_parts['account_name']
|
|
||||||
and not path_parts['container_name']):
|
|
||||||
# further limit operations on a single account or container
|
|
||||||
rate_limit_key = '%s%s%s' % (path_parts['account_name'],
|
|
||||||
path_parts['container_name'] or '-',
|
|
||||||
current_second)
|
|
||||||
ops_count = self.memcache.incr(rate_limit_key, timeout=2)
|
|
||||||
if ops_count > self.account_rate_limit:
|
|
||||||
self.logger.error(
|
|
||||||
'Returning 498 because of account and container'
|
|
||||||
' rate limiting')
|
|
||||||
return Response(status='498 Rate Limited',
|
|
||||||
body='Slow down', request=req)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def app_factory(global_conf, **local_conf):
|
def app_factory(global_conf, **local_conf):
|
||||||
"""paste.deploy app factory for creating WSGI proxy apps."""
|
"""paste.deploy app factory for creating WSGI proxy apps."""
|
||||||
conf = global_conf.copy()
|
conf = global_conf.copy()
|
||||||
|
412
test/unit/common/middleware/test_ratelimit.py
Normal file
412
test/unit/common/middleware/test_ratelimit.py
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
# Copyright (c) 2010 OpenStack, LLC.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
import eventlet
|
||||||
|
from webob import Request
|
||||||
|
|
||||||
|
from swift.common.middleware import ratelimit
|
||||||
|
from swift.proxy.server import get_container_memcache_key
|
||||||
|
|
||||||
|
# mocks
|
||||||
|
#logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMemcache(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.store = {}
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
return self.store.get(key)
|
||||||
|
|
||||||
|
def set(self, key, value, serialize=False, timeout=0):
|
||||||
|
self.store[key] = value
|
||||||
|
return True
|
||||||
|
|
||||||
|
def incr(self, key, delta=1, timeout=0):
|
||||||
|
self.store[key] = int(self.store.setdefault(key, 0)) + delta
|
||||||
|
return int(self.store[key])
|
||||||
|
|
||||||
|
def decr(self, key, delta=1, timeout=0):
|
||||||
|
self.store[key] = int(self.store.setdefault(key, 0)) - delta
|
||||||
|
return int(self.store[key])
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def soft_lock(self, key, timeout=0, retries=5):
|
||||||
|
yield True
|
||||||
|
|
||||||
|
def delete(self, key):
|
||||||
|
try:
|
||||||
|
del self.store[key]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def mock_http_connect(response, headers=None, with_exc=False):
|
||||||
|
class FakeConn(object):
|
||||||
|
def __init__(self, status, headers, with_exc):
|
||||||
|
self.status = status
|
||||||
|
self.reason = 'Fake'
|
||||||
|
self.host = '1.2.3.4'
|
||||||
|
self.port = '1234'
|
||||||
|
self.with_exc = with_exc
|
||||||
|
self.headers = headers
|
||||||
|
if self.headers is None:
|
||||||
|
self.headers = {}
|
||||||
|
def getresponse(self):
|
||||||
|
if self.with_exc:
|
||||||
|
raise Exception('test')
|
||||||
|
return self
|
||||||
|
def getheader(self, header):
|
||||||
|
return self.headers[header]
|
||||||
|
def read(self, amt=None):
|
||||||
|
return ''
|
||||||
|
def close(self):
|
||||||
|
return
|
||||||
|
return lambda *args, **kwargs: FakeConn(response, headers, with_exc)
|
||||||
|
|
||||||
|
class FakeApp(object):
|
||||||
|
def __call__(self, env, start_response):
|
||||||
|
return ['204 No Content']
|
||||||
|
class FakeLogger(object):
|
||||||
|
def error(self, msg):
|
||||||
|
# a thread safe logger
|
||||||
|
pass
|
||||||
|
def start_response(*args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_filter_factory(global_conf, **local_conf):
|
||||||
|
conf = global_conf.copy()
|
||||||
|
conf.update(local_conf)
|
||||||
|
def limit_filter(app):
|
||||||
|
return ratelimit.RateLimitMiddleware(app, conf, logger=FakeLogger())
|
||||||
|
return limit_filter
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimit(unittest.TestCase):
|
||||||
|
|
||||||
|
def _run(self, callable_func, num, rate, extra_sleep=0,
|
||||||
|
total_time=None, check_time=True):
|
||||||
|
begin = time.time()
|
||||||
|
|
||||||
|
for x in range(0, num):
|
||||||
|
result = callable_func()
|
||||||
|
# Extra sleep is here to test with different call intervals.
|
||||||
|
time.sleep(extra_sleep)
|
||||||
|
end = time.time()
|
||||||
|
if total_time is None:
|
||||||
|
total_time = num / rate
|
||||||
|
# Allow for one second of variation in the total time.
|
||||||
|
time_diff = abs(total_time - (end - begin))
|
||||||
|
if check_time:
|
||||||
|
self.assertTrue(time_diff < 1)
|
||||||
|
return time_diff
|
||||||
|
|
||||||
|
def test_get_container_maxrate(self):
|
||||||
|
conf_dict = {'container_limit_10': 200,
|
||||||
|
'container_limit_50': 100,
|
||||||
|
'container_limit_75': 30,}
|
||||||
|
test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
|
||||||
|
|
||||||
|
self.assertEquals(test_ratelimit.get_container_maxrate(0), None)
|
||||||
|
self.assertEquals(test_ratelimit.get_container_maxrate(5), None)
|
||||||
|
self.assertEquals(test_ratelimit.get_container_maxrate(10), 200)
|
||||||
|
self.assertEquals(test_ratelimit.get_container_maxrate(60), 72)
|
||||||
|
self.assertEquals(test_ratelimit.get_container_maxrate(160), 30)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratelimit(self):
|
||||||
|
current_rate = 13
|
||||||
|
num_calls = 100
|
||||||
|
conf_dict = {'account_ratelimit': current_rate}
|
||||||
|
|
||||||
|
self.test_ratelimit = ratelimit.filter_factory(conf_dict)(FakeApp())
|
||||||
|
|
||||||
|
ratelimit.http_connect = mock_http_connect(204)
|
||||||
|
|
||||||
|
req = Request.blank('/v/a/c')
|
||||||
|
req.environ['swift.cache'] = FakeMemcache()
|
||||||
|
|
||||||
|
make_app_call = lambda: self.test_ratelimit(req.environ, start_response)
|
||||||
|
|
||||||
|
self._run(make_app_call, num_calls, current_rate)
|
||||||
|
|
||||||
|
def test_ratelimit_whitelist(self):
|
||||||
|
current_rate = 2
|
||||||
|
conf_dict = {'account_ratelimit': current_rate,
|
||||||
|
'max_sleep_time_seconds': 2,
|
||||||
|
'account_whitelist': 'a',
|
||||||
|
'account_blacklist': 'b',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
|
||||||
|
ratelimit.http_connect = mock_http_connect(204)
|
||||||
|
req = Request.blank('/v/a/c')
|
||||||
|
req.environ['swift.cache'] = FakeMemcache()
|
||||||
|
|
||||||
|
class rate_caller(Thread):
|
||||||
|
def __init__(self, parent):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.parent = parent
|
||||||
|
def run(self):
|
||||||
|
self.result = self.parent.test_ratelimit(req.environ,
|
||||||
|
start_response)
|
||||||
|
|
||||||
|
nt = 5
|
||||||
|
begin = time.time()
|
||||||
|
threads = []
|
||||||
|
for i in range(nt):
|
||||||
|
rc = rate_caller(self)
|
||||||
|
rc.start()
|
||||||
|
threads.append(rc)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
the_498s = [t for t in threads if \
|
||||||
|
''.join(t.result).startswith('Slow down')]
|
||||||
|
|
||||||
|
self.assertEquals(len(the_498s), 0)
|
||||||
|
|
||||||
|
time_took = time.time() - begin
|
||||||
|
# the 4th request will happen at 1.5
|
||||||
|
self.assert_(round(time_took, 1) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratelimit_blacklist(self):
|
||||||
|
current_rate = 2
|
||||||
|
conf_dict = {'account_ratelimit': current_rate,
|
||||||
|
'max_sleep_time_seconds': 2,
|
||||||
|
'account_whitelist': 'a',
|
||||||
|
'account_blacklist': 'b',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
|
||||||
|
ratelimit.http_connect = mock_http_connect(204)
|
||||||
|
req = Request.blank('/v/b/c')
|
||||||
|
req.environ['swift.cache'] = FakeMemcache()
|
||||||
|
|
||||||
|
class rate_caller(Thread):
|
||||||
|
def __init__(self, parent):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.parent = parent
|
||||||
|
def run(self):
|
||||||
|
self.result = self.parent.test_ratelimit(req.environ,
|
||||||
|
start_response)
|
||||||
|
|
||||||
|
nt = 5
|
||||||
|
begin = time.time()
|
||||||
|
threads = []
|
||||||
|
for i in range(nt):
|
||||||
|
rc = rate_caller(self)
|
||||||
|
rc.start()
|
||||||
|
threads.append(rc)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
the_497s = [t for t in threads if \
|
||||||
|
''.join(t.result).startswith('Your account')]
|
||||||
|
|
||||||
|
self.assertEquals(len(the_497s), 5)
|
||||||
|
|
||||||
|
time_took = time.time() - begin
|
||||||
|
self.assert_(round(time_took, 1) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratelimit_max_rate(self):
|
||||||
|
'''
|
||||||
|
Running 5 threads at rate 2 a sec. and max sleep of 2 seconds
|
||||||
|
Expect threads to be run as follows:
|
||||||
|
t1:0, t2:0, t3:1, t4:1.5, t5:2(Max Rate thrown)
|
||||||
|
'''
|
||||||
|
current_rate = 2
|
||||||
|
conf_dict = {'account_ratelimit': current_rate,
|
||||||
|
'max_sleep_time_seconds': 2}
|
||||||
|
|
||||||
|
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
|
||||||
|
ratelimit.http_connect = mock_http_connect(204)
|
||||||
|
req = Request.blank('/v/a/c')
|
||||||
|
req.environ['swift.cache'] = FakeMemcache()
|
||||||
|
|
||||||
|
class rate_caller(Thread):
|
||||||
|
def __init__(self, parent):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.parent = parent
|
||||||
|
def run(self):
|
||||||
|
self.result = self.parent.test_ratelimit(req.environ,
|
||||||
|
start_response)
|
||||||
|
nt = 5
|
||||||
|
begin = time.time()
|
||||||
|
threads = []
|
||||||
|
for i in range(nt):
|
||||||
|
rc = rate_caller(self)
|
||||||
|
rc.start()
|
||||||
|
threads.append(rc)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
the_498s = [t for t in threads if \
|
||||||
|
''.join(t.result).startswith('Slow down')]
|
||||||
|
|
||||||
|
self.assertEquals(len(the_498s), 1)
|
||||||
|
time_took = time.time() - begin
|
||||||
|
# the 4th request will happen at 1.5
|
||||||
|
self.assert_(round(time_took, 1) == 1.5)
|
||||||
|
|
||||||
|
def test_ratelimit_max_rate_double(self):
|
||||||
|
current_rate = 2
|
||||||
|
conf_dict = {'account_ratelimit': current_rate,
|
||||||
|
'clock_accuracy': 100,
|
||||||
|
'max_sleep_time_seconds': 4}
|
||||||
|
# making clock less accurate for nosetests running slow
|
||||||
|
|
||||||
|
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
|
||||||
|
ratelimit.http_connect = mock_http_connect(204)
|
||||||
|
req = Request.blank('/v/a/c')
|
||||||
|
req.environ['swift.cache'] = FakeMemcache()
|
||||||
|
begin = time.time()
|
||||||
|
|
||||||
|
class rate_caller(Thread):
|
||||||
|
def __init__(self, parent, name):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.parent = parent
|
||||||
|
self.name = name
|
||||||
|
def run(self):
|
||||||
|
self.result1 = self.parent.test_ratelimit(req.environ,
|
||||||
|
start_response)
|
||||||
|
time.sleep(.1)
|
||||||
|
self.result2 = self.parent.test_ratelimit(req.environ,
|
||||||
|
start_response)
|
||||||
|
nt = 9
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(nt):
|
||||||
|
rc = rate_caller(self, "thread %s" % i)
|
||||||
|
rc.start()
|
||||||
|
threads.append(rc)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
all_results = [''.join(t.result1) for t in threads]
|
||||||
|
all_results += [''.join(t.result2) for t in threads]
|
||||||
|
|
||||||
|
the_498s = [t for t in all_results if t.startswith('Slow down')]
|
||||||
|
|
||||||
|
self.assertEquals(len(the_498s), 2)
|
||||||
|
|
||||||
|
time_took = time.time() - begin
|
||||||
|
|
||||||
|
self.assert_(round(time_took, 1) == 7.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratelimit_max_rate_multiple_acc(self):
|
||||||
|
num_calls = 4
|
||||||
|
current_rate = 2
|
||||||
|
|
||||||
|
conf_dict = {'account_ratelimit': current_rate,
|
||||||
|
'max_sleep_time_seconds': 2}
|
||||||
|
fake_memcache = FakeMemcache()
|
||||||
|
|
||||||
|
the_app = ratelimit.RateLimitMiddleware(None, conf_dict,
|
||||||
|
logger=FakeLogger())
|
||||||
|
the_app.memcache_client = fake_memcache
|
||||||
|
|
||||||
|
class rate_caller(Thread):
|
||||||
|
def __init__(self, name):
|
||||||
|
self.myname = name
|
||||||
|
Thread.__init__(self)
|
||||||
|
def run(self):
|
||||||
|
for j in range(num_calls):
|
||||||
|
self.result = the_app.handle_rate_limit(None, self.myname,
|
||||||
|
None, None)
|
||||||
|
|
||||||
|
nt = 15
|
||||||
|
begin = time.time()
|
||||||
|
threads = []
|
||||||
|
for i in range(nt):
|
||||||
|
rc = rate_caller('a%s' % i)
|
||||||
|
rc.start()
|
||||||
|
threads.append(rc)
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
time_took = time.time() - begin
|
||||||
|
# the all 15 threads still take 1.5 secs
|
||||||
|
self.assert_(round(time_took, 1) == 1.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratelimit_acc_vrs_container(self):
|
||||||
|
|
||||||
|
conf_dict = {'clock_accuracy': 1000,
|
||||||
|
'account_ratelimit': 10,
|
||||||
|
'max_sleep_time_seconds': 4,
|
||||||
|
'container_limit_10': 6,
|
||||||
|
'container_limit_50': 2,
|
||||||
|
'container_limit_75': 1,}
|
||||||
|
|
||||||
|
self.test_ratelimit = dummy_filter_factory(conf_dict)(FakeApp())
|
||||||
|
ratelimit.http_connect = mock_http_connect(204)
|
||||||
|
req = Request.blank('/v/a/c')
|
||||||
|
req.environ['swift.cache'] = FakeMemcache()
|
||||||
|
|
||||||
|
cont_key = get_container_memcache_key('a','c')
|
||||||
|
|
||||||
|
class rate_caller(Thread):
|
||||||
|
def __init__(self, parent, name):
|
||||||
|
Thread.__init__(self)
|
||||||
|
self.parent = parent
|
||||||
|
self.name = name
|
||||||
|
def run(self):
|
||||||
|
self.result = self.parent.test_ratelimit(req.environ,
|
||||||
|
start_response,
|
||||||
|
name=self.name)
|
||||||
|
|
||||||
|
def runthreads(threads, nt):
|
||||||
|
|
||||||
|
for i in range(nt):
|
||||||
|
rc = rate_caller(self, "thread %s" % i)
|
||||||
|
rc.start()
|
||||||
|
threads.append(rc)
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
begin = time.time()
|
||||||
|
req.environ['swift.cache'].set(cont_key, {'container_size': 20})
|
||||||
|
|
||||||
|
begin = time.time()
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
runthreads(threads,3)
|
||||||
|
|
||||||
|
time_took = time.time() - begin
|
||||||
|
self.assert_(round(time_took, 1) == .4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -1295,17 +1295,6 @@ class TestObjectController(unittest.TestCase):
|
|||||||
headers = readuntil2crlfs(fd)
|
headers = readuntil2crlfs(fd)
|
||||||
exp = 'HTTP/1.1 404'
|
exp = 'HTTP/1.1 404'
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
self.assertEquals(headers[:len(exp)], exp)
|
||||||
# Check blacklist
|
|
||||||
prosrv.rate_limit_blacklist = ['a']
|
|
||||||
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
|
||||||
fd = sock.makefile()
|
|
||||||
fd.write('GET /v1/a HTTP/1.1\r\nHost: localhost\r\n'
|
|
||||||
'Connection: close\r\nContent-Length: 0\r\n\r\n')
|
|
||||||
fd.flush()
|
|
||||||
headers = readuntil2crlfs(fd)
|
|
||||||
exp = 'HTTP/1.1 497'
|
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
|
||||||
prosrv.rate_limit_blacklist = []
|
|
||||||
# Check invalid utf-8
|
# Check invalid utf-8
|
||||||
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
||||||
fd = sock.makefile()
|
fd = sock.makefile()
|
||||||
@ -1326,31 +1315,6 @@ class TestObjectController(unittest.TestCase):
|
|||||||
headers = readuntil2crlfs(fd)
|
headers = readuntil2crlfs(fd)
|
||||||
exp = 'HTTP/1.1 412'
|
exp = 'HTTP/1.1 412'
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
self.assertEquals(headers[:len(exp)], exp)
|
||||||
# Check rate limiting
|
|
||||||
orig_rate_limit = prosrv.rate_limit
|
|
||||||
prosrv.rate_limit = 0
|
|
||||||
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
|
||||||
fd = sock.makefile()
|
|
||||||
fd.write('GET /v1/a HTTP/1.1\r\nHost: localhost\r\n'
|
|
||||||
'Connection: close\r\nX-Auth-Token: t\r\n'
|
|
||||||
'Content-Length: 0\r\n\r\n')
|
|
||||||
fd.flush()
|
|
||||||
headers = readuntil2crlfs(fd)
|
|
||||||
exp = 'HTTP/1.1 498'
|
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
|
||||||
prosrv.rate_limit = orig_rate_limit
|
|
||||||
orig_rate_limit = prosrv.account_rate_limit
|
|
||||||
prosrv.account_rate_limit = 0
|
|
||||||
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
|
||||||
fd = sock.makefile()
|
|
||||||
fd.write('PUT /v1/a/c HTTP/1.1\r\nHost: localhost\r\n'
|
|
||||||
'Connection: close\r\nX-Auth-Token: t\r\n'
|
|
||||||
'Content-Length: 0\r\n\r\n')
|
|
||||||
fd.flush()
|
|
||||||
headers = readuntil2crlfs(fd)
|
|
||||||
exp = 'HTTP/1.1 498'
|
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
|
||||||
prosrv.account_rate_limit = orig_rate_limit
|
|
||||||
# Check bad method
|
# Check bad method
|
||||||
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
||||||
fd = sock.makefile()
|
fd = sock.makefile()
|
||||||
@ -1362,8 +1326,8 @@ class TestObjectController(unittest.TestCase):
|
|||||||
exp = 'HTTP/1.1 405'
|
exp = 'HTTP/1.1 405'
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
self.assertEquals(headers[:len(exp)], exp)
|
||||||
# Check unhandled exception
|
# Check unhandled exception
|
||||||
orig_rate_limit = prosrv.rate_limit
|
orig_logger = prosrv.logger
|
||||||
del prosrv.rate_limit
|
del prosrv.logger
|
||||||
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
sock = connect_tcp(('localhost', prolis.getsockname()[1]))
|
||||||
fd = sock.makefile()
|
fd = sock.makefile()
|
||||||
fd.write('HEAD /v1/a HTTP/1.1\r\nHost: localhost\r\n'
|
fd.write('HEAD /v1/a HTTP/1.1\r\nHost: localhost\r\n'
|
||||||
@ -1373,7 +1337,7 @@ class TestObjectController(unittest.TestCase):
|
|||||||
headers = readuntil2crlfs(fd)
|
headers = readuntil2crlfs(fd)
|
||||||
exp = 'HTTP/1.1 500'
|
exp = 'HTTP/1.1 500'
|
||||||
self.assertEquals(headers[:len(exp)], exp)
|
self.assertEquals(headers[:len(exp)], exp)
|
||||||
prosrv.rate_limit = orig_rate_limit
|
prosrv.logger = orig_logger
|
||||||
# Okay, back to chunked put testing; Create account
|
# Okay, back to chunked put testing; Create account
|
||||||
ts = normalize_timestamp(time())
|
ts = normalize_timestamp(time())
|
||||||
partition, nodes = prosrv.account_ring.get_nodes('a')
|
partition, nodes = prosrv.account_ring.get_nodes('a')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user