Implement per-team API rate limiting with Redis (#164)
This commit is contained in:
@@ -16,3 +16,4 @@ You are a Staff Engineer and an Expert in ReactJS, NextJS, JavaScript, TypeScrip
|
||||
- Be concise Minimize any other prose.
|
||||
- If you think there might not be a correct answer, you say so.
|
||||
- If you do not know the answer, say so, instead of guessing.
|
||||
- never, install any libraries without asking
|
||||
|
@@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "Team" ADD COLUMN "apiRateLimit" INTEGER NOT NULL DEFAULT 2;
|
@@ -104,6 +104,7 @@ model Team {
|
||||
plan Plan @default(FREE)
|
||||
stripeCustomerId String? @unique
|
||||
isActive Boolean @default(true)
|
||||
apiRateLimit Int @default(2)
|
||||
billingEmail String?
|
||||
teamUsers TeamUser[]
|
||||
domains Domain[]
|
||||
|
@@ -49,7 +49,7 @@ const route = createRoute({
|
||||
|
||||
function addContact(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
const contactBook = await getContactBook(c, team.id);
|
||||
|
||||
|
@@ -39,7 +39,7 @@ const route = createRoute({
|
||||
|
||||
function deleteContactHandler(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
await getContactBook(c, team.id);
|
||||
const contactId = c.req.param("contactId");
|
||||
|
@@ -50,7 +50,7 @@ const route = createRoute({
|
||||
|
||||
function getContact(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
await getContactBook(c, team.id);
|
||||
|
||||
|
@@ -51,7 +51,7 @@ const route = createRoute({
|
||||
|
||||
function getContacts(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
const cb = await getContactBook(c, team.id);
|
||||
|
||||
|
@@ -52,7 +52,7 @@ const route = createRoute({
|
||||
|
||||
function updateContactInfo(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
await getContactBook(c, team.id);
|
||||
const contactId = c.req.param("contactId");
|
||||
|
@@ -49,7 +49,7 @@ const route = createRoute({
|
||||
|
||||
function upsertContact(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
const contactBook = await getContactBook(c, team.id);
|
||||
|
||||
|
@@ -34,7 +34,7 @@ const route = createRoute({
|
||||
|
||||
function createDomain(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
const body = c.req.valid("json");
|
||||
const response = await createDomainService(team.id, body.name, body.region);
|
||||
|
||||
|
@@ -21,7 +21,7 @@ const route = createRoute({
|
||||
|
||||
function getDomains(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
const domains = await db.domain.findMany({ where: { teamId: team.id } });
|
||||
|
||||
|
@@ -33,8 +33,6 @@ const route = createRoute({
|
||||
|
||||
function verifyDomain(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
|
||||
await db.domain.update({
|
||||
where: { id: c.req.valid("param").id },
|
||||
data: { isVerifying: true },
|
||||
|
@@ -44,7 +44,7 @@ const route = createRoute({
|
||||
|
||||
function sendBatch(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
const emailPayloads = c.req.valid("json");
|
||||
|
||||
// Add teamId and apiKeyId to each email payload
|
||||
|
@@ -35,7 +35,7 @@ const route = createRoute({
|
||||
|
||||
function cancelScheduledEmail(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
const emailId = c.req.param("emailId");
|
||||
await checkIsValidEmailId(emailId, team.id);
|
||||
|
||||
|
@@ -57,7 +57,7 @@ const route = createRoute({
|
||||
|
||||
function send(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
const emailId = c.req.param("emailId");
|
||||
|
||||
|
@@ -31,7 +31,7 @@ const route = createRoute({
|
||||
|
||||
function send(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
|
||||
let html = undefined;
|
||||
|
||||
|
@@ -45,7 +45,7 @@ const route = createRoute({
|
||||
|
||||
function updateEmailScheduledAt(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = await getTeamFromToken(c);
|
||||
const team = c.var.team;
|
||||
const emailId = c.req.param("emailId");
|
||||
|
||||
await checkIsValidEmailId(emailId, team.id);
|
||||
|
@@ -1,14 +1,8 @@
|
||||
import TTLCache from "@isaacs/ttlcache";
|
||||
import { Context } from "hono";
|
||||
import { db } from "../db";
|
||||
import { UnsendApiError } from "./api-error";
|
||||
import { env } from "~/env";
|
||||
import { getTeamAndApiKey } from "../service/api-service";
|
||||
|
||||
const rateLimitCache = new TTLCache({
|
||||
ttl: 1000, // 1 second
|
||||
max: 10000,
|
||||
});
|
||||
import { isSelfHosted } from "~/utils/common";
|
||||
|
||||
/**
|
||||
* Gets the team from the token. Also will check if the token is valid.
|
||||
@@ -32,8 +26,6 @@ export const getTeamFromToken = async (c: Context) => {
|
||||
});
|
||||
}
|
||||
|
||||
checkRateLimit(token);
|
||||
|
||||
const teamAndApiKey = await getTeamAndApiKey(token);
|
||||
|
||||
if (!teamAndApiKey) {
|
||||
@@ -66,18 +58,3 @@ export const getTeamFromToken = async (c: Context) => {
|
||||
|
||||
return { ...team, apiKeyId: apiKey.id };
|
||||
};
|
||||
|
||||
const checkRateLimit = (token: string) => {
|
||||
let rateLimit = rateLimitCache.get<number>(token);
|
||||
|
||||
rateLimit = rateLimit ?? 0;
|
||||
|
||||
if (rateLimit >= env.API_RATE_LIMIT) {
|
||||
throw new UnsendApiError({
|
||||
code: "RATE_LIMITED",
|
||||
message: `Rate limit exceeded, ${env.API_RATE_LIMIT} requests per second`,
|
||||
});
|
||||
}
|
||||
|
||||
rateLimitCache.set(token, rateLimit + 1);
|
||||
};
|
||||
|
@@ -1,13 +1,118 @@
|
||||
import { OpenAPIHono } from "@hono/zod-openapi";
|
||||
import { swaggerUI } from "@hono/swagger-ui";
|
||||
import { Context, Next } from "hono";
|
||||
import { handleError } from "./api-error";
|
||||
import { env } from "~/env";
|
||||
import { getRedis } from "~/server/redis";
|
||||
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||
import { isSelfHosted } from "~/utils/common";
|
||||
import { UnsendApiError } from "./api-error";
|
||||
import { Team } from "@prisma/client";
|
||||
|
||||
// Define AppEnv for Hono context
|
||||
export type AppEnv = {
|
||||
Variables: {
|
||||
team: Team & { apiKeyId: number };
|
||||
};
|
||||
};
|
||||
|
||||
export function getApp() {
|
||||
const app = new OpenAPIHono().basePath("/api");
|
||||
const app = new OpenAPIHono<AppEnv>().basePath("/api");
|
||||
|
||||
app.onError(handleError);
|
||||
|
||||
// Auth and Team Middleware (runs before rate limiter)
|
||||
app.use("*", async (c: Context<AppEnv>, next: Next) => {
|
||||
if (
|
||||
c.req.path.startsWith("/api/v1/doc") ||
|
||||
c.req.path.startsWith("/api/v1/ui") ||
|
||||
c.req.path === "/api/health"
|
||||
) {
|
||||
return next();
|
||||
}
|
||||
|
||||
try {
|
||||
const team = await getTeamFromToken(c as any);
|
||||
c.set("team", team);
|
||||
} catch (error) {
|
||||
if (error instanceof UnsendApiError) {
|
||||
throw error;
|
||||
}
|
||||
console.error("Error in getTeamFromToken middleware:", error);
|
||||
throw new UnsendApiError({
|
||||
code: "INTERNAL_SERVER_ERROR",
|
||||
message: "Authentication failed",
|
||||
});
|
||||
}
|
||||
await next();
|
||||
});
|
||||
|
||||
// Custom Rate Limiter Middleware
|
||||
const RATE_LIMIT_WINDOW_SECONDS = 1;
|
||||
|
||||
app.use("*", async (c: Context<AppEnv>, next: Next) => {
|
||||
// Skip for self-hosted, or if team is not set (e.g. for public/doc paths not caught earlier)
|
||||
// or if the path is one of the explicitly skipped paths for auth.
|
||||
if (
|
||||
isSelfHosted() ||
|
||||
!c.var.team || // Team should be set by auth middleware for protected routes
|
||||
c.req.path.startsWith("/api/v1/doc") ||
|
||||
c.req.path.startsWith("/api/v1/ui") ||
|
||||
c.req.path === "/api/health"
|
||||
) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const team = c.var.team;
|
||||
const limit = team.apiRateLimit ?? 2; // Default limit from your previous setup
|
||||
const key = `rl:${team.id}`; // Rate limit key for Redis
|
||||
const redis = getRedis();
|
||||
|
||||
let currentRequests: number;
|
||||
let ttl: number;
|
||||
let isNewKey = false;
|
||||
|
||||
try {
|
||||
// Increment the key. If the key does not exist, it is created and set to 1.
|
||||
currentRequests = await redis.incr(key);
|
||||
|
||||
if (currentRequests === 1) {
|
||||
// This is the first request in the window, set the expiry.
|
||||
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
|
||||
}
|
||||
// Get the TTL (time to live) of the key to know when it resets.
|
||||
// If the key does not exist or has no expiry, TTL returns -1 or -2.
|
||||
// We rely on expire being set for new keys.
|
||||
ttl = await redis.ttl(key);
|
||||
} catch (error) {
|
||||
console.error("Redis error during rate limiting:", error);
|
||||
// Alternatively, you could fail closed by throwing an error here.
|
||||
return next();
|
||||
}
|
||||
|
||||
const resetTime =
|
||||
Math.floor(Date.now() / 1000) +
|
||||
(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS);
|
||||
const remainingRequests = Math.max(0, limit - currentRequests);
|
||||
|
||||
c.res.headers.set("X-RateLimit-Limit", String(limit));
|
||||
c.res.headers.set("X-RateLimit-Remaining", String(remainingRequests));
|
||||
c.res.headers.set("X-RateLimit-Reset", String(resetTime));
|
||||
|
||||
if (currentRequests > limit) {
|
||||
c.res.headers.set(
|
||||
"Retry-After",
|
||||
String(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS)
|
||||
);
|
||||
throw new UnsendApiError({
|
||||
code: "RATE_LIMITED",
|
||||
message: `Rate limit exceeded. Try again in ${ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS} seconds.`,
|
||||
});
|
||||
}
|
||||
|
||||
await next();
|
||||
});
|
||||
|
||||
// The OpenAPI documentation will be available at /doc
|
||||
app.doc("/v1/doc", (c) => ({
|
||||
openapi: "3.0.0",
|
||||
@@ -28,4 +133,4 @@ export function getApp() {
|
||||
return app;
|
||||
}
|
||||
|
||||
export type PublicAPIApp = ReturnType<typeof getApp>;
|
||||
export type PublicAPIApp = OpenAPIHono<AppEnv>;
|
||||
|
Reference in New Issue
Block a user