Implement per-team API rate limiting with Redis (#164)

This commit is contained in:
KM Koushik
2025-05-19 21:55:10 +10:00
committed by GitHub
parent 14557a96ac
commit 89bf97488e
19 changed files with 125 additions and 41 deletions

View File

@@ -16,3 +16,4 @@ You are a Staff Engineer and an Expert in ReactJS, NextJS, JavaScript, TypeScrip
- Be concise Minimize any other prose. - Be concise Minimize any other prose.
- If you think there might not be a correct answer, you say so. - If you think there might not be a correct answer, you say so.
- If you do not know the answer, say so, instead of guessing. - If you do not know the answer, say so, instead of guessing.
- never, install any libraries without asking

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "Team" ADD COLUMN "apiRateLimit" INTEGER NOT NULL DEFAULT 2;

View File

@@ -104,6 +104,7 @@ model Team {
plan Plan @default(FREE) plan Plan @default(FREE)
stripeCustomerId String? @unique stripeCustomerId String? @unique
isActive Boolean @default(true) isActive Boolean @default(true)
apiRateLimit Int @default(2)
billingEmail String? billingEmail String?
teamUsers TeamUser[] teamUsers TeamUser[]
domains Domain[] domains Domain[]

View File

@@ -49,7 +49,7 @@ const route = createRoute({
function addContact(app: PublicAPIApp) { function addContact(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const contactBook = await getContactBook(c, team.id); const contactBook = await getContactBook(c, team.id);

View File

@@ -39,7 +39,7 @@ const route = createRoute({
function deleteContactHandler(app: PublicAPIApp) { function deleteContactHandler(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
await getContactBook(c, team.id); await getContactBook(c, team.id);
const contactId = c.req.param("contactId"); const contactId = c.req.param("contactId");

View File

@@ -50,7 +50,7 @@ const route = createRoute({
function getContact(app: PublicAPIApp) { function getContact(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
await getContactBook(c, team.id); await getContactBook(c, team.id);

View File

@@ -51,7 +51,7 @@ const route = createRoute({
function getContacts(app: PublicAPIApp) { function getContacts(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const cb = await getContactBook(c, team.id); const cb = await getContactBook(c, team.id);

View File

@@ -52,7 +52,7 @@ const route = createRoute({
function updateContactInfo(app: PublicAPIApp) { function updateContactInfo(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
await getContactBook(c, team.id); await getContactBook(c, team.id);
const contactId = c.req.param("contactId"); const contactId = c.req.param("contactId");

View File

@@ -49,7 +49,7 @@ const route = createRoute({
function upsertContact(app: PublicAPIApp) { function upsertContact(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const contactBook = await getContactBook(c, team.id); const contactBook = await getContactBook(c, team.id);

View File

@@ -34,7 +34,7 @@ const route = createRoute({
function createDomain(app: PublicAPIApp) { function createDomain(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const body = c.req.valid("json"); const body = c.req.valid("json");
const response = await createDomainService(team.id, body.name, body.region); const response = await createDomainService(team.id, body.name, body.region);

View File

@@ -21,7 +21,7 @@ const route = createRoute({
function getDomains(app: PublicAPIApp) { function getDomains(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const domains = await db.domain.findMany({ where: { teamId: team.id } }); const domains = await db.domain.findMany({ where: { teamId: team.id } });

View File

@@ -33,8 +33,6 @@ const route = createRoute({
function verifyDomain(app: PublicAPIApp) { function verifyDomain(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c);
await db.domain.update({ await db.domain.update({
where: { id: c.req.valid("param").id }, where: { id: c.req.valid("param").id },
data: { isVerifying: true }, data: { isVerifying: true },

View File

@@ -44,7 +44,7 @@ const route = createRoute({
function sendBatch(app: PublicAPIApp) { function sendBatch(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const emailPayloads = c.req.valid("json"); const emailPayloads = c.req.valid("json");
// Add teamId and apiKeyId to each email payload // Add teamId and apiKeyId to each email payload

View File

@@ -35,7 +35,7 @@ const route = createRoute({
function cancelScheduledEmail(app: PublicAPIApp) { function cancelScheduledEmail(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const emailId = c.req.param("emailId"); const emailId = c.req.param("emailId");
await checkIsValidEmailId(emailId, team.id); await checkIsValidEmailId(emailId, team.id);

View File

@@ -57,7 +57,7 @@ const route = createRoute({
function send(app: PublicAPIApp) { function send(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const emailId = c.req.param("emailId"); const emailId = c.req.param("emailId");

View File

@@ -31,7 +31,7 @@ const route = createRoute({
function send(app: PublicAPIApp) { function send(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
let html = undefined; let html = undefined;

View File

@@ -45,7 +45,7 @@ const route = createRoute({
function updateEmailScheduledAt(app: PublicAPIApp) { function updateEmailScheduledAt(app: PublicAPIApp) {
app.openapi(route, async (c) => { app.openapi(route, async (c) => {
const team = await getTeamFromToken(c); const team = c.var.team;
const emailId = c.req.param("emailId"); const emailId = c.req.param("emailId");
await checkIsValidEmailId(emailId, team.id); await checkIsValidEmailId(emailId, team.id);

View File

@@ -1,14 +1,8 @@
import TTLCache from "@isaacs/ttlcache";
import { Context } from "hono"; import { Context } from "hono";
import { db } from "../db"; import { db } from "../db";
import { UnsendApiError } from "./api-error"; import { UnsendApiError } from "./api-error";
import { env } from "~/env";
import { getTeamAndApiKey } from "../service/api-service"; import { getTeamAndApiKey } from "../service/api-service";
import { isSelfHosted } from "~/utils/common";
const rateLimitCache = new TTLCache({
ttl: 1000, // 1 second
max: 10000,
});
/** /**
* Gets the team from the token. Also will check if the token is valid. * Gets the team from the token. Also will check if the token is valid.
@@ -32,8 +26,6 @@ export const getTeamFromToken = async (c: Context) => {
}); });
} }
checkRateLimit(token);
const teamAndApiKey = await getTeamAndApiKey(token); const teamAndApiKey = await getTeamAndApiKey(token);
if (!teamAndApiKey) { if (!teamAndApiKey) {
@@ -66,18 +58,3 @@ export const getTeamFromToken = async (c: Context) => {
return { ...team, apiKeyId: apiKey.id }; return { ...team, apiKeyId: apiKey.id };
}; };
const checkRateLimit = (token: string) => {
let rateLimit = rateLimitCache.get<number>(token);
rateLimit = rateLimit ?? 0;
if (rateLimit >= env.API_RATE_LIMIT) {
throw new UnsendApiError({
code: "RATE_LIMITED",
message: `Rate limit exceeded, ${env.API_RATE_LIMIT} requests per second`,
});
}
rateLimitCache.set(token, rateLimit + 1);
};

View File

@@ -1,13 +1,118 @@
import { OpenAPIHono } from "@hono/zod-openapi"; import { OpenAPIHono } from "@hono/zod-openapi";
import { swaggerUI } from "@hono/swagger-ui"; import { swaggerUI } from "@hono/swagger-ui";
import { Context, Next } from "hono";
import { handleError } from "./api-error"; import { handleError } from "./api-error";
import { env } from "~/env"; import { env } from "~/env";
import { getRedis } from "~/server/redis";
import { getTeamFromToken } from "~/server/public-api/auth";
import { isSelfHosted } from "~/utils/common";
import { UnsendApiError } from "./api-error";
import { Team } from "@prisma/client";
// Define AppEnv for Hono context
export type AppEnv = {
Variables: {
team: Team & { apiKeyId: number };
};
};
export function getApp() { export function getApp() {
const app = new OpenAPIHono().basePath("/api"); const app = new OpenAPIHono<AppEnv>().basePath("/api");
app.onError(handleError); app.onError(handleError);
// Auth and Team Middleware (runs before rate limiter)
app.use("*", async (c: Context<AppEnv>, next: Next) => {
if (
c.req.path.startsWith("/api/v1/doc") ||
c.req.path.startsWith("/api/v1/ui") ||
c.req.path === "/api/health"
) {
return next();
}
try {
const team = await getTeamFromToken(c as any);
c.set("team", team);
} catch (error) {
if (error instanceof UnsendApiError) {
throw error;
}
console.error("Error in getTeamFromToken middleware:", error);
throw new UnsendApiError({
code: "INTERNAL_SERVER_ERROR",
message: "Authentication failed",
});
}
await next();
});
// Custom Rate Limiter Middleware
const RATE_LIMIT_WINDOW_SECONDS = 1;
app.use("*", async (c: Context<AppEnv>, next: Next) => {
// Skip for self-hosted, or if team is not set (e.g. for public/doc paths not caught earlier)
// or if the path is one of the explicitly skipped paths for auth.
if (
isSelfHosted() ||
!c.var.team || // Team should be set by auth middleware for protected routes
c.req.path.startsWith("/api/v1/doc") ||
c.req.path.startsWith("/api/v1/ui") ||
c.req.path === "/api/health"
) {
return next();
}
const team = c.var.team;
const limit = team.apiRateLimit ?? 2; // Default limit from your previous setup
const key = `rl:${team.id}`; // Rate limit key for Redis
const redis = getRedis();
let currentRequests: number;
let ttl: number;
let isNewKey = false;
try {
// Increment the key. If the key does not exist, it is created and set to 1.
currentRequests = await redis.incr(key);
if (currentRequests === 1) {
// This is the first request in the window, set the expiry.
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
}
// Get the TTL (time to live) of the key to know when it resets.
// If the key does not exist or has no expiry, TTL returns -1 or -2.
// We rely on expire being set for new keys.
ttl = await redis.ttl(key);
} catch (error) {
console.error("Redis error during rate limiting:", error);
// Alternatively, you could fail closed by throwing an error here.
return next();
}
const resetTime =
Math.floor(Date.now() / 1000) +
(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS);
const remainingRequests = Math.max(0, limit - currentRequests);
c.res.headers.set("X-RateLimit-Limit", String(limit));
c.res.headers.set("X-RateLimit-Remaining", String(remainingRequests));
c.res.headers.set("X-RateLimit-Reset", String(resetTime));
if (currentRequests > limit) {
c.res.headers.set(
"Retry-After",
String(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS)
);
throw new UnsendApiError({
code: "RATE_LIMITED",
message: `Rate limit exceeded. Try again in ${ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS} seconds.`,
});
}
await next();
});
// The OpenAPI documentation will be available at /doc // The OpenAPI documentation will be available at /doc
app.doc("/v1/doc", (c) => ({ app.doc("/v1/doc", (c) => ({
openapi: "3.0.0", openapi: "3.0.0",
@@ -28,4 +133,4 @@ export function getApp() {
return app; return app;
} }
export type PublicAPIApp = ReturnType<typeof getApp>; export type PublicAPIApp = OpenAPIHono<AppEnv>;