diff --git a/packages/graphql-armor/src/apollo/armor.ts b/packages/graphql-armor/src/apollo/armor.ts index afe696e07..22bfdb467 100644 --- a/packages/graphql-armor/src/apollo/armor.ts +++ b/packages/graphql-armor/src/apollo/armor.ts @@ -2,6 +2,7 @@ import type { ApolloServerOptions, ApolloServerPlugin, BaseContext } from '@apol import type { GraphQLArmorConfig } from '@escape.tech/graphql-armor-types'; import type { ValidationRule } from 'graphql'; +import { contextInjectionPlugin } from './context-helper'; import { ApolloProtection } from './protections/base-protection'; import { ApolloBlockFieldSuggestionProtection } from './protections/block-field-suggestion'; import { ApolloCostLimitProtection } from './protections/cost-limit'; @@ -32,6 +33,7 @@ export class ApolloArmor { } { let plugins: ApolloServerOptions['plugins'] = []; let validationRules: ApolloServerOptions['validationRules'] = []; + plugins.push(contextInjectionPlugin); for (const protection of this.protections) { if (protection.isEnabled) { diff --git a/packages/graphql-armor/src/apollo/context-helper.ts b/packages/graphql-armor/src/apollo/context-helper.ts new file mode 100644 index 000000000..afb45edac --- /dev/null +++ b/packages/graphql-armor/src/apollo/context-helper.ts @@ -0,0 +1,70 @@ +import type { ApolloServerPlugin, BaseContext, GraphQLRequestContext } from '@apollo/server'; +import { isEnhancedValidationContext } from '@escape.tech/graphql-armor-types'; +import type { ASTVisitor, ValidationRule as GraphQLValidationRule, ValidationContext } from 'graphql'; + +// Types +interface ExtendedValidationContext extends ValidationContext { + graphqlRequest?: GraphQLRequestContext; + _rules?: ValidationRule[]; +} + +interface ValidationRule { + attachRequestContext?: (requestContext: GraphQLRequestContext) => void; + [key: string]: unknown; +} + +type RuleFunction = (context: ValidationContext) => ASTVisitor; + +// Helper function to create a rule wrapper that injects the request context +export const injectRequestContextRule = (rule: RuleFunction): GraphQLValidationRule => { + return (context: ValidationContext): ASTVisitor => { + const visitor = rule(context); + + if (!isEnhancedValidationContext(context)) { + return visitor; + } + + // We need to cast here because ASTVisitor doesn't know about our custom property + const enhancedVisitor = { + ...visitor, + attachRequestContext: (requestContext: GraphQLRequestContext) => { + (context as ExtendedValidationContext).graphqlRequest = requestContext; + }, + }; + + return enhancedVisitor as unknown as ASTVisitor; + }; +}; + +// Helper function to find validation context in different Apollo versions +const findValidationContext = (validationCtx: any): ValidationContext | undefined => { + return validationCtx.validationContext || validationCtx.context?.document?.validationContext; +}; + +// Helper function to process validation rules +const processValidationRules = ( + validationContext: ExtendedValidationContext, + requestContext: GraphQLRequestContext, +): void => { + const rules = validationContext._rules || []; + const contextRules = rules.filter((rule): rule is ValidationRule => + Boolean(rule?.attachRequestContext && typeof rule.attachRequestContext === 'function'), + ); + + contextRules.forEach((rule) => rule.attachRequestContext?.(requestContext)); +}; + +// Create a single shared plugin instance for context injection +export const contextInjectionPlugin: ApolloServerPlugin = { + async requestDidStart(requestContext: GraphQLRequestContext) { + return { + async validationDidStart(validationCtx: any) { + const validationContext = findValidationContext(validationCtx); + + if (validationContext && isEnhancedValidationContext(validationContext)) { + processValidationRules(validationContext as ExtendedValidationContext, requestContext); + } + }, + }; + }, +}; diff --git a/packages/graphql-armor/src/apollo/protections/max-aliases.ts b/packages/graphql-armor/src/apollo/protections/max-aliases.ts index 8f1a4a73d..5625dd7bd 100644 --- a/packages/graphql-armor/src/apollo/protections/max-aliases.ts +++ b/packages/graphql-armor/src/apollo/protections/max-aliases.ts @@ -1,5 +1,6 @@ import { maxAliasesRule } from '@escape.tech/graphql-armor-max-aliases'; +import { injectRequestContextRule } from '../context-helper'; import { inferApolloPropagator } from '../errors'; import { ApolloProtection, ApolloServerConfigurationEnhancement } from './base-protection'; @@ -15,7 +16,7 @@ export class ApolloMaxAliasesProtection extends ApolloProtection { this.config.maxAliases = inferApolloPropagator(this.config.maxAliases); return { - validationRules: [maxAliasesRule(this.config.maxAliases)], + validationRules: [injectRequestContextRule(maxAliasesRule(this.config.maxAliases))], }; } } diff --git a/packages/graphql-armor/src/apollo/protections/max-depth.ts b/packages/graphql-armor/src/apollo/protections/max-depth.ts index 528b05050..9d7cf347c 100644 --- a/packages/graphql-armor/src/apollo/protections/max-depth.ts +++ b/packages/graphql-armor/src/apollo/protections/max-depth.ts @@ -1,5 +1,6 @@ import { maxDepthRule } from '@escape.tech/graphql-armor-max-depth'; +import { injectRequestContextRule } from '../context-helper'; import { inferApolloPropagator } from '../errors'; import { ApolloProtection, ApolloServerConfigurationEnhancement } from './base-protection'; @@ -15,7 +16,7 @@ export class ApolloMaxDepthProtection extends ApolloProtection { this.config.maxDepth = inferApolloPropagator(this.config.maxDepth); return { - validationRules: [maxDepthRule(this.config.maxDepth)], + validationRules: [injectRequestContextRule(maxDepthRule(this.config.maxDepth))], }; } } diff --git a/packages/graphql-armor/src/apollo/protections/max-directives.ts b/packages/graphql-armor/src/apollo/protections/max-directives.ts index d4aba7880..461a82b6f 100644 --- a/packages/graphql-armor/src/apollo/protections/max-directives.ts +++ b/packages/graphql-armor/src/apollo/protections/max-directives.ts @@ -1,5 +1,6 @@ import { maxDirectivesRule } from '@escape.tech/graphql-armor-max-directives'; +import { injectRequestContextRule } from '../context-helper'; import { inferApolloPropagator } from '../errors'; import { ApolloProtection, ApolloServerConfigurationEnhancement } from './base-protection'; @@ -15,7 +16,7 @@ export class ApolloMaxDirectivesProtection extends ApolloProtection { this.config.maxDirectives = inferApolloPropagator(this.config.maxDirectives); return { - validationRules: [maxDirectivesRule(this.config.maxDirectives)], + validationRules: [injectRequestContextRule(maxDirectivesRule(this.config.maxDirectives))], }; } } diff --git a/packages/graphql-armor/test/apollo/armor.spec.ts b/packages/graphql-armor/test/apollo/armor.spec.ts index 13fae4065..56576b670 100644 --- a/packages/graphql-armor/test/apollo/armor.spec.ts +++ b/packages/graphql-armor/test/apollo/armor.spec.ts @@ -21,7 +21,7 @@ describe('apolloArmor', () => { }); it('should have property that equals', () => { - expect(enhancements.plugins.length).toEqual(2); + expect(enhancements.plugins.length).toEqual(3); expect(enhancements.validationRules.length).toEqual(4); expect(enhancements.allowBatchedHttpRequests).toEqual(false); expect(enhancements.includeStacktraceInErrorResponses).toEqual(false); @@ -50,6 +50,6 @@ describe('apolloArmor', () => { }); const enhancementsDisabled = apolloDisabled.protect(); - expect(enhancementsDisabled.plugins.length).toEqual(0); + expect(enhancementsDisabled.plugins.length).toEqual(1); }); }); diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index 091de1e86..3e24121d2 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -1,3 +1,4 @@ +import type { BaseContext, GraphQLRequestContext } from '@apollo/server'; import type { BlockFieldSuggestionsOptions } from '@escape.tech/graphql-armor-block-field-suggestions'; import type { CostLimitOptions } from '@escape.tech/graphql-armor-cost-limit'; import type { MaxAliasesOptions } from '@escape.tech/graphql-armor-max-aliases'; @@ -6,6 +7,7 @@ import type { MaxDirectivesOptions } from '@escape.tech/graphql-armor-max-direct import type { MaxTokensOptions } from '@escape.tech/graphql-armor-max-tokens'; import type { GraphQLError, ValidationContext } from 'graphql'; +// Core configuration types export type ProtectionConfiguration = { enabled?: boolean; }; @@ -19,10 +21,51 @@ export type GraphQLArmorConfig = { maxTokens?: ProtectionConfiguration & MaxTokensOptions; }; -export type GraphQLArmorAcceptCallback = (ctx: ValidationContext | null, details: any) => void; -export type GraphQLArmorRejectCallback = (ctx: ValidationContext | null, error: GraphQLError) => void; +// Context and validation types +export interface EnhancedValidationContext extends ValidationContext { + graphqlRequest?: GraphQLRequestContext; +} + +// User and authentication types +export interface User { + id?: string; + trustLevel?: string; + [key: string]: unknown; +} + +// Callback types +export interface AcceptCallbackDetails { + [key: string]: unknown; +} + +export type GraphQLArmorAcceptCallback = ( + ctx: EnhancedValidationContext | null, + details: AcceptCallbackDetails, +) => void; + +export type GraphQLArmorRejectCallback = (ctx: EnhancedValidationContext | null, error: GraphQLError) => void; + export type GraphQLArmorCallbackConfiguration = { onAccept?: GraphQLArmorAcceptCallback[]; onReject?: GraphQLArmorRejectCallback[]; propagateOnRejection?: boolean; }; + +// Type guards +export const isEnhancedValidationContext = (context: ValidationContext | null): context is EnhancedValidationContext => + context !== null && 'graphqlRequest' in context; + +export const isUser = (value: unknown): value is User => + typeof value === 'object' && value !== null && 'trustLevel' in value; + +// Utility functions +export const createRejectionError = (message: string, user?: User): Error => { + const userInfo = user?.id ? ` for user ${user.id}` : ''; + return new Error(`${message}${userInfo}`); +}; + +export const validateRequestContext = (context: EnhancedValidationContext): void => { + if (!context.graphqlRequest) { + throw new Error('Request context not available'); + } +};