feat: add domain-based access control for API keys (#198)
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
KM Koushik
parent
dbc6996d9a
commit
0817b0c7a5
@@ -0,0 +1,5 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "ApiKey" ADD COLUMN "domainId" INTEGER;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ApiKey" ADD CONSTRAINT "ApiKey_domainId_fkey" FOREIGN KEY ("domainId") REFERENCES "Domain"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
@@ -195,6 +195,7 @@ model Domain {
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
team Team @relation(fields: [teamId], references: [id], onDelete: Cascade)
|
||||
apiKeys ApiKey[]
|
||||
}
|
||||
|
||||
enum ApiPermission {
|
||||
@@ -209,11 +210,13 @@ model ApiKey {
|
||||
partialToken String
|
||||
name String
|
||||
permission ApiPermission @default(SENDING)
|
||||
domainId Int?
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
lastUsed DateTime?
|
||||
teamId Int
|
||||
team Team @relation(fields: [teamId], references: [id], onDelete: Cascade)
|
||||
domain Domain? @relation(fields: [domainId], references: [id], onDelete: Cascade)
|
||||
}
|
||||
|
||||
enum EmailStatus {
|
||||
|
@@ -27,11 +27,20 @@ import {
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@usesend/ui/src/form";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@unsend/ui/src/select";
|
||||
|
||||
|
||||
const apiKeySchema = z.object({
|
||||
name: z.string({ required_error: "Name is required" }).min(1, {
|
||||
message: "Name is required",
|
||||
}),
|
||||
domainId: z.string().optional(),
|
||||
});
|
||||
|
||||
export default function AddApiKey() {
|
||||
@@ -40,6 +49,8 @@ export default function AddApiKey() {
|
||||
const createApiKeyMutation = api.apiKey.createToken.useMutation();
|
||||
const [isCopied, setIsCopied] = useState(false);
|
||||
const [showApiKey, setShowApiKey] = useState(false);
|
||||
|
||||
const domainsQuery = api.domain.domains.useQuery();
|
||||
|
||||
const utils = api.useUtils();
|
||||
|
||||
@@ -47,6 +58,7 @@ export default function AddApiKey() {
|
||||
resolver: zodResolver(apiKeySchema),
|
||||
defaultValues: {
|
||||
name: "",
|
||||
domainId: "all",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -55,6 +67,7 @@ export default function AddApiKey() {
|
||||
{
|
||||
name: values.name,
|
||||
permission: "FULL",
|
||||
domainId: values.domainId === "all" ? undefined : Number(values.domainId),
|
||||
},
|
||||
{
|
||||
onSuccess: (data) => {
|
||||
@@ -180,6 +193,33 @@ export default function AddApiKey() {
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={apiKeyForm.control}
|
||||
name="domainId"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Domain access</FormLabel>
|
||||
<Select onValueChange={field.onChange} defaultValue={field.value}>
|
||||
<FormControl>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select domain access" />
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">All Domains</SelectItem>
|
||||
{domainsQuery.data?.map((domain: { id: number; name: string }) => (
|
||||
<SelectItem key={domain.id} value={domain.id.toString()}>
|
||||
{domain.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<FormDescription>
|
||||
Choose which domain this API key can send emails from.
|
||||
</FormDescription>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<div className="flex justify-end">
|
||||
<Button
|
||||
className=" w-[100px] hover:bg-gray-100 focus:bg-gray-100"
|
||||
|
@@ -25,6 +25,7 @@ export default function ApiList() {
|
||||
<TableHead className="rounded-tl-xl">Name</TableHead>
|
||||
<TableHead>Token</TableHead>
|
||||
<TableHead>Permission</TableHead>
|
||||
<TableHead>Domain Access</TableHead>
|
||||
<TableHead>Last used</TableHead>
|
||||
<TableHead>Created at</TableHead>
|
||||
<TableHead className="rounded-tr-xl">Action</TableHead>
|
||||
@@ -33,7 +34,7 @@ export default function ApiList() {
|
||||
<TableBody>
|
||||
{apiKeysQuery.isLoading ? (
|
||||
<TableRow className="h-32">
|
||||
<TableCell colSpan={6} className="text-center py-4">
|
||||
<TableCell colSpan={7} className="text-center py-4">
|
||||
<Spinner
|
||||
className="w-6 h-6 mx-auto"
|
||||
innerSvgClass="stroke-primary"
|
||||
@@ -42,7 +43,7 @@ export default function ApiList() {
|
||||
</TableRow>
|
||||
) : apiKeysQuery.data?.length === 0 ? (
|
||||
<TableRow className="h-32">
|
||||
<TableCell colSpan={6} className="text-center py-4">
|
||||
<TableCell colSpan={7} className="text-center py-4">
|
||||
<p>No API keys added</p>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
@@ -52,9 +53,14 @@ export default function ApiList() {
|
||||
<TableCell>{apiKey.name}</TableCell>
|
||||
<TableCell>{apiKey.partialToken}</TableCell>
|
||||
<TableCell>{apiKey.permission}</TableCell>
|
||||
<TableCell>
|
||||
{apiKey.domainId
|
||||
? apiKey.domain?.name ?? "Domain removed"
|
||||
: "All domains"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{apiKey.lastUsed
|
||||
? formatDistanceToNow(apiKey.lastUsed)
|
||||
? formatDistanceToNow(apiKey.lastUsed, { addSuffix: true })
|
||||
: "Never"}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
|
@@ -1,4 +1,6 @@
|
||||
import { z } from "zod";
|
||||
import { ApiPermission } from "@prisma/client";
|
||||
import { TRPCError } from "@trpc/server";
|
||||
|
||||
import {
|
||||
apiKeyProcedure,
|
||||
@@ -10,13 +12,18 @@ import { addApiKey, deleteApiKey } from "~/server/service/api-service";
|
||||
export const apiRouter = createTRPCRouter({
|
||||
createToken: teamProcedure
|
||||
.input(
|
||||
z.object({ name: z.string(), permission: z.enum(["FULL", "SENDING"]) })
|
||||
z.object({
|
||||
name: z.string(),
|
||||
permission: z.nativeEnum(ApiPermission),
|
||||
domainId: z.number().int().positive().optional(),
|
||||
})
|
||||
)
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
return addApiKey({
|
||||
return await addApiKey({
|
||||
name: input.name,
|
||||
permission: input.permission,
|
||||
teamId: ctx.team.id,
|
||||
domainId: input.domainId,
|
||||
});
|
||||
}),
|
||||
|
||||
@@ -32,6 +39,12 @@ export const apiRouter = createTRPCRouter({
|
||||
partialToken: true,
|
||||
lastUsed: true,
|
||||
createdAt: true,
|
||||
domainId: true,
|
||||
domain: {
|
||||
select: {
|
||||
name: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
@@ -33,3 +33,26 @@ export const checkIsValidEmailId = async (emailId: string, teamId: number) => {
|
||||
throw new UnsendApiError({ code: "NOT_FOUND", message: "Email not found" });
|
||||
}
|
||||
};
|
||||
|
||||
export const checkIsValidEmailIdWithDomainRestriction = async (
|
||||
emailId: string,
|
||||
teamId: number,
|
||||
apiKeyDomainId?: number
|
||||
) => {
|
||||
const whereClause: { id: string; teamId: number; domainId?: number } = {
|
||||
id: emailId,
|
||||
teamId,
|
||||
};
|
||||
|
||||
if (apiKeyDomainId !== undefined) {
|
||||
whereClause.domainId = apiKeyDomainId;
|
||||
}
|
||||
|
||||
const email = await db.email.findUnique({ where: whereClause });
|
||||
|
||||
if (!email) {
|
||||
throw new UnsendApiError({ code: "NOT_FOUND", message: "Email not found" });
|
||||
}
|
||||
|
||||
return email;
|
||||
};
|
||||
|
@@ -2,7 +2,6 @@ import { createRoute, z } from "@hono/zod-openapi";
|
||||
import { DomainSchema } from "~/lib/zod/domain-schema";
|
||||
import { PublicAPIApp } from "~/server/public-api/hono";
|
||||
import { db } from "~/server/db";
|
||||
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||
|
||||
const route = createRoute({
|
||||
method: "get",
|
||||
@@ -14,7 +13,7 @@ const route = createRoute({
|
||||
schema: z.array(DomainSchema),
|
||||
},
|
||||
},
|
||||
description: "Retrieve the user",
|
||||
description: "Retrieve domains accessible by the API key",
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -23,7 +22,12 @@ function getDomains(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = c.var.team;
|
||||
|
||||
const domains = await db.domain.findMany({ where: { teamId: team.id } });
|
||||
// If API key is restricted to a specific domain, only return that domain; else return all team domains
|
||||
const domains = team.apiKey.domainId
|
||||
? await db.domain.findMany({
|
||||
where: { teamId: team.id, id: team.apiKey.domainId },
|
||||
})
|
||||
: await db.domain.findMany({ where: { teamId: team.id } });
|
||||
|
||||
return c.json(domains);
|
||||
});
|
||||
|
@@ -1,6 +1,5 @@
|
||||
import { createRoute, z } from "@hono/zod-openapi";
|
||||
import { PublicAPIApp } from "~/server/public-api/hono";
|
||||
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||
import { db } from "~/server/db";
|
||||
|
||||
const route = createRoute({
|
||||
@@ -26,15 +25,70 @@ const route = createRoute({
|
||||
}),
|
||||
},
|
||||
},
|
||||
description: "Create a new domain",
|
||||
description: "Verify domain",
|
||||
},
|
||||
403: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: z.object({
|
||||
error: z.string(),
|
||||
}),
|
||||
},
|
||||
},
|
||||
description: "Forbidden - API key doesn't have access to this domain",
|
||||
},
|
||||
404: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: z.object({
|
||||
error: z.string(),
|
||||
}),
|
||||
},
|
||||
},
|
||||
description: "Domain not found",
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
function verifyDomain(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = c.var.team;
|
||||
const domainId = c.req.valid("param").id;
|
||||
|
||||
// Check if API key has access to this domain
|
||||
let domain = null;
|
||||
|
||||
if (team.apiKey.domainId) {
|
||||
// If API key is restricted to a specific domain, verify the requested domain matches
|
||||
if (domainId === team.apiKey.domainId) {
|
||||
domain = await db.domain.findFirst({
|
||||
where: {
|
||||
teamId: team.id,
|
||||
id: domainId
|
||||
},
|
||||
});
|
||||
}
|
||||
// If domainId doesn't match the API key's restriction, domain remains null
|
||||
} else {
|
||||
// API key has access to all team domains
|
||||
domain = await db.domain.findFirst({
|
||||
where: {
|
||||
teamId: team.id,
|
||||
id: domainId
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (!domain) {
|
||||
return c.json({
|
||||
error: team.apiKey.domainId
|
||||
? "API key doesn't have access to this domain"
|
||||
: "Domain not found"
|
||||
}, 404);
|
||||
}
|
||||
|
||||
await db.domain.update({
|
||||
where: { id: c.req.valid("param").id },
|
||||
where: { id: domainId },
|
||||
data: { isVerifying: true },
|
||||
});
|
||||
|
||||
|
@@ -2,7 +2,7 @@ import { createRoute, z } from "@hono/zod-openapi";
|
||||
import { PublicAPIApp } from "~/server/public-api/hono";
|
||||
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||
import { cancelEmail } from "~/server/service/email-service";
|
||||
import { checkIsValidEmailId } from "../../api-utils";
|
||||
import { checkIsValidEmailIdWithDomainRestriction } from "../../api-utils";
|
||||
|
||||
const route = createRoute({
|
||||
method: "post",
|
||||
@@ -37,7 +37,7 @@ function cancelScheduledEmail(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = c.var.team;
|
||||
const emailId = c.req.param("emailId");
|
||||
await checkIsValidEmailId(emailId, team.id);
|
||||
await checkIsValidEmailIdWithDomainRestriction(emailId, team.id, team.apiKey.domainId);
|
||||
|
||||
await cancelEmail(emailId);
|
||||
|
||||
|
@@ -58,14 +58,19 @@ const route = createRoute({
|
||||
function send(app: PublicAPIApp) {
|
||||
app.openapi(route, async (c) => {
|
||||
const team = c.var.team;
|
||||
|
||||
const emailId = c.req.param("emailId");
|
||||
|
||||
const whereClause: { id: string; teamId: number; domainId?: number } = {
|
||||
id: emailId,
|
||||
teamId: team.id,
|
||||
};
|
||||
|
||||
if (team.apiKey.domainId !== null) {
|
||||
whereClause.domainId = team.apiKey.domainId;
|
||||
}
|
||||
|
||||
const email = await db.email.findUnique({
|
||||
where: {
|
||||
id: emailId,
|
||||
teamId: team.id,
|
||||
},
|
||||
where: whereClause,
|
||||
select: {
|
||||
id: true,
|
||||
teamId: true,
|
||||
|
@@ -123,7 +123,9 @@ function listEmails(app: PublicAPIApp) {
|
||||
};
|
||||
}
|
||||
|
||||
if (domainId && domainId.length > 0) {
|
||||
if (team.apiKey.domainId !== null) {
|
||||
whereClause.domainId = team.apiKey.domainId;
|
||||
} else if (domainId && domainId.length > 0) {
|
||||
whereClause.domainId = { in: domainId };
|
||||
}
|
||||
|
||||
|
@@ -2,7 +2,7 @@ import { createRoute, z } from "@hono/zod-openapi";
|
||||
import { PublicAPIApp } from "~/server/public-api/hono";
|
||||
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||
import { updateEmail } from "~/server/service/email-service";
|
||||
import { checkIsValidEmailId } from "../../api-utils";
|
||||
import { checkIsValidEmailIdWithDomainRestriction } from "../../api-utils";
|
||||
|
||||
const route = createRoute({
|
||||
method: "patch",
|
||||
@@ -48,7 +48,7 @@ function updateEmailScheduledAt(app: PublicAPIApp) {
|
||||
const team = c.var.team;
|
||||
const emailId = c.req.param("emailId");
|
||||
|
||||
await checkIsValidEmailId(emailId, team.id);
|
||||
await checkIsValidEmailIdWithDomainRestriction(emailId, team.id, team.apiKey.domainId);
|
||||
|
||||
await updateEmail(emailId, {
|
||||
scheduledAt: c.req.valid("json").scheduledAt,
|
||||
|
@@ -59,5 +59,5 @@ export const getTeamFromToken = async (c: Context) => {
|
||||
logger.error({ err }, "Failed to update lastUsed on API key")
|
||||
);
|
||||
|
||||
return { ...team, apiKeyId: apiKey.id };
|
||||
return { ...team, apiKeyId: apiKey.id, apiKey: { domainId: apiKey.domainId } };
|
||||
};
|
||||
|
@@ -7,13 +7,13 @@ import { getRedis } from "~/server/redis";
|
||||
import { getTeamFromToken } from "~/server/public-api/auth";
|
||||
import { isSelfHosted } from "~/utils/common";
|
||||
import { UnsendApiError } from "./api-error";
|
||||
import { Team } from "@prisma/client";
|
||||
import { Team, ApiKey } from "@prisma/client";
|
||||
import { logger } from "../logger/log";
|
||||
|
||||
// Define AppEnv for Hono context
|
||||
export type AppEnv = {
|
||||
Variables: {
|
||||
team: Team & { apiKeyId: number };
|
||||
team: Team & { apiKeyId: number; apiKey: { domainId: number | null } };
|
||||
};
|
||||
};
|
||||
|
||||
|
@@ -9,12 +9,29 @@ export async function addApiKey({
|
||||
name,
|
||||
permission,
|
||||
teamId,
|
||||
domainId,
|
||||
}: {
|
||||
name: string;
|
||||
permission: ApiPermission;
|
||||
teamId: number;
|
||||
domainId?: number;
|
||||
}) {
|
||||
try {
|
||||
// Validate domain ownership if domainId is provided
|
||||
if (domainId !== undefined) {
|
||||
const domain = await db.domain.findUnique({
|
||||
where: {
|
||||
id: domainId,
|
||||
teamId: teamId
|
||||
},
|
||||
select: { id: true },
|
||||
});
|
||||
|
||||
if (!domain) {
|
||||
throw new Error("DOMAIN_NOT_FOUND");
|
||||
}
|
||||
}
|
||||
|
||||
const clientId = smallNanoid(10);
|
||||
const token = randomBytes(16).toString("hex");
|
||||
const hashedToken = await createSecureHash(token);
|
||||
@@ -26,6 +43,7 @@ export async function addApiKey({
|
||||
name,
|
||||
permission: permission,
|
||||
teamId,
|
||||
domainId,
|
||||
tokenHash: hashedToken,
|
||||
partialToken: `${apiKey.slice(0, 6)}...${apiKey.slice(-3)}`,
|
||||
clientId,
|
||||
@@ -45,6 +63,11 @@ export async function getTeamAndApiKey(apiKey: string) {
|
||||
where: {
|
||||
clientId,
|
||||
},
|
||||
include: {
|
||||
domain: {
|
||||
select: { id: true, name: true },
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
if (!apiKeyRow) {
|
||||
|
@@ -6,6 +6,7 @@ import { db } from "~/server/db";
|
||||
import { SesSettingsService } from "./ses-settings-service";
|
||||
import { UnsendApiError } from "../public-api/api-error";
|
||||
import { logger } from "../logger/log";
|
||||
import { ApiKey } from "@prisma/client";
|
||||
import { LimitService } from "./limit-service";
|
||||
|
||||
const dnsResolveTxt = util.promisify(dns.resolveTxt);
|
||||
@@ -34,7 +35,7 @@ export async function validateDomainFromEmail(email: string, teamId: number) {
|
||||
});
|
||||
}
|
||||
|
||||
const domain = await db.domain.findUnique({
|
||||
const domain = await db.domain.findFirst({
|
||||
where: { name: fromDomain, teamId },
|
||||
});
|
||||
|
||||
@@ -55,6 +56,30 @@ export async function validateDomainFromEmail(email: string, teamId: number) {
|
||||
return domain;
|
||||
}
|
||||
|
||||
export async function validateApiKeyDomainAccess(
|
||||
email: string,
|
||||
teamId: number,
|
||||
apiKey: ApiKey & { domain?: { name: string } | null }
|
||||
) {
|
||||
// First validate the domain exists and is verified
|
||||
const domain = await validateDomainFromEmail(email, teamId);
|
||||
|
||||
// If API key has no domain restriction (domainId is null), allow all domains
|
||||
if (!apiKey.domainId) {
|
||||
return domain;
|
||||
}
|
||||
|
||||
// If API key is restricted to a specific domain, check if it matches
|
||||
if (apiKey.domainId !== domain.id) {
|
||||
throw new UnsendApiError({
|
||||
code: "FORBIDDEN",
|
||||
message: `API key does not have access to domain: ${domain.name}`,
|
||||
});
|
||||
}
|
||||
|
||||
return domain;
|
||||
}
|
||||
|
||||
export async function createDomain(
|
||||
teamId: number,
|
||||
name: string,
|
||||
|
@@ -2,7 +2,7 @@ import { EmailContent } from "~/types";
|
||||
import { db } from "../db";
|
||||
import { UnsendApiError } from "~/server/public-api/api-error";
|
||||
import { EmailQueueService } from "./email-queue-service";
|
||||
import { validateDomainFromEmail } from "./domain-service";
|
||||
import { validateDomainFromEmail, validateApiKeyDomainAccess } from "./domain-service";
|
||||
import { EmailRenderer } from "@usesend/email-editor/src/renderer";
|
||||
import { logger } from "../logger/log";
|
||||
import { SuppressionService } from "./suppression-service";
|
||||
@@ -70,7 +70,27 @@ export async function sendEmail(
|
||||
let subject = subjectFromApiCall;
|
||||
let html = htmlFromApiCall;
|
||||
|
||||
const domain = await validateDomainFromEmail(from, teamId);
|
||||
let domain: Awaited<ReturnType<typeof validateDomainFromEmail>>;
|
||||
|
||||
// If this is an API call with an API key, validate domain access
|
||||
if (apiKeyId) {
|
||||
const apiKey = await db.apiKey.findUnique({
|
||||
where: { id: apiKeyId },
|
||||
include: { domain: true },
|
||||
});
|
||||
|
||||
if (!apiKey) {
|
||||
throw new UnsendApiError({
|
||||
code: "BAD_REQUEST",
|
||||
message: "Invalid API key",
|
||||
});
|
||||
}
|
||||
|
||||
domain = await validateApiKeyDomainAccess(from, teamId, apiKey);
|
||||
} else {
|
||||
// For non-API calls (dashboard, etc.), use regular domain validation
|
||||
domain = await validateDomainFromEmail(from, teamId);
|
||||
}
|
||||
|
||||
// Check for suppressed emails before sending
|
||||
const toEmails = Array.isArray(to) ? to : [to];
|
||||
|
Reference in New Issue
Block a user