Implement per-team API rate limiting with Redis (#164)
This commit is contained in:
@@ -16,3 +16,4 @@ You are a Staff Engineer and an Expert in ReactJS, NextJS, JavaScript, TypeScrip
|
|||||||
- Be concise Minimize any other prose.
|
- Be concise Minimize any other prose.
|
||||||
- If you think there might not be a correct answer, you say so.
|
- If you think there might not be a correct answer, you say so.
|
||||||
- If you do not know the answer, say so, instead of guessing.
|
- If you do not know the answer, say so, instead of guessing.
|
||||||
|
- never, install any libraries without asking
|
||||||
|
@@ -0,0 +1,2 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "Team" ADD COLUMN "apiRateLimit" INTEGER NOT NULL DEFAULT 2;
|
@@ -104,6 +104,7 @@ model Team {
|
|||||||
plan Plan @default(FREE)
|
plan Plan @default(FREE)
|
||||||
stripeCustomerId String? @unique
|
stripeCustomerId String? @unique
|
||||||
isActive Boolean @default(true)
|
isActive Boolean @default(true)
|
||||||
|
apiRateLimit Int @default(2)
|
||||||
billingEmail String?
|
billingEmail String?
|
||||||
teamUsers TeamUser[]
|
teamUsers TeamUser[]
|
||||||
domains Domain[]
|
domains Domain[]
|
||||||
|
@@ -49,7 +49,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function addContact(app: PublicAPIApp) {
|
function addContact(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
const contactBook = await getContactBook(c, team.id);
|
const contactBook = await getContactBook(c, team.id);
|
||||||
|
|
||||||
|
@@ -39,7 +39,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function deleteContactHandler(app: PublicAPIApp) {
|
function deleteContactHandler(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
await getContactBook(c, team.id);
|
await getContactBook(c, team.id);
|
||||||
const contactId = c.req.param("contactId");
|
const contactId = c.req.param("contactId");
|
||||||
|
@@ -50,7 +50,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function getContact(app: PublicAPIApp) {
|
function getContact(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
await getContactBook(c, team.id);
|
await getContactBook(c, team.id);
|
||||||
|
|
||||||
|
@@ -51,7 +51,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function getContacts(app: PublicAPIApp) {
|
function getContacts(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
const cb = await getContactBook(c, team.id);
|
const cb = await getContactBook(c, team.id);
|
||||||
|
|
||||||
|
@@ -52,7 +52,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function updateContactInfo(app: PublicAPIApp) {
|
function updateContactInfo(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
await getContactBook(c, team.id);
|
await getContactBook(c, team.id);
|
||||||
const contactId = c.req.param("contactId");
|
const contactId = c.req.param("contactId");
|
||||||
|
@@ -49,7 +49,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function upsertContact(app: PublicAPIApp) {
|
function upsertContact(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
const contactBook = await getContactBook(c, team.id);
|
const contactBook = await getContactBook(c, team.id);
|
||||||
|
|
||||||
|
@@ -34,7 +34,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function createDomain(app: PublicAPIApp) {
|
function createDomain(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
const body = c.req.valid("json");
|
const body = c.req.valid("json");
|
||||||
const response = await createDomainService(team.id, body.name, body.region);
|
const response = await createDomainService(team.id, body.name, body.region);
|
||||||
|
|
||||||
|
@@ -21,7 +21,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function getDomains(app: PublicAPIApp) {
|
function getDomains(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
const domains = await db.domain.findMany({ where: { teamId: team.id } });
|
const domains = await db.domain.findMany({ where: { teamId: team.id } });
|
||||||
|
|
||||||
|
@@ -33,8 +33,6 @@ const route = createRoute({
|
|||||||
|
|
||||||
function verifyDomain(app: PublicAPIApp) {
|
function verifyDomain(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
|
||||||
|
|
||||||
await db.domain.update({
|
await db.domain.update({
|
||||||
where: { id: c.req.valid("param").id },
|
where: { id: c.req.valid("param").id },
|
||||||
data: { isVerifying: true },
|
data: { isVerifying: true },
|
||||||
|
@@ -44,7 +44,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function sendBatch(app: PublicAPIApp) {
|
function sendBatch(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
const emailPayloads = c.req.valid("json");
|
const emailPayloads = c.req.valid("json");
|
||||||
|
|
||||||
// Add teamId and apiKeyId to each email payload
|
// Add teamId and apiKeyId to each email payload
|
||||||
|
@@ -35,7 +35,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function cancelScheduledEmail(app: PublicAPIApp) {
|
function cancelScheduledEmail(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
const emailId = c.req.param("emailId");
|
const emailId = c.req.param("emailId");
|
||||||
await checkIsValidEmailId(emailId, team.id);
|
await checkIsValidEmailId(emailId, team.id);
|
||||||
|
|
||||||
|
@@ -57,7 +57,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function send(app: PublicAPIApp) {
|
function send(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
const emailId = c.req.param("emailId");
|
const emailId = c.req.param("emailId");
|
||||||
|
|
||||||
|
@@ -31,7 +31,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function send(app: PublicAPIApp) {
|
function send(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
|
|
||||||
let html = undefined;
|
let html = undefined;
|
||||||
|
|
||||||
|
@@ -45,7 +45,7 @@ const route = createRoute({
|
|||||||
|
|
||||||
function updateEmailScheduledAt(app: PublicAPIApp) {
|
function updateEmailScheduledAt(app: PublicAPIApp) {
|
||||||
app.openapi(route, async (c) => {
|
app.openapi(route, async (c) => {
|
||||||
const team = await getTeamFromToken(c);
|
const team = c.var.team;
|
||||||
const emailId = c.req.param("emailId");
|
const emailId = c.req.param("emailId");
|
||||||
|
|
||||||
await checkIsValidEmailId(emailId, team.id);
|
await checkIsValidEmailId(emailId, team.id);
|
||||||
|
@@ -1,14 +1,8 @@
|
|||||||
import TTLCache from "@isaacs/ttlcache";
|
|
||||||
import { Context } from "hono";
|
import { Context } from "hono";
|
||||||
import { db } from "../db";
|
import { db } from "../db";
|
||||||
import { UnsendApiError } from "./api-error";
|
import { UnsendApiError } from "./api-error";
|
||||||
import { env } from "~/env";
|
|
||||||
import { getTeamAndApiKey } from "../service/api-service";
|
import { getTeamAndApiKey } from "../service/api-service";
|
||||||
|
import { isSelfHosted } from "~/utils/common";
|
||||||
const rateLimitCache = new TTLCache({
|
|
||||||
ttl: 1000, // 1 second
|
|
||||||
max: 10000,
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the team from the token. Also will check if the token is valid.
|
* Gets the team from the token. Also will check if the token is valid.
|
||||||
@@ -32,8 +26,6 @@ export const getTeamFromToken = async (c: Context) => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
checkRateLimit(token);
|
|
||||||
|
|
||||||
const teamAndApiKey = await getTeamAndApiKey(token);
|
const teamAndApiKey = await getTeamAndApiKey(token);
|
||||||
|
|
||||||
if (!teamAndApiKey) {
|
if (!teamAndApiKey) {
|
||||||
@@ -66,18 +58,3 @@ export const getTeamFromToken = async (c: Context) => {
|
|||||||
|
|
||||||
return { ...team, apiKeyId: apiKey.id };
|
return { ...team, apiKeyId: apiKey.id };
|
||||||
};
|
};
|
||||||
|
|
||||||
const checkRateLimit = (token: string) => {
|
|
||||||
let rateLimit = rateLimitCache.get<number>(token);
|
|
||||||
|
|
||||||
rateLimit = rateLimit ?? 0;
|
|
||||||
|
|
||||||
if (rateLimit >= env.API_RATE_LIMIT) {
|
|
||||||
throw new UnsendApiError({
|
|
||||||
code: "RATE_LIMITED",
|
|
||||||
message: `Rate limit exceeded, ${env.API_RATE_LIMIT} requests per second`,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
rateLimitCache.set(token, rateLimit + 1);
|
|
||||||
};
|
|
||||||
|
@@ -1,13 +1,118 @@
|
|||||||
import { OpenAPIHono } from "@hono/zod-openapi";
|
import { OpenAPIHono } from "@hono/zod-openapi";
|
||||||
import { swaggerUI } from "@hono/swagger-ui";
|
import { swaggerUI } from "@hono/swagger-ui";
|
||||||
|
import { Context, Next } from "hono";
|
||||||
import { handleError } from "./api-error";
|
import { handleError } from "./api-error";
|
||||||
import { env } from "~/env";
|
import { env } from "~/env";
|
||||||
|
import { getRedis } from "~/server/redis";
|
||||||
|
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||||
|
import { isSelfHosted } from "~/utils/common";
|
||||||
|
import { UnsendApiError } from "./api-error";
|
||||||
|
import { Team } from "@prisma/client";
|
||||||
|
|
||||||
|
// Define AppEnv for Hono context
|
||||||
|
export type AppEnv = {
|
||||||
|
Variables: {
|
||||||
|
team: Team & { apiKeyId: number };
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
export function getApp() {
|
export function getApp() {
|
||||||
const app = new OpenAPIHono().basePath("/api");
|
const app = new OpenAPIHono<AppEnv>().basePath("/api");
|
||||||
|
|
||||||
app.onError(handleError);
|
app.onError(handleError);
|
||||||
|
|
||||||
|
// Auth and Team Middleware (runs before rate limiter)
|
||||||
|
app.use("*", async (c: Context<AppEnv>, next: Next) => {
|
||||||
|
if (
|
||||||
|
c.req.path.startsWith("/api/v1/doc") ||
|
||||||
|
c.req.path.startsWith("/api/v1/ui") ||
|
||||||
|
c.req.path === "/api/health"
|
||||||
|
) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const team = await getTeamFromToken(c as any);
|
||||||
|
c.set("team", team);
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof UnsendApiError) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
console.error("Error in getTeamFromToken middleware:", error);
|
||||||
|
throw new UnsendApiError({
|
||||||
|
code: "INTERNAL_SERVER_ERROR",
|
||||||
|
message: "Authentication failed",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
await next();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Custom Rate Limiter Middleware
|
||||||
|
const RATE_LIMIT_WINDOW_SECONDS = 1;
|
||||||
|
|
||||||
|
app.use("*", async (c: Context<AppEnv>, next: Next) => {
|
||||||
|
// Skip for self-hosted, or if team is not set (e.g. for public/doc paths not caught earlier)
|
||||||
|
// or if the path is one of the explicitly skipped paths for auth.
|
||||||
|
if (
|
||||||
|
isSelfHosted() ||
|
||||||
|
!c.var.team || // Team should be set by auth middleware for protected routes
|
||||||
|
c.req.path.startsWith("/api/v1/doc") ||
|
||||||
|
c.req.path.startsWith("/api/v1/ui") ||
|
||||||
|
c.req.path === "/api/health"
|
||||||
|
) {
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
const team = c.var.team;
|
||||||
|
const limit = team.apiRateLimit ?? 2; // Default limit from your previous setup
|
||||||
|
const key = `rl:${team.id}`; // Rate limit key for Redis
|
||||||
|
const redis = getRedis();
|
||||||
|
|
||||||
|
let currentRequests: number;
|
||||||
|
let ttl: number;
|
||||||
|
let isNewKey = false;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Increment the key. If the key does not exist, it is created and set to 1.
|
||||||
|
currentRequests = await redis.incr(key);
|
||||||
|
|
||||||
|
if (currentRequests === 1) {
|
||||||
|
// This is the first request in the window, set the expiry.
|
||||||
|
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
|
||||||
|
}
|
||||||
|
// Get the TTL (time to live) of the key to know when it resets.
|
||||||
|
// If the key does not exist or has no expiry, TTL returns -1 or -2.
|
||||||
|
// We rely on expire being set for new keys.
|
||||||
|
ttl = await redis.ttl(key);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Redis error during rate limiting:", error);
|
||||||
|
// Alternatively, you could fail closed by throwing an error here.
|
||||||
|
return next();
|
||||||
|
}
|
||||||
|
|
||||||
|
const resetTime =
|
||||||
|
Math.floor(Date.now() / 1000) +
|
||||||
|
(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS);
|
||||||
|
const remainingRequests = Math.max(0, limit - currentRequests);
|
||||||
|
|
||||||
|
c.res.headers.set("X-RateLimit-Limit", String(limit));
|
||||||
|
c.res.headers.set("X-RateLimit-Remaining", String(remainingRequests));
|
||||||
|
c.res.headers.set("X-RateLimit-Reset", String(resetTime));
|
||||||
|
|
||||||
|
if (currentRequests > limit) {
|
||||||
|
c.res.headers.set(
|
||||||
|
"Retry-After",
|
||||||
|
String(ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS)
|
||||||
|
);
|
||||||
|
throw new UnsendApiError({
|
||||||
|
code: "RATE_LIMITED",
|
||||||
|
message: `Rate limit exceeded. Try again in ${ttl > 0 ? ttl : RATE_LIMIT_WINDOW_SECONDS} seconds.`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
await next();
|
||||||
|
});
|
||||||
|
|
||||||
// The OpenAPI documentation will be available at /doc
|
// The OpenAPI documentation will be available at /doc
|
||||||
app.doc("/v1/doc", (c) => ({
|
app.doc("/v1/doc", (c) => ({
|
||||||
openapi: "3.0.0",
|
openapi: "3.0.0",
|
||||||
@@ -28,4 +133,4 @@ export function getApp() {
|
|||||||
return app;
|
return app;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type PublicAPIApp = ReturnType<typeof getApp>;
|
export type PublicAPIApp = OpenAPIHono<AppEnv>;
|
||||||
|
Reference in New Issue
Block a user