From 89bf97488edebc14dbc4654a6ba0008d4c5a5158 Mon Sep 17 00:00:00 2001 From: KM Koushik Date: Mon, 19 May 2025 21:55:10 +1000 Subject: [PATCH] Implement per-team API rate limiting with Redis (#164) --- .cursor/rules/general.mdc | 1 + .../migration.sql | 2 + apps/web/prisma/schema.prisma | 1 + .../public-api/api/contacts/add-contact.ts | 2 +- .../public-api/api/contacts/delete-contact.ts | 2 +- .../public-api/api/contacts/get-contact.ts | 2 +- .../public-api/api/contacts/get-contacts.ts | 2 +- .../public-api/api/contacts/update-contact.ts | 2 +- .../public-api/api/contacts/upsert-contact.ts | 2 +- .../public-api/api/domains/create-domain.ts | 2 +- .../public-api/api/domains/get-domains.ts | 2 +- .../public-api/api/domains/verify-domain.ts | 2 - .../public-api/api/emails/batch-email.ts | 2 +- .../public-api/api/emails/cancel-email.ts | 2 +- .../server/public-api/api/emails/get-email.ts | 2 +- .../public-api/api/emails/send-email.ts | 2 +- .../public-api/api/emails/update-email.ts | 2 +- apps/web/src/server/public-api/auth.ts | 25 +--- apps/web/src/server/public-api/hono.ts | 109 +++++++++++++++++- 19 files changed, 125 insertions(+), 41 deletions(-) create mode 100644 apps/web/prisma/migrations/20250517053539_add_team_api_rate_limit/migration.sql diff --git a/.cursor/rules/general.mdc b/.cursor/rules/general.mdc index 8162f5d..98adbbe 100644 --- a/.cursor/rules/general.mdc +++ b/.cursor/rules/general.mdc @@ -16,3 +16,4 @@ You are a Staff Engineer and an Expert in ReactJS, NextJS, JavaScript, TypeScrip - Be concise Minimize any other prose. - If you think there might not be a correct answer, you say so. - If you do not know the answer, say so, instead of guessing. +- never, install any libraries without asking diff --git a/apps/web/prisma/migrations/20250517053539_add_team_api_rate_limit/migration.sql b/apps/web/prisma/migrations/20250517053539_add_team_api_rate_limit/migration.sql new file mode 100644 index 0000000..e2c8e12 --- /dev/null +++ b/apps/web/prisma/migrations/20250517053539_add_team_api_rate_limit/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "Team" ADD COLUMN "apiRateLimit" INTEGER NOT NULL DEFAULT 2; diff --git a/apps/web/prisma/schema.prisma b/apps/web/prisma/schema.prisma index 0e814bf..f53f6e1 100644 --- a/apps/web/prisma/schema.prisma +++ b/apps/web/prisma/schema.prisma @@ -104,6 +104,7 @@ model Team { plan Plan @default(FREE) stripeCustomerId String? @unique isActive Boolean @default(true) + apiRateLimit Int @default(2) billingEmail String? teamUsers TeamUser[] domains Domain[] diff --git a/apps/web/src/server/public-api/api/contacts/add-contact.ts b/apps/web/src/server/public-api/api/contacts/add-contact.ts index 58e9448..17e39d9 100644 --- a/apps/web/src/server/public-api/api/contacts/add-contact.ts +++ b/apps/web/src/server/public-api/api/contacts/add-contact.ts @@ -49,7 +49,7 @@ const route = createRoute({ function addContact(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const contactBook = await getContactBook(c, team.id); diff --git a/apps/web/src/server/public-api/api/contacts/delete-contact.ts b/apps/web/src/server/public-api/api/contacts/delete-contact.ts index 21b0117..4e54f43 100644 --- a/apps/web/src/server/public-api/api/contacts/delete-contact.ts +++ b/apps/web/src/server/public-api/api/contacts/delete-contact.ts @@ -39,7 +39,7 @@ const route = createRoute({ function deleteContactHandler(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; await getContactBook(c, team.id); const contactId = c.req.param("contactId"); diff --git a/apps/web/src/server/public-api/api/contacts/get-contact.ts b/apps/web/src/server/public-api/api/contacts/get-contact.ts index 615a2ba..090816f 100644 --- a/apps/web/src/server/public-api/api/contacts/get-contact.ts +++ b/apps/web/src/server/public-api/api/contacts/get-contact.ts @@ -50,7 +50,7 @@ const route = createRoute({ function getContact(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; await getContactBook(c, team.id); diff --git a/apps/web/src/server/public-api/api/contacts/get-contacts.ts b/apps/web/src/server/public-api/api/contacts/get-contacts.ts index 672b6eb..0d6589f 100644 --- a/apps/web/src/server/public-api/api/contacts/get-contacts.ts +++ b/apps/web/src/server/public-api/api/contacts/get-contacts.ts @@ -51,7 +51,7 @@ const route = createRoute({ function getContacts(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const cb = await getContactBook(c, team.id); diff --git a/apps/web/src/server/public-api/api/contacts/update-contact.ts b/apps/web/src/server/public-api/api/contacts/update-contact.ts index 4cfceec..d1846bf 100644 --- a/apps/web/src/server/public-api/api/contacts/update-contact.ts +++ b/apps/web/src/server/public-api/api/contacts/update-contact.ts @@ -52,7 +52,7 @@ const route = createRoute({ function updateContactInfo(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; await getContactBook(c, team.id); const contactId = c.req.param("contactId"); diff --git a/apps/web/src/server/public-api/api/contacts/upsert-contact.ts b/apps/web/src/server/public-api/api/contacts/upsert-contact.ts index eef2252..958902c 100644 --- a/apps/web/src/server/public-api/api/contacts/upsert-contact.ts +++ b/apps/web/src/server/public-api/api/contacts/upsert-contact.ts @@ -49,7 +49,7 @@ const route = createRoute({ function upsertContact(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const contactBook = await getContactBook(c, team.id); diff --git a/apps/web/src/server/public-api/api/domains/create-domain.ts b/apps/web/src/server/public-api/api/domains/create-domain.ts index 1c5d5d6..6d9b8db 100644 --- a/apps/web/src/server/public-api/api/domains/create-domain.ts +++ b/apps/web/src/server/public-api/api/domains/create-domain.ts @@ -34,7 +34,7 @@ const route = createRoute({ function createDomain(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const body = c.req.valid("json"); const response = await createDomainService(team.id, body.name, body.region); diff --git a/apps/web/src/server/public-api/api/domains/get-domains.ts b/apps/web/src/server/public-api/api/domains/get-domains.ts index 629527f..ccb77c1 100644 --- a/apps/web/src/server/public-api/api/domains/get-domains.ts +++ b/apps/web/src/server/public-api/api/domains/get-domains.ts @@ -21,7 +21,7 @@ const route = createRoute({ function getDomains(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const domains = await db.domain.findMany({ where: { teamId: team.id } }); diff --git a/apps/web/src/server/public-api/api/domains/verify-domain.ts b/apps/web/src/server/public-api/api/domains/verify-domain.ts index 08a83c3..f6f1ac5 100644 --- a/apps/web/src/server/public-api/api/domains/verify-domain.ts +++ b/apps/web/src/server/public-api/api/domains/verify-domain.ts @@ -33,8 +33,6 @@ const route = createRoute({ function verifyDomain(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); - await db.domain.update({ where: { id: c.req.valid("param").id }, data: { isVerifying: true }, diff --git a/apps/web/src/server/public-api/api/emails/batch-email.ts b/apps/web/src/server/public-api/api/emails/batch-email.ts index 237a0ce..422c274 100644 --- a/apps/web/src/server/public-api/api/emails/batch-email.ts +++ b/apps/web/src/server/public-api/api/emails/batch-email.ts @@ -44,7 +44,7 @@ const route = createRoute({ function sendBatch(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const emailPayloads = c.req.valid("json"); // Add teamId and apiKeyId to each email payload diff --git a/apps/web/src/server/public-api/api/emails/cancel-email.ts b/apps/web/src/server/public-api/api/emails/cancel-email.ts index 7573144..8761e0c 100644 --- a/apps/web/src/server/public-api/api/emails/cancel-email.ts +++ b/apps/web/src/server/public-api/api/emails/cancel-email.ts @@ -35,7 +35,7 @@ const route = createRoute({ function cancelScheduledEmail(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const emailId = c.req.param("emailId"); await checkIsValidEmailId(emailId, team.id); diff --git a/apps/web/src/server/public-api/api/emails/get-email.ts b/apps/web/src/server/public-api/api/emails/get-email.ts index 64fc8ad..548c4eb 100644 --- a/apps/web/src/server/public-api/api/emails/get-email.ts +++ b/apps/web/src/server/public-api/api/emails/get-email.ts @@ -57,7 +57,7 @@ const route = createRoute({ function send(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const emailId = c.req.param("emailId"); diff --git a/apps/web/src/server/public-api/api/emails/send-email.ts b/apps/web/src/server/public-api/api/emails/send-email.ts index 6a87481..d4952e0 100644 --- a/apps/web/src/server/public-api/api/emails/send-email.ts +++ b/apps/web/src/server/public-api/api/emails/send-email.ts @@ -31,7 +31,7 @@ const route = createRoute({ function send(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; let html = undefined; diff --git a/apps/web/src/server/public-api/api/emails/update-email.ts b/apps/web/src/server/public-api/api/emails/update-email.ts index 3f4e6cd..8aafe0f 100644 --- a/apps/web/src/server/public-api/api/emails/update-email.ts +++ b/apps/web/src/server/public-api/api/emails/update-email.ts @@ -45,7 +45,7 @@ const route = createRoute({ function updateEmailScheduledAt(app: PublicAPIApp) { app.openapi(route, async (c) => { - const team = await getTeamFromToken(c); + const team = c.var.team; const emailId = c.req.param("emailId"); await checkIsValidEmailId(emailId, team.id); diff --git a/apps/web/src/server/public-api/auth.ts b/apps/web/src/server/public-api/auth.ts index e86d7e2..dc581f2 100644 --- a/apps/web/src/server/public-api/auth.ts +++ b/apps/web/src/server/public-api/auth.ts @@ -1,14 +1,8 @@ -import TTLCache from "@isaacs/ttlcache"; import { Context } from "hono"; import { db } from "../db"; import { UnsendApiError } from "./api-error"; -import { env } from "~/env"; import { getTeamAndApiKey } from "../service/api-service"; - -const rateLimitCache = new TTLCache({ - ttl: 1000, // 1 second - max: 10000, -}); +import { isSelfHosted } from "~/utils/common"; /** * Gets the team from the token. Also will check if the token is valid. @@ -32,8 +26,6 @@ export const getTeamFromToken = async (c: Context) => { }); } - checkRateLimit(token); - const teamAndApiKey = await getTeamAndApiKey(token); if (!teamAndApiKey) { @@ -66,18 +58,3 @@ export const getTeamFromToken = async (c: Context) => { return { ...team, apiKeyId: apiKey.id }; }; - -const checkRateLimit = (token: string) => { - let rateLimit = rateLimitCache.get(token); - - rateLimit = rateLimit ?? 0; - - if (rateLimit >= env.API_RATE_LIMIT) { - throw new UnsendApiError({ - code: "RATE_LIMITED", - message: `Rate limit exceeded, ${env.API_RATE_LIMIT} requests per second`, - }); - } - - rateLimitCache.set(token, rateLimit + 1); -}; diff --git a/apps/web/src/server/public-api/hono.ts b/apps/web/src/server/public-api/hono.ts index b7905ea..5a6186b 100644 --- a/apps/web/src/server/public-api/hono.ts +++ b/apps/web/src/server/public-api/hono.ts @@ -1,13 +1,118 @@ import { OpenAPIHono } from "@hono/zod-openapi"; import { swaggerUI } from "@hono/swagger-ui"; +import { Context, Next } from "hono"; import { handleError } from "./api-error"; import { env } from "~/env"; +import { getRedis } from "~/server/redis"; +import { getTeamFromToken } from "~/server/public-api/auth"; +import { isSelfHosted } from "~/utils/common"; +import { UnsendApiError } from "./api-error"; +import { Team } from "@prisma/client"; + +// Define AppEnv for Hono context +export type AppEnv = { + Variables: { + team: Team & { apiKeyId: number }; + }; +}; export function getApp() { - const app = new OpenAPIHono().basePath("/api"); + const app = new OpenAPIHono().basePath("/api"); app.onError(handleError); + // Auth and Team Middleware (runs before rate limiter) + app.use("*", async (c: Context, next: Next) => { + if ( + c.req.path.startsWith("/api/v1/doc") || + c.req.path.startsWith("/api/v1/ui") || + c.req.path === "/api/health" + ) { + return next(); + } + + try { + const team = await getTeamFromToken(c as any); + c.set("team", team); + } catch (error) { + if (error instanceof UnsendApiError) { + throw error; + } + console.error("Error in getTeamFromToken middleware:", error); + throw new UnsendApiError({ + code: "INTERNAL_SERVER_ERROR", + message: "Authentication failed", + }); + } + await next(); + }); + + // Custom Rate Limiter Middleware + const RATE_LIMIT_WINDOW_SECONDS = 1; + + app.use("*", async (c: Context, next: Next) => { + // Skip for self-hosted, or if team is not set (e.g. for public/doc paths not caught earlier) + // or if the path is one of the explicitly skipped paths for auth. + if ( + isSelfHosted() || + !c.var.team || // Team should be set by auth middleware for protected routes + c.req.path.startsWith("/api/v1/doc") || + c.req.path.startsWith("/api/v1/ui") || + c.req.path === "/api/health" + ) { + return next(); + } + + const team = c.var.team; + const limit = team.apiRateLimit ?? 2; // Default limit from your previous setup + const key = `rl:${team.id}`; // Rate limit key for Redis + const redis = getRedis(); + + let currentRequests: number; + let ttl: number; + let isNewKey = false; + + try { + // Increment the key. If the key does not exist, it is created and set to 1. + currentRequests = await redis.incr(key); + + if (currentRequests === 1) { + // This is the first request in the window, set the expiry. + await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS); + } + // Get the TTL (time to live) of the key to know when it resets. + // If the key does not exist or has no expiry, TTL returns -1 or -2. + // We rely on expire being set for new keys. + ttl = await redis.ttl(key); + } catch (error) { + console.error("Redis error during rate limiting:", error); + // Alternatively, you could fail closed by throwing an error here. + return next(); + } + + const resetTime = + Math.floor(Date.now() / 1000) + + (ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS); + const remainingRequests = Math.max(0, limit - currentRequests); + + c.res.headers.set("X-RateLimit-Limit", String(limit)); + c.res.headers.set("X-RateLimit-Remaining", String(remainingRequests)); + c.res.headers.set("X-RateLimit-Reset", String(resetTime)); + + if (currentRequests > limit) { + c.res.headers.set( + "Retry-After", + String(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS) + ); + throw new UnsendApiError({ + code: "RATE_LIMITED", + message: `Rate limit exceeded. Try again in ${ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS} seconds.`, + }); + } + + await next(); + }); + // The OpenAPI documentation will be available at /doc app.doc("/v1/doc", (c) => ({ openapi: "3.0.0", @@ -28,4 +133,4 @@ export function getApp() { return app; } -export type PublicAPIApp = ReturnType; +export type PublicAPIApp = OpenAPIHono;