-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add rate limiting middleware example #11969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rodrigobnogueira
wants to merge
5
commits into
aio-libs:master
Choose a base branch
from
rodrigobnogueira:add-rate-limit-middleware-example
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+239
−0
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
1e98b2f
Add rate limiting middleware example
8c0f0c1
Add changelog fragment for PR #11969
84696ad
Add myself to CONTRIBUTORS.txt
a5e2364
Add explanatory comment for empty except block
9f7a307
Fix token bucket logic to correctly consume tokens after waiting
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| Added a rate limiting middleware example demonstrating how to limit request | ||
| rate using the token bucket algorithm with support for per-domain rate limits | ||
| and ``Retry-After`` header handling -- by :user:`rodrigobnogueira`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Example of using rate limiting middleware with aiohttp client. | ||
|
|
||
| This example shows how to implement a middleware that limits request rate | ||
| to avoid overwhelming servers or hitting API rate limits. The implementation | ||
| uses a token bucket algorithm that allows for burst traffic while maintaining | ||
| an average rate limit. | ||
|
|
||
| Features: | ||
| - Token bucket rate limiting with configurable rate and burst size | ||
| - Per-domain rate limiting for multi-host scenarios | ||
| - Automatic Retry-After header handling | ||
| - Support for both global and per-domain limits | ||
| """ | ||
|
|
||
| import asyncio | ||
| import logging | ||
| import time | ||
| from collections import defaultdict | ||
| from http import HTTPStatus | ||
|
|
||
| from aiohttp import ClientHandlerType, ClientRequest, ClientResponse, ClientSession, web | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class TokenBucket: | ||
| """Token bucket rate limiter implementation.""" | ||
|
|
||
| def __init__(self, rate: float, burst: int) -> None: | ||
| self.rate = rate | ||
| self.burst = burst | ||
| self.tokens = float(burst) | ||
| self.last_refill = time.monotonic() | ||
| self._lock = asyncio.Lock() | ||
|
|
||
| async def acquire(self) -> None: | ||
| """Acquire a token, waiting if necessary.""" | ||
| while True: | ||
| async with self._lock: | ||
| now = time.monotonic() | ||
| self._refill(now) | ||
| if self.tokens >= 1: | ||
| self.tokens -= 1 | ||
| return | ||
| wait_time = (1 - self.tokens) / self.rate | ||
|
|
||
| await asyncio.sleep(wait_time) | ||
|
|
||
| def _refill(self, now: float) -> None: | ||
| elapsed = now - self.last_refill | ||
| self.tokens = min(self.burst, self.tokens + elapsed * self.rate) | ||
| self.last_refill = now | ||
|
|
||
|
|
||
| class RateLimitMiddleware: | ||
| """Middleware that rate limits requests using token bucket algorithm.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| rate: float = 10.0, | ||
| burst: int = 10, | ||
| per_domain: bool = False, | ||
| respect_retry_after: bool = True, | ||
| ) -> None: | ||
| self.rate = rate | ||
| self.burst = burst | ||
| self.per_domain = per_domain | ||
| self.respect_retry_after = respect_retry_after | ||
| self._global_bucket = TokenBucket(rate, burst) | ||
| self._domain_buckets: dict[str, TokenBucket] = defaultdict( | ||
| lambda: TokenBucket(rate, burst) | ||
| ) | ||
|
|
||
| def _get_bucket(self, request: ClientRequest) -> TokenBucket: | ||
| if self.per_domain: | ||
| domain = request.url.host or "unknown" | ||
| return self._domain_buckets[domain] | ||
| return self._global_bucket | ||
|
|
||
| async def _handle_retry_after(self, response: ClientResponse) -> None: | ||
| if response.status != HTTPStatus.TOO_MANY_REQUESTS: | ||
| return | ||
| retry_after = response.headers.get("Retry-After") | ||
| if retry_after: | ||
| try: | ||
| wait_seconds = float(retry_after) | ||
| _LOGGER.info("Server requested Retry-After: %ss", wait_seconds) | ||
| await asyncio.sleep(wait_seconds) | ||
| except ValueError: | ||
| pass # Retry-After may be an HTTP-date; ignore if not a number | ||
|
|
||
| async def __call__( | ||
| self, | ||
| request: ClientRequest, | ||
| handler: ClientHandlerType, | ||
| ) -> ClientResponse: | ||
| """Execute request with rate limiting.""" | ||
| bucket = self._get_bucket(request) | ||
| await bucket.acquire() | ||
|
|
||
| response = await handler(request) | ||
|
|
||
| if self.respect_retry_after: | ||
| await self._handle_retry_after(response) | ||
|
|
||
| return response | ||
|
|
||
|
|
||
| class TestServer: | ||
| """Test server that simulates rate limiting.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| self.request_times: list[float] = [] | ||
| self.rate_limit_counter = 0 | ||
|
|
||
| async def handle_api(self, request: web.Request) -> web.Response: | ||
| """Normal API endpoint that tracks request timing.""" | ||
| self.request_times.append(time.monotonic()) | ||
| return web.json_response( | ||
| { | ||
| "message": "Success", | ||
| "request_count": len(self.request_times), | ||
| } | ||
| ) | ||
|
|
||
| async def handle_rate_limited(self, request: web.Request) -> web.Response: | ||
| """Endpoint simulating server-side rate limiting.""" | ||
| self.rate_limit_counter += 1 | ||
| if self.rate_limit_counter <= 2: | ||
| return web.Response( | ||
| status=429, | ||
| text="Too Many Requests", | ||
| headers={"Retry-After": "1"}, | ||
| ) | ||
| return web.json_response({"message": "Rate limit cleared"}) | ||
|
|
||
| async def handle_stats(self, request: web.Request) -> web.Response: | ||
| """Return request timing statistics.""" | ||
| if len(self.request_times) < 2: | ||
| return web.json_response({"intervals": [], "average_rate": 0}) | ||
| intervals = [ | ||
| self.request_times[i] - self.request_times[i - 1] | ||
| for i in range(1, len(self.request_times)) | ||
| ] | ||
| avg_rate = 1.0 / (sum(intervals) / len(intervals)) if intervals else 0 | ||
| return web.json_response( | ||
| { | ||
| "intervals": [round(i, 3) for i in intervals], | ||
| "average_rate": round(avg_rate, 2), | ||
| } | ||
| ) | ||
|
|
||
| async def handle_reset(self, request: web.Request) -> web.Response: | ||
| """Reset server state.""" | ||
| self.request_times = [] | ||
| self.rate_limit_counter = 0 | ||
| return web.Response(text="Reset") | ||
|
|
||
|
|
||
| async def run_test_server() -> web.AppRunner: | ||
| """Run a test server with rate limiting simulation.""" | ||
| app = web.Application() | ||
| server = TestServer() | ||
|
|
||
| app.router.add_get("/api", server.handle_api) | ||
| app.router.add_get("/rate-limited", server.handle_rate_limited) | ||
| app.router.add_get("/stats", server.handle_stats) | ||
| app.router.add_post("/reset", server.handle_reset) | ||
|
|
||
| runner = web.AppRunner(app) | ||
| await runner.setup() | ||
| site = web.TCPSite(runner, "localhost", 8080) | ||
| await site.start() | ||
| return runner | ||
|
|
||
|
|
||
| async def run_tests() -> None: | ||
| """Run rate limit middleware tests.""" | ||
| rate_limit = RateLimitMiddleware(rate=5.0, burst=2, per_domain=False) | ||
|
|
||
| async with ClientSession(middlewares=(rate_limit,)) as session: | ||
| await session.post("http://localhost:8080/reset") | ||
|
|
||
| print("=== Test 1: Burst requests (limit: 5/s, burst: 2) ===") | ||
| print("Sending 5 requests rapidly...") | ||
| start = time.monotonic() | ||
|
|
||
| for i in range(5): | ||
| async with session.get("http://localhost:8080/api") as resp: | ||
| data = await resp.json() | ||
| elapsed = time.monotonic() - start | ||
| print(f"Request {i + 1}: {elapsed:.2f}s - {data['message']}") | ||
|
|
||
| print("\n=== Test 2: Check actual request rate ===") | ||
| async with session.get("http://localhost:8080/stats") as resp: | ||
| stats = await resp.json() | ||
| print(f"Request intervals: {stats['intervals']}") | ||
| print(f"Average rate: {stats['average_rate']} req/s") | ||
|
|
||
| print("\n=== Test 3: Server-side 429 with Retry-After ===") | ||
| await session.post("http://localhost:8080/reset") | ||
| for i in range(3): | ||
| async with session.get("http://localhost:8080/rate-limited") as resp: | ||
| text = await resp.text() if resp.status == 429 else (await resp.json()) | ||
| print(f"Request {i + 1}: Status {resp.status} - {text}") | ||
|
|
||
| print("\n=== Test 4: Per-domain rate limiting ===") | ||
| per_domain_limit = RateLimitMiddleware(rate=10.0, burst=1, per_domain=True) | ||
|
|
||
| async with ClientSession(middlewares=(per_domain_limit,)) as session: | ||
| await session.post("http://localhost:8080/reset") | ||
| print("Simulating requests to different 'domains' (same server)...") | ||
| print("(In real usage, different domains get separate rate limits)") | ||
|
|
||
| start = time.monotonic() | ||
| for i in range(3): | ||
| async with session.get("http://localhost:8080/api") as resp: | ||
| elapsed = time.monotonic() - start | ||
| print(f"Request {i + 1} to localhost: {elapsed:.2f}s") | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| server = await run_test_server() | ||
|
|
||
| try: | ||
| await run_tests() | ||
| finally: | ||
| await server.cleanup() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.