Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/graphql-armor/src/apollo/armor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { ApolloServerOptions, ApolloServerPlugin, BaseContext } from '@apol
import type { GraphQLArmorConfig } from '@escape.tech/graphql-armor-types';
import type { ValidationRule } from 'graphql';

import { contextInjectionPlugin } from './context-helper';
import { ApolloProtection } from './protections/base-protection';
import { ApolloBlockFieldSuggestionProtection } from './protections/block-field-suggestion';
import { ApolloCostLimitProtection } from './protections/cost-limit';
Expand Down Expand Up @@ -32,6 +33,7 @@ export class ApolloArmor {
} {
let plugins: ApolloServerOptions<BaseContext>['plugins'] = [];
let validationRules: ApolloServerOptions<BaseContext>['validationRules'] = [];
plugins.push(contextInjectionPlugin);

for (const protection of this.protections) {
if (protection.isEnabled) {
Expand Down
70 changes: 70 additions & 0 deletions packages/graphql-armor/src/apollo/context-helper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import type { ApolloServerPlugin, BaseContext, GraphQLRequestContext } from '@apollo/server';
import { isEnhancedValidationContext } from '@escape.tech/graphql-armor-types';
import type { ASTVisitor, ValidationRule as GraphQLValidationRule, ValidationContext } from 'graphql';

// Types
interface ExtendedValidationContext extends ValidationContext {
graphqlRequest?: GraphQLRequestContext<BaseContext>;
_rules?: ValidationRule[];
}

interface ValidationRule {
attachRequestContext?: (requestContext: GraphQLRequestContext<BaseContext>) => void;
[key: string]: unknown;
}

type RuleFunction = (context: ValidationContext) => ASTVisitor;

// Helper function to create a rule wrapper that injects the request context
export const injectRequestContextRule = (rule: RuleFunction): GraphQLValidationRule => {
return (context: ValidationContext): ASTVisitor => {
const visitor = rule(context);

if (!isEnhancedValidationContext(context)) {
return visitor;
}

// We need to cast here because ASTVisitor doesn't know about our custom property
const enhancedVisitor = {
...visitor,
attachRequestContext: (requestContext: GraphQLRequestContext<BaseContext>) => {
(context as ExtendedValidationContext).graphqlRequest = requestContext;
},
};

return enhancedVisitor as unknown as ASTVisitor;
};
};

// Helper function to find validation context in different Apollo versions
const findValidationContext = (validationCtx: any): ValidationContext | undefined => {
return validationCtx.validationContext || validationCtx.context?.document?.validationContext;
};

// Helper function to process validation rules
const processValidationRules = (
validationContext: ExtendedValidationContext,
requestContext: GraphQLRequestContext<BaseContext>,
): void => {
const rules = validationContext._rules || [];
const contextRules = rules.filter((rule): rule is ValidationRule =>
Boolean(rule?.attachRequestContext && typeof rule.attachRequestContext === 'function'),
);

contextRules.forEach((rule) => rule.attachRequestContext?.(requestContext));
};

// Create a single shared plugin instance for context injection
export const contextInjectionPlugin: ApolloServerPlugin<BaseContext> = {
async requestDidStart(requestContext: GraphQLRequestContext<BaseContext>) {
return {
async validationDidStart(validationCtx: any) {
const validationContext = findValidationContext(validationCtx);

if (validationContext && isEnhancedValidationContext(validationContext)) {
processValidationRules(validationContext as ExtendedValidationContext, requestContext);
}
},
};
},
};
3 changes: 2 additions & 1 deletion packages/graphql-armor/src/apollo/protections/max-aliases.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { maxAliasesRule } from '@escape.tech/graphql-armor-max-aliases';

import { injectRequestContextRule } from '../context-helper';
import { inferApolloPropagator } from '../errors';
import { ApolloProtection, ApolloServerConfigurationEnhancement } from './base-protection';

Expand All @@ -15,7 +16,7 @@ export class ApolloMaxAliasesProtection extends ApolloProtection {
this.config.maxAliases = inferApolloPropagator<typeof this.config.maxAliases>(this.config.maxAliases);

return {
validationRules: [maxAliasesRule(this.config.maxAliases)],
validationRules: [injectRequestContextRule(maxAliasesRule(this.config.maxAliases))],
};
}
}
3 changes: 2 additions & 1 deletion packages/graphql-armor/src/apollo/protections/max-depth.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { maxDepthRule } from '@escape.tech/graphql-armor-max-depth';

import { injectRequestContextRule } from '../context-helper';
import { inferApolloPropagator } from '../errors';
import { ApolloProtection, ApolloServerConfigurationEnhancement } from './base-protection';

Expand All @@ -15,7 +16,7 @@ export class ApolloMaxDepthProtection extends ApolloProtection {
this.config.maxDepth = inferApolloPropagator<typeof this.config.maxDepth>(this.config.maxDepth);

return {
validationRules: [maxDepthRule(this.config.maxDepth)],
validationRules: [injectRequestContextRule(maxDepthRule(this.config.maxDepth))],
};
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { maxDirectivesRule } from '@escape.tech/graphql-armor-max-directives';

import { injectRequestContextRule } from '../context-helper';
import { inferApolloPropagator } from '../errors';
import { ApolloProtection, ApolloServerConfigurationEnhancement } from './base-protection';

Expand All @@ -15,7 +16,7 @@ export class ApolloMaxDirectivesProtection extends ApolloProtection {
this.config.maxDirectives = inferApolloPropagator<typeof this.config.maxDirectives>(this.config.maxDirectives);

return {
validationRules: [maxDirectivesRule(this.config.maxDirectives)],
validationRules: [injectRequestContextRule(maxDirectivesRule(this.config.maxDirectives))],
};
}
}
4 changes: 2 additions & 2 deletions packages/graphql-armor/test/apollo/armor.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ describe('apolloArmor', () => {
});

it('should have property that equals', () => {
expect(enhancements.plugins.length).toEqual(2);
expect(enhancements.plugins.length).toEqual(3);
expect(enhancements.validationRules.length).toEqual(4);
expect(enhancements.allowBatchedHttpRequests).toEqual(false);
expect(enhancements.includeStacktraceInErrorResponses).toEqual(false);
Expand Down Expand Up @@ -50,6 +50,6 @@ describe('apolloArmor', () => {
});

const enhancementsDisabled = apolloDisabled.protect();
expect(enhancementsDisabled.plugins.length).toEqual(0);
expect(enhancementsDisabled.plugins.length).toEqual(1);
});
});
47 changes: 45 additions & 2 deletions packages/types/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { BaseContext, GraphQLRequestContext } from '@apollo/server';
import type { BlockFieldSuggestionsOptions } from '@escape.tech/graphql-armor-block-field-suggestions';
import type { CostLimitOptions } from '@escape.tech/graphql-armor-cost-limit';
import type { MaxAliasesOptions } from '@escape.tech/graphql-armor-max-aliases';
Expand All @@ -6,6 +7,7 @@ import type { MaxDirectivesOptions } from '@escape.tech/graphql-armor-max-direct
import type { MaxTokensOptions } from '@escape.tech/graphql-armor-max-tokens';
import type { GraphQLError, ValidationContext } from 'graphql';

