diff options
| author | Fuwn <[email protected]> | 2026-01-24 13:09:50 +0000 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2026-01-24 13:09:50 +0000 |
| commit | 396acf3bbbe00a192cb0ea0a9ccf91b1d8d2850b (patch) | |
| tree | b9df4ca6a70db45cfffbae6fdd7252e20fb8e93c /src/lib/prisma.ts | |
| download | umami-396acf3bbbe00a192cb0ea0a9ccf91b1d8d2850b.tar.xz umami-396acf3bbbe00a192cb0ea0a9ccf91b1d8d2850b.zip | |
Created from https://vercel.com/new
Diffstat (limited to 'src/lib/prisma.ts')
| -rw-r--r-- | src/lib/prisma.ts | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/src/lib/prisma.ts b/src/lib/prisma.ts new file mode 100644 index 0000000..64cb870 --- /dev/null +++ b/src/lib/prisma.ts @@ -0,0 +1,368 @@ +import { PrismaPg } from '@prisma/adapter-pg'; +import { readReplicas } from '@prisma/extension-read-replicas'; +import debug from 'debug'; +import { PrismaClient } from '@/generated/prisma/client'; +import { DEFAULT_PAGE_SIZE, FILTER_COLUMNS, OPERATORS, SESSION_COLUMNS } from './constants'; +import { filtersObjectToArray } from './params'; +import type { Operator, QueryFilters, QueryOptions } from './types'; + +const log = debug('umami:prisma'); + +const PRISMA = 'prisma'; + +const PRISMA_LOG_OPTIONS = { + log: [ + { + emit: 'event' as const, + level: 'query' as const, + }, + ], +}; + +const DATE_FORMATS = { + minute: 'YYYY-MM-DD HH24:MI:00', + hour: 'YYYY-MM-DD HH24:00:00', + day: 'YYYY-MM-DD HH24:00:00', + month: 'YYYY-MM-01 HH24:00:00', + year: 'YYYY-01-01 HH24:00:00', +}; + +const DATE_FORMATS_UTC = { + minute: 'YYYY-MM-DD"T"HH24:MI:00"Z"', + hour: 'YYYY-MM-DD"T"HH24:00:00"Z"', + day: 'YYYY-MM-DD"T"HH24:00:00"Z"', + month: 'YYYY-MM-01"T"HH24:00:00"Z"', + year: 'YYYY-01-01"T"HH24:00:00"Z"', +}; + +function getAddIntervalQuery(field: string, interval: string): string { + return `${field} + interval '${interval}'`; +} + +function getDayDiffQuery(field1: string, field2: string): string { + return `${field1}::date - ${field2}::date`; +} + +function getCastColumnQuery(field: string, type: string): string { + return `${field}::${type}`; +} + +function getDateSQL(field: string, unit: string, timezone?: string): string { + if (timezone && timezone !== 'utc') { + return `to_char(date_trunc('${unit}', ${field} at time zone '${timezone}'), '${DATE_FORMATS[unit]}')`; + } + + return `to_char(date_trunc('${unit}', ${field}), '${DATE_FORMATS_UTC[unit]}')`; +} + +function getDateWeeklySQL(field: string, timezone?: string) { + return `concat(extract(dow from (${field} at time zone '${timezone}')), ':', to_char((${field} at time zone '${timezone}'), 'HH24'))`; +} + +export function getTimestampSQL(field: string) { + return `floor(extract(epoch from ${field}))`; +} + +function getTimestampDiffSQL(field1: string, field2: string): string { + return `floor(extract(epoch from (${field2} - ${field1})))`; +} + +function getSearchSQL(column: string, param: string = 'search'): string { + return `and ${column} ilike {{${param}}}`; +} + +function mapFilter(column: string, operator: string, name: string, type: string = '') { + const value = `{{${name}${type ? `::${type}` : ''}}}`; + + switch (operator) { + case OPERATORS.equals: + return `${column} = ${value}`; + case OPERATORS.notEquals: + return `${column} != ${value}`; + case OPERATORS.contains: + return `${column} ilike ${value}`; + case OPERATORS.doesNotContain: + return `${column} not ilike ${value}`; + default: + return ''; + } +} + +function getFilterQuery(filters: Record<string, any>, options: QueryOptions = {}): string { + const query = filtersObjectToArray(filters, options).reduce( + (arr, { name, column, operator, prefix = '' }) => { + const isCohort = options?.isCohort; + + if (isCohort) { + column = FILTER_COLUMNS[name.slice('cohort_'.length)]; + } + + if (column) { + arr.push(`and ${mapFilter(`${prefix}${column}`, operator, name)}`); + + if (name === 'referrer') { + arr.push( + `and (website_event.referrer_domain != website_event.hostname or website_event.referrer_domain is null)`, + ); + } + } + + return arr; + }, + [], + ); + + return query.join('\n'); +} + +function getCohortQuery(filters: QueryFilters = {}) { + if (!filters || Object.keys(filters).length === 0) { + return ''; + } + + const filterQuery = getFilterQuery(filters, { isCohort: true }); + + return `join + (select distinct website_event.session_id + from website_event + join session on session.session_id = website_event.session_id + and session.website_id = website_event.website_id + where website_event.website_id = {{websiteId}} + and website_event.created_at between {{cohort_startDate}} and {{cohort_endDate}} + ${filterQuery} + ) cohort + on cohort.session_id = website_event.session_id + `; +} + +function getDateQuery(filters: Record<string, any>) { + const { startDate, endDate } = filters; + + if (startDate) { + if (endDate) { + return `and website_event.created_at between {{startDate}} and {{endDate}}`; + } else { + return `and website_event.created_at >= {{startDate}}`; + } + } + + return ''; +} + +function getQueryParams(filters: Record<string, any>) { + return { + ...filters, + ...filtersObjectToArray(filters).reduce((obj, { name, operator, value }) => { + obj[name] = ([OPERATORS.contains, OPERATORS.doesNotContain] as Operator[]).includes(operator) + ? `%${value}%` + : value; + + return obj; + }, {}), + }; +} + +function parseFilters(filters: Record<string, any>, options?: QueryOptions) { + const joinSession = Object.keys(filters).find(key => + ['referrer', ...SESSION_COLUMNS].includes(key), + ); + + const cohortFilters = Object.fromEntries( + Object.entries(filters).filter(([key]) => key.startsWith('cohort_')), + ); + + return { + joinSessionQuery: + options?.joinSession || joinSession + ? `inner join session on website_event.session_id = session.session_id and website_event.website_id = session.website_id` + : '', + dateQuery: getDateQuery(filters), + filterQuery: getFilterQuery(filters, options), + queryParams: getQueryParams(filters), + cohortQuery: getCohortQuery(cohortFilters), + }; +} + +async function rawQuery(sql: string, data: Record<string, any>, name?: string): Promise<any> { + if (process.env.LOG_QUERY) { + log('QUERY:\n', sql); + log('PARAMETERS:\n', data); + log('NAME:\n', name); + } + const params = []; + const schema = getSchema(); + + if (schema) { + await client.$executeRawUnsafe(`SET search_path TO "${schema}";`); + } + + const query = sql?.replaceAll(/\{\{\s*(\w+)(::\w+)?\s*}}/g, (...args) => { + const [, name, type] = args; + + const value = data[name]; + + params.push(value); + + return `$${params.length}${type ?? ''}`; + }); + + if (process.env.DATABASE_REPLICA_URL && '$replica' in client) { + return client.$replica().$queryRawUnsafe(query, ...params); + } + + return client.$queryRawUnsafe(query, ...params); +} + +async function pagedQuery<T>(model: string, criteria: T, filters?: QueryFilters) { + const { page = 1, pageSize, orderBy, sortDescending = false, search } = filters || {}; + const size = +pageSize || DEFAULT_PAGE_SIZE; + + const data = await client[model].findMany({ + ...criteria, + ...{ + ...(size > 0 && { take: +size, skip: +size * (+page - 1) }), + ...(orderBy && { + orderBy: [ + { + [orderBy]: sortDescending ? 'desc' : 'asc', + }, + ], + }), + }, + }); + + const count = await client[model].count({ where: (criteria as any).where }); + + return { data, count, page: +page, pageSize: size, orderBy, search }; +} + +async function pagedRawQuery( + query: string, + queryParams: Record<string, any>, + filters: QueryFilters, + name?: string, +) { + const { page = 1, pageSize, orderBy, sortDescending = false } = filters; + const size = +pageSize || DEFAULT_PAGE_SIZE; + const offset = +size * (+page - 1); + const direction = sortDescending ? 'desc' : 'asc'; + + const statements = [ + orderBy && `order by ${orderBy} ${direction}`, + +size > 0 && `limit ${+size} offset ${offset}`, + ] + .filter(n => n) + .join('\n'); + + const count = await rawQuery(`select count(*) as num from (${query}) t`, queryParams).then( + res => res[0].num, + ); + + const data = await rawQuery(`${query}${statements}`, queryParams, name); + + return { data, count, page: +page, pageSize: size, orderBy }; +} + +function getSearchParameters(query: string, filters: Record<string, any>[]) { + if (!query) return; + + const parseFilter = (filter: Record<string, any>) => { + const [[key, value]] = Object.entries(filter); + + return { + [key]: + typeof value === 'string' + ? { + [value]: query, + mode: 'insensitive', + } + : parseFilter(value), + }; + }; + + const params = filters.map(filter => parseFilter(filter)); + + return { + AND: { + OR: params, + }, + }; +} + +function transaction(input: any, options?: any) { + return client.$transaction(input, options); +} + +function getSchema() { + const connectionUrl = new URL(process.env.DATABASE_URL); + + return connectionUrl.searchParams.get('schema'); +} + +function getClient() { + const url = process.env.DATABASE_URL; + const replicaUrl = process.env.DATABASE_REPLICA_URL; + const logQuery = process.env.LOG_QUERY; + const schema = getSchema(); + + const baseAdapter = new PrismaPg({ connectionString: url }, { schema }); + + const baseClient = new PrismaClient({ + adapter: baseAdapter, + errorFormat: 'pretty', + ...(logQuery ? PRISMA_LOG_OPTIONS : {}), + }); + + if (logQuery) { + baseClient.$on('query', log); + } + + if (!replicaUrl) { + log('Prisma initialized'); + globalThis[PRISMA] ??= baseClient; + return baseClient; + } + + const replicaAdapter = new PrismaPg({ connectionString: replicaUrl }, { schema }); + + const replicaClient = new PrismaClient({ + adapter: replicaAdapter, + errorFormat: 'pretty', + ...(logQuery ? PRISMA_LOG_OPTIONS : {}), + }); + + if (logQuery) { + replicaClient.$on('query', log); + } + + const extended = baseClient.$extends( + readReplicas({ + replicas: [replicaClient], + }), + ); + + log('Prisma initialized (with replica)'); + globalThis[PRISMA] ??= extended; + + return extended; +} + +const client = (globalThis[PRISMA] || getClient()) as ReturnType<typeof getClient>; + +export default { + client, + transaction, + getAddIntervalQuery, + getCastColumnQuery, + getDayDiffQuery, + getDateSQL, + getDateWeeklySQL, + getFilterQuery, + getSearchParameters, + getTimestampDiffSQL, + getSearchSQL, + pagedQuery, + pagedRawQuery, + parseFilters, + rawQuery, +}; |