diff --git a/schema.go b/schema.go index 2d914b8..1733c10 100644 --- a/schema.go +++ b/schema.go @@ -2,6 +2,8 @@ package jsonschema import ( "encoding/json" + "maps" + "slices" orderedmap "github.com/wk8/go-ordered-map/v2" ) @@ -92,3 +94,101 @@ var ( // http://json-schema.org/latest/json-schema-validation.html#rfc.section.5.26 // RFC draft-wright-json-schema-validation-00, section 5.26 type Definitions map[string]*Schema + +// MergeSchemas merges multiple schemas into one. The following fields are merged: +// - Definitions +// - AllOf +// - AnyOf +// - OneOf +// - PrefixItems +// - Properties +// - PatternProperties +// - Enum +// - Required +// - DependentRequired +// - Examples +// - Extras +// +// This is most useful for combining multiple top-level schemas generated through reflection. +func (t *Schema) MergeSchemas(source ...*Schema) (*Schema, error) { + if t.Definitions == nil { + t.Definitions = make(map[string]*Schema) + } + + if t.Properties == nil { + t.Properties = NewProperties() + } + + if t.Extras == nil { + t.Extras = make(map[string]any) + } + + if t.DependentSchemas == nil { + t.DependentSchemas = make(map[string]*Schema) + } + + if t.PatternProperties == nil { + t.PatternProperties = make(map[string]*Schema) + } + + if t.DependentRequired == nil { + t.DependentRequired = make(map[string][]string) + } + + for _, src := range source { + if src.Properties != nil { + for pair := src.Properties.Oldest(); pair != nil; pair = pair.Next() { + t.Properties.Set(pair.Key, pair.Value) + } + } + + maps.Copy(t.Definitions, src.Definitions) + maps.Copy(t.Extras, src.Extras) + maps.Copy(t.DependentSchemas, src.DependentSchemas) + maps.Copy(t.PatternProperties, src.PatternProperties) + + t.AllOf = append(t.AllOf, src.AllOf...) + t.AnyOf = append(t.AnyOf, src.AnyOf...) + t.OneOf = append(t.OneOf, src.OneOf...) + t.Enum = append(t.Enum, src.Enum...) + t.PrefixItems = append(t.PrefixItems, src.PrefixItems...) + t.Examples = append(t.Examples, src.Examples...) + + for _, r := range src.Required { + if !slices.Contains(t.Required, r) { + t.Required = append(t.Required, r) + } + } + + for k, v := range src.DependentRequired { + if _, ok := t.DependentRequired[k]; !ok { + t.DependentRequired[k] = v + + continue + } + + for _, req := range v { + if !slices.Contains(t.DependentRequired[k], req) { + t.DependentRequired[k] = append(t.DependentRequired[k], req) + } + } + + } + } + + return t, nil +} + +// AddDefinition adds a schema to the definitions of the current schema. +// The definition schema is consumed and should not be used afterwards. +func (t *Schema) AddDefinition(name string, definition *Schema) { + if t.Definitions == nil { + t.Definitions = make(map[string]*Schema) + } + + // clear ID and Version as they are for top-level schema only + definition.ID = ID("") + definition.Version = "" + + t.Definitions[name] = definition +} diff --git a/schema_test.go b/schema_test.go new file mode 100644 index 0000000..a64224e --- /dev/null +++ b/schema_test.go @@ -0,0 +1,42 @@ +package jsonschema + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSchemaMerge(t *testing.T) { + t.Parallel() + + r := Reflector{ + ExpandedStruct: true, + } + + combined, err := r.Reflect(&TestUser{}). + MergeSchemas(r.Reflect(&Inner{}), r.Reflect(&RecursiveExample{})) + + require.NoError(t, err) + require.NotNil(t, combined) + + // recursive definition must be added manually when using ExpandedStruct + combined.AddDefinition("RecursiveExample", r.Reflect(&RecursiveExample{})) + + type combinedStruct struct { + TestUser `json:",inline"` + Inner `json:",inline"` + RecursiveExample `json:",inline"` + } + + expected := r.Reflect(&combinedStruct{}) + + expected.ID = combined.ID // IDs are expected to differ, everything else must be identical + + expectedJSON, err := expected.MarshalJSON() + require.NoError(t, err) + + combinedJSON, err := combined.MarshalJSON() + require.NoError(t, err) + + require.Equal(t, string(expectedJSON), string(combinedJSON)) +}