// Core configuration types
export type ProtectionConfiguration = {
enabled?: boolean;
};
Expand All @@ -19,10 +21,51 @@ export type GraphQLArmorConfig = {
maxTokens?: ProtectionConfiguration & MaxTokensOptions;
};

export type GraphQLArmorAcceptCallback = (ctx: ValidationContext | null, details: any) => void;
export type GraphQLArmorRejectCallback = (ctx: ValidationContext | null, error: GraphQLError) => void;
// Context and validation types
export interface EnhancedValidationContext extends ValidationContext {
graphqlRequest?: GraphQLRequestContext<BaseContext>;
}

// User and authentication types
export interface User {
id?: string;
trustLevel?: string;
[key: string]: unknown;
}

// Callback types
export interface AcceptCallbackDetails {
[key: string]: unknown;
}

export type GraphQLArmorAcceptCallback = (
ctx: EnhancedValidationContext | null,
details: AcceptCallbackDetails,
) => void;

export type GraphQLArmorRejectCallback = (ctx: EnhancedValidationContext | null, error: GraphQLError) => void;

export type GraphQLArmorCallbackConfiguration = {
onAccept?: GraphQLArmorAcceptCallback[];
onReject?: GraphQLArmorRejectCallback[];
propagateOnRejection?: boolean;
};

// Type guards
export const isEnhancedValidationContext = (context: ValidationContext | null): context is EnhancedValidationContext =>
context !== null && 'graphqlRequest' in context;

export const isUser = (value: unknown): value is User =>
typeof value === 'object' && value !== null && 'trustLevel' in value;

// Utility functions
export const createRejectionError = (message: string, user?: User): Error => {
const userInfo = user?.id ? ` for user ${user.id}` : '';
return new Error(`${message}${userInfo}`);
};

export const validateRequestContext = (context: EnhancedValidationContext): void => {
if (!context.graphqlRequest) {
throw new Error('Request context not available');
}
};