diff --git a/Project.toml b/Project.toml index 41c5441cac..fe8ca86d72 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -33,7 +32,6 @@ FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ImplicitDiscreteSolve = "3263718b-31ed-49cf-8a0f-35a466e8af96" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -68,18 +66,22 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac" InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" Pyomo = "0e8e1daf-01b5-4eba-a626-3897743a3816" [extensions] MTKBifurcationKitExt = "BifurcationKit" MTKCasADiDynamicOptExt = "CasADi" +MTKChainRulesCoreExt = "ChainRulesCore" MTKDeepDiffsExt = "DeepDiffs" MTKFMIExt = "FMI" MTKInfiniteOptExt = "InfiniteOpt" +MTKJuliaFormatterExt = "JuliaFormatter" MTKLabelledArraysExt = "LabelledArrays" MTKPyomoDynamicOptExt = "Pyomo" diff --git a/docs/src/basics/FAQ.md b/docs/src/basics/FAQ.md index 3f09ab8b13..7b712395b3 100644 --- a/docs/src/basics/FAQ.md +++ b/docs/src/basics/FAQ.md @@ -192,7 +192,7 @@ p, replace, alias = SciMLStructures.canonicalize(Tunable(), prob.p) # changes to the array will be reflected in parameter values ``` -# ERROR: ArgumentError: SymbolicUtils.BasicSymbolic{Real}[xˍt(t)] are missing from the variable map. +# ERROR: ArgumentError: `[xˍt(t)]` are missing from the variable map. This error can come up after running `mtkcompile` on a system that generates dummy derivatives (i.e. variables with `ˍt`). For example, here even though all the variables are defined with initial values, the `ODEProblem` generation will throw an error that defaults are missing from the variable map. diff --git a/ext/MTKCasADiDynamicOptExt.jl b/ext/MTKCasADiDynamicOptExt.jl index addc478d98..b1762e7f7f 100644 --- a/ext/MTKCasADiDynamicOptExt.jl +++ b/ext/MTKCasADiDynamicOptExt.jl @@ -122,7 +122,7 @@ end function MTK.lowered_var(m::CasADiModel, uv, i, t) X = getfield(m, uv) - t isa Union{Num, Symbolics.Symbolic} ? X.u[i, :] : X(t)[i] + t isa Union{Num, SymbolicT} ? X.u[i, :] : X(t)[i] end function MTK.lowered_integral(model::CasADiModel, expr, lo, hi) diff --git a/src/adjoints.jl b/ext/MTKChainRulesCoreExt.jl similarity index 85% rename from src/adjoints.jl rename to ext/MTKChainRulesCoreExt.jl index 98266de938..c213a164a3 100644 --- a/src/adjoints.jl +++ b/ext/MTKChainRulesCoreExt.jl @@ -1,3 +1,13 @@ +module MTKChainRulesCoreExt + +import ChainRulesCore +import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk +using ModelingToolkit: MTKParameters, NONNUMERIC_PORTION, AbstractSystem +import ModelingToolkit as MTK +import SciMLStructures +import SymbolicIndexingInterface: remake_buffer +import SciMLBase: AbstractNonlinearProblem, remake + function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...) function mtp_pullback(dt) dt = unthunk(dt) @@ -104,3 +114,11 @@ function ChainRulesCore.rrule( end ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol) + +function ModelingToolkit.update_initializeprob!(initprob::AbstractNonlinearProblem, prob) + pgetter = ChainRulesCore.@ignore_derivatives MTK.get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + p = pgetter(prob, initprob) + return remake(initprob; p) +end + +end diff --git a/ext/MTKInfiniteOptExt.jl b/ext/MTKInfiniteOptExt.jl index e0f02c0436..acb80c1041 100644 --- a/ext/MTKInfiniteOptExt.jl +++ b/ext/MTKInfiniteOptExt.jl @@ -122,7 +122,7 @@ end function MTK.lowered_var(m::InfiniteOptModel, uv, i, t) X = getfield(m, uv) - t isa Union{Num, Symbolics.Symbolic} ? X[i] : X[i](t) + t isa Union{Num, SymbolicT} ? X[i] : X[i](t) end function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau) @@ -256,13 +256,13 @@ for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqr end # JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both. -function Base.isequal(::SymbolicUtils.Symbolic, +function Base.isequal(::SymbolicT, ::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr}) false end function Base.isequal( ::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr}, - ::SymbolicUtils.Symbolic) + ::SymbolicT) false end end diff --git a/ext/MTKJuliaFormatterExt.jl b/ext/MTKJuliaFormatterExt.jl new file mode 100644 index 0000000000..a4efeec931 --- /dev/null +++ b/ext/MTKJuliaFormatterExt.jl @@ -0,0 +1,12 @@ +module MTKJuliaFormatterExt + +import ModelingToolkit: readable_code, _readable_code, rec_remove_macro_linenums! +import JuliaFormatter + +function readable_code(expr::Expr) + expr = Base.remove_linenums!(_readable_code(expr)) + rec_remove_macro_linenums!(expr) + JuliaFormatter.format_text(string(expr), JuliaFormatter.SciMLStyle()) +end + +end diff --git a/ext/MTKPyomoDynamicOptExt.jl b/ext/MTKPyomoDynamicOptExt.jl index 5b4e9e7a1c..053a2e34c3 100644 --- a/ext/MTKPyomoDynamicOptExt.jl +++ b/ext/MTKPyomoDynamicOptExt.jl @@ -53,7 +53,7 @@ struct PyomoDynamicOptProblem{uType, tType, isinplace, P, F, K} <: end end -function pysym_getproperty(s::Union{Num, Symbolics.Symbolic}, name::Symbol) +function pysym_getproperty(s::Union{Num, SymbolicT}, name::Symbol) Symbolics.wrap(SymbolicUtils.term( _getproperty, Symbolics.unwrap(s), Val{name}(), type = Symbolics.Struct{PyomoVar})) end @@ -112,7 +112,7 @@ function MTK.add_constraint!(pmodel::PyomoDynamicOptModel, cons; n_idxs = 1) Symbolics.unwrap(expr), SPECIAL_FUNCTIONS_DICT, fold = false) cons_sym = Symbol("cons", hash(cons)) - if occursin(Symbolics.unwrap(t_sym), expr) + if SU.query!(isequal(Symbolics.unwrap(t_sym)), expr) f = eval(Symbolics.build_function(expr, model_sym, t_sym)) setproperty!(model, cons_sym, pyomo.Constraint(model.t, rule = Pyomo.pyfunc(f))) else @@ -124,7 +124,7 @@ end function MTK.set_objective!(pmodel::PyomoDynamicOptModel, expr) @unpack model, model_sym, t_sym, dummy_sym = pmodel expr = Symbolics.substitute(expr, SPECIAL_FUNCTIONS_DICT, fold = false) - if occursin(Symbolics.unwrap(t_sym), expr) + if SU.query!(isequal(Symbolics.unwrap(t_sym)), expr) f = eval(Symbolics.build_function(expr, model_sym, t_sym)) model.obj = pyomo.Objective(model.t, rule = Pyomo.pyfunc(f)) else @@ -165,7 +165,7 @@ end function MTK.lowered_var(m::PyomoDynamicOptModel, uv, i, t) X = Symbolics.value(pysym_getproperty(m.model_sym, uv)) - var = t isa Union{Num, Symbolics.Symbolic} ? X[i, m.t_sym] : X[i, t] + var = t isa Union{Num, SymbolicT} ? X[i, m.t_sym] : X[i, t] Symbolics.unwrap(var) end diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 4f29c5f428..0475390408 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -9,16 +9,18 @@ using PrecompileTools, Reexport end import SymbolicUtils +import SymbolicUtils as SU import SymbolicUtils: iscall, arguments, operation, maketerm, promote_symtype, - Symbolic, isadd, ismul, ispow, issym, FnType, - @rule, Rewriters, substitute, metadata, BasicSymbolic, - Sym, Term + isadd, ismul, ispow, issym, FnType, isconst, BSImpl, + @rule, Rewriters, substitute, metadata, BasicSymbolic using SymbolicUtils.Code import SymbolicUtils.Code: toexpr import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint using DocStringExtensions using SpecialFunctions, NaNMath -using DiffEqCallbacks +@recompile_invalidations begin + using DiffEqCallbacks +end using Graphs import ExprTools: splitdef, combinedef import OrderedCollections @@ -48,7 +50,6 @@ using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, Ti PeriodicClock, Clock, SolverStepClock, ContinuousClock, OverrideInit, NoInit using Distributed -import JuliaFormatter using MLStyle import Moshi using Moshi.Data: @data @@ -62,19 +63,16 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc using OffsetArrays: Origin import CommonSolve import EnumX -import ChainRulesCore -import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr -using Symbolics: degree -using Symbolics: _parse_vars, value, @derivatives, get_variables, - exprs_occur_in, symbolic_linear_solve, build_expr, unwrap, wrap, +using Symbolics: degree, VartypeT, SymbolicT +using Symbolics: parse_vars, value, @derivatives, get_variables, + exprs_occur_in, symbolic_linear_solve, unwrap, wrap, VariableSource, getname, variable, - NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval, - hasnode, fixpoint_sub, fast_substitute, - CallWithMetadata, CallWithParent + NAMESPACE_SEPARATOR, setdefaultval, + hasnode, fixpoint_sub, CallAndWrap, SArgsT, SSym, STerm const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR) import Symbolics: rename, get_variables!, _solve, hessian_sparsity, jacobian_sparsity, isaffine, islinear, _iszero, _isone, @@ -83,7 +81,7 @@ import Symbolics: rename, get_variables!, _solve, hessian_sparsity, ParallelForm, SerialForm, MultithreadedForm, build_function, rhss, lhss, prettify_expr, gradient, jacobian, hessian, derivative, sparsejacobian, sparsehessian, - substituter, scalarize, getparent, hasderiv, hasdiff + scalarize, hasderiv import DiffEqBase: @add_kwonly export independent_variables, unknowns, observables, parameters, full_parameters, @@ -157,6 +155,9 @@ include("parameters.jl") include("independent_variables.jl") include("constants.jl") +const SymmapT = Dict{SymbolicT, SymbolicT} +const COMMON_NOTHING = SU.Const{VartypeT}(nothing) + include("utils.jl") include("systems/index_cache.jl") @@ -225,7 +226,6 @@ include("structural_transformation/StructuralTransformations.jl") @reexport using .StructuralTransformations include("inputoutput.jl") -include("adjoints.jl") include("deprecations.jl") const t_nounits = let @@ -334,6 +334,8 @@ export AbstractCollocation, JuMPCollocation, InfiniteOptCollocation, CasADiCollocation, PyomoCollocation export DynamicOptSolution +const set_scalar_metadata = setmetadata + @public apply_to_variables, equations_toplevel, unknowns_toplevel, parameters_toplevel @public continuous_events_toplevel, discrete_events_toplevel, assertions, is_alg_equation @public is_diff_equation, Equality, linearize_symbolic, reorder_unknowns @@ -356,10 +358,100 @@ for prop in [SYS_PROPS; [:continuous_events, :discrete_events]] end PrecompileTools.@compile_workload begin - using ModelingToolkit + fold1 = Val{false}() + using SymbolicUtils + using SymbolicUtils: shape + using Symbolics + @syms x y f(t) q[1:5] + SymbolicUtils.Sym{SymReal}(:a; type = Real, shape = SymbolicUtils.ShapeVecT()) + x + y + x * y + x / y + x ^ y + x ^ 5 + 6 ^ x + x - y + -y + 2y + z = 2 + dict = SymbolicUtils.ACDict{VartypeT}() + dict[x] = 1 + dict[y] = 1 + type::typeof(DataType) = rand() < 0.5 ? Real : Float64 + nt = (; type, shape, unsafe = true) + Base.pairs(nt) + BSImpl.AddMul{VartypeT}(1, dict, SymbolicUtils.AddMulVariant.MUL; type, shape = SymbolicUtils.ShapeVecT(), unsafe = true) + *(y, z) + *(z, y) + SymbolicUtils.symtype(y) + f(x) + (5x / 5) + expand((x + y) ^ 2) + simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y) + ex = x + 2y + sin(x) + rules1 = Dict(x => y) + rules2 = Dict(x => 1) + Dx = Differential(x) + Differential(y)(ex) + uex = unwrap(ex) + Symbolics.executediff(Dx, uex) + # Running `fold = Val(true)` invalidates the precompiled statements + # for `fold = Val(false)` and itself doesn't precompile anyway. + # substitute(ex, rules1) + substitute(ex, rules1; fold = fold1) + substitute(ex, rules2; fold = fold1) + @variables foo + f(foo) + @variables x y f(::Real) q[1:5] + x + y + x * y + x / y + x ^ y + x ^ 5 + # 6 ^ x + x - y + -y + 2y + symtype(y) + z = 2 + *(y, z) + *(z, y) + f(x) + (5x / 5) + [x, y] + [x, f, f] + promote_type(Int, Num) + promote_type(Real, Num) + promote_type(Float64, Num) + # expand((x + y) ^ 2) + # simplify(x ^ (1//2) + (sin(x) ^ 2 + cos(x) ^ 2) + 2(x + y) - x - y) + ex = x + 2y + sin(x) + rules1 = Dict(x => y) + # rules2 = Dict(x => 1) + # Running `fold = Val(true)` invalidates the precompiled statements + # for `fold = Val(false)` and itself doesn't precompile anyway. + # substitute(ex, rules1) + substitute(ex, rules1; fold = fold1) + Symbolics.linear_expansion(ex, y) + # substitute(ex, rules2; fold = fold1) + # substitute(ex, rules2) + # substitute(ex, rules1; fold = fold2) + # substitute(ex, rules2; fold = fold2) + q[1] + q'q + using ModelingToolkit @variables x(ModelingToolkit.t_nounits) - @named sys = System([ModelingToolkit.D_nounits(x) ~ -x], ModelingToolkit.t_nounits) - prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), jac = true) + isequal(ModelingToolkit.D_nounits.x, ModelingToolkit.t_nounits) + sys = System([ModelingToolkit.D_nounits(x) ~ x], ModelingToolkit.t_nounits, [x], Num[]; name = :sys) + complete(sys) + @syms p[1:2] + ndims(p) + size(p) + axes(p) + length(p) + v = [p] + isempty(v) + # mtkcompile(sys) @mtkmodel __testmod__ begin @constants begin c = 1.0 @@ -390,4 +482,17 @@ PrecompileTools.@compile_workload begin end end +precompile(Tuple{typeof(Base.merge), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Nothing, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:metadata,), Tuple{Base.ImmutableDict{DataType, Any}}}}) +precompile(Tuple{typeof(Base.merge), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Base.ImmutableDict{DataType, Any}, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:id, :hash, :hash2), Tuple{Nothing, Int64, Int64}}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:f, :args, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, SymbolicUtils.SmallVec{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, Array{SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}, 1}}, Base.ImmutableDict{DataType, Any}, Int64, Int64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, Nothing}}, Type{SymbolicUtils.BasicSymbolicImpl.Term{SymbolicUtils.SymReal}}}) +precompile(Tuple{typeof(Symbolics.parse_vars), Symbol, Type, Tuple{Symbol, Symbol}, Function}) +precompile(Tuple{typeof(Base.merge), NamedTuple{(:name, :metadata, :hash, :hash2, :shape, :type, :id), Tuple{Symbol, Base.ImmutableDict{DataType, Any}, UInt64, UInt64, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}, DataType, SymbolicUtils.IDType}}, NamedTuple{(:metadata,), Tuple{Base.ImmutableDict{DataType, Any}}}}) +precompile(Tuple{typeof(Base.vect), Symbolics.Equation, Vararg{Symbolics.Equation}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) +precompile(Tuple{Type{NamedTuple{(:name, :defaults), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}}}) +precompile(Tuple{typeof(SymbolicUtils.isequal_somescalar), Float64, Float64}) +precompile(Tuple{Type{NamedTuple{(:name, :defaults, :guesses), T} where T<:Tuple}, Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:name, :defaults, :guesses), Tuple{Symbol, Base.Dict{Symbolics.Num, Float64}, Base.Dict{Symbolics.Num, Float64}}}, Type{ModelingToolkit.System}, Array{Symbolics.Equation, 1}, Symbolics.Num, Array{Symbolics.Num, 1}, Array{Symbolics.Num, 1}}) +precompile(Tuple{typeof(Core.kwcall), NamedTuple{(:type, :shape), Tuple{DataType, SymbolicUtils.SmallVec{Base.UnitRange{Int64}, Array{Base.UnitRange{Int64}, 1}}}}, typeof(SymbolicUtils.term), Any, SymbolicUtils.BasicSymbolicImpl.var"typeof(BasicSymbolicImpl)"{SymbolicUtils.SymReal}}) + end # module diff --git a/src/clock.jl b/src/clock.jl index df3b6f4b47..6537334645 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -45,7 +45,7 @@ has_time_domain(_, x) = has_time_domain(x) Determine if variable `x` has a time-domain attributed to it. """ -function has_time_domain(x::Symbolic) +function has_time_domain(x::SymbolicT) # getmetadata(x, ContinuousClock, nothing) !== nothing || # getmetadata(x, Discrete, nothing) !== nothing getmetadata(x, VariableTimeDomain, nothing) !== nothing @@ -77,7 +77,7 @@ See also [`is_continuous_domain`](@ref) """ function has_continuous_domain(x) issym(x) && return is_continuous_domain(x) - hasderiv(x) || hasdiff(x) || hassample(x) || hashold(x) + hasderiv(x) || hassample(x) || hashold(x) end """ diff --git a/src/constants.jl b/src/constants.jl index 4113287ad4..fed010a2ee 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -3,7 +3,7 @@ Test whether `x` is a constant-type Sym. """ function isconstant(x) x = unwrap(x) - x isa Symbolic && !getmetadata(x, VariableTunable, true) + x isa SymbolicT && !getmetadata(x, VariableTunable, true) end """ @@ -26,8 +26,8 @@ Define one or more constants. See also [`@independent_variables`](@ref), [`@parameters`](@ref) and [`@variables`](@ref). """ macro constants(xs...) - Symbolics._parse_vars(:constants, + Symbolics.parse_vars(:constants, Real, xs, - toconstant) |> esc + toconstant) end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 9e57296d9f..72f14408c1 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -33,7 +33,8 @@ at the inferred clock for that equation. struct SampleTime <: Operator SampleTime() = SymbolicUtils.term(SampleTime, type = Real) end -SymbolicUtils.promote_symtype(::Type{<:SampleTime}, t...) = Real +SymbolicUtils.promote_symtype(::Type{<:SampleTime}, ::Type{T}) where {T} = Real +SymbolicUtils.promote_shape(::SampleTime, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::SampleTime) = :SampleTime SymbolicUtils.isbinop(::SampleTime) = false @@ -60,7 +61,7 @@ julia> Δ = Shift(t) """ struct Shift <: Operator """Fixed Shift""" - t::Union{Nothing, Symbolic} + t::Union{Nothing, SymbolicT} steps::Int Shift(t, steps = 1) = new(value(t), steps) end @@ -71,11 +72,7 @@ SymbolicUtils.isbinop(::Shift) = false function (D::Shift)(x, allow_zero = false) !allow_zero && D.steps == 0 && return x - if Symbolics.isarraysymbolic(x) - Symbolics.array_term(D, x) - else - term(D, x) - end + term(D, x; type = symtype(x), shape = SU.shape(x)) end function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false) !allow_zero && D.steps == 0 && return x @@ -94,7 +91,8 @@ function (D::Shift)(x::Union{Num, Symbolics.Arr}, allow_zero = false) end wrap(D(vt, allow_zero)) end -SymbolicUtils.promote_symtype(::Shift, t) = t +SymbolicUtils.promote_symtype(::Shift, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Shift, @nospecialize(x::SU.ShapeT)) = x Base.show(io::IO, D::Shift) = print(io, "Shift(", D.t, ", ", D.steps, ")") @@ -162,9 +160,10 @@ function Sample(arg::Real) Sample()(arg) end end -(D::Sample)(x) = Term{symtype(x)}(D, Any[x]) +(D::Sample)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Sample)(x::Num) = Num(D(value(x))) -SymbolicUtils.promote_symtype(::Sample, x) = x +SymbolicUtils.promote_symtype(::Sample, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Sample, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::Sample) = :Sample SymbolicUtils.isbinop(::Sample) = false @@ -208,9 +207,10 @@ end is_transparent_operator(::Type{Hold}) = true -(D::Hold)(x) = Term{symtype(x)}(D, Any[x]) +(D::Hold)(x) = STerm(D, SArgsT((x,)); type = symtype(x), shape = SU.shape(x)) (D::Hold)(x::Num) = Num(D(value(x))) -SymbolicUtils.promote_symtype(::Hold, x) = x +SymbolicUtils.promote_symtype(::Hold, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Hold, @nospecialize(x::SU.ShapeT)) = x Base.nameof(::Hold) = :Hold SymbolicUtils.isbinop(::Hold) = false diff --git a/src/independent_variables.jl b/src/independent_variables.jl index d1f2ab4210..fce2d93873 100644 --- a/src/independent_variables.jl +++ b/src/independent_variables.jl @@ -7,12 +7,12 @@ Define one or more independent variables. For example: @variables x(t) """ macro independent_variables(ts...) - Symbolics._parse_vars(:independent_variables, + Symbolics.parse_vars(:independent_variables, Real, ts, - toiv) |> esc + toiv) end -toiv(s::Symbolic) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) +toiv(s::SymbolicT) = GlobalScope(setmetadata(s, MTKVariableTypeCtx, PARAMETER)) toiv(s::Symbolics.Arr) = wrap(toiv(value(s))) toiv(s::Num) = Num(toiv(value(s))) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index c113c4e753..5afc6b2000 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -49,6 +49,13 @@ See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref), [`bound_outputs`](@re """ unbound_outputs(sys) = filter(x -> !is_bound(sys, x), outputs(sys)) +function _is_atomic_inside_operator(ex::SymbolicT) + SU.default_is_atomic(ex) && Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Operator end => false + _ => true + end +end + """ is_bound(sys, u) @@ -75,8 +82,11 @@ function is_bound(sys, u, stack = []) eqs = equations(sys) eqs = filter(eq -> has_var(eq, u), eqs) # Only look at equations that contain u # isout = isoutput(u) + vars = Set{SymbolicT}() for eq in eqs - vars = [get_variables(eq.rhs); get_variables(eq.lhs)] + empty!(vars) + get_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + get_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) for var in vars var === u && continue if !same_or_inner_namespace(u, var) @@ -88,7 +98,9 @@ function is_bound(sys, u, stack = []) oeqs = observed(sys) oeqs = filter(eq -> has_var(eq, u), oeqs) # Only look at equations that contain u for eq in oeqs - vars = [get_variables(eq.rhs); get_variables(eq.lhs)] + empty!(vars) + get_variables!(vars, eq.rhs; is_atomic = _is_atomic_inside_operator) + get_variables!(vars, eq.lhs; is_atomic = _is_atomic_inside_operator) for var in vars var === u && continue if !same_or_inner_namespace(u, var) @@ -312,7 +324,7 @@ function inputs_to_parameters!(state::TransformationState, inputsyms) @set! structure.graph = complete(new_graph) @set! sys.eqs = isempty(input_to_parameters) ? equations(sys) : - fast_substitute(equations(sys), input_to_parameters) + substitute(equations(sys), input_to_parameters) @set! sys.unknowns = setdiff(unknowns(sys), keys(input_to_parameters)) ps = parameters(sys) diff --git a/src/parameters.jl b/src/parameters.jl index d8ff1bf1be..7bb76d7bf0 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -25,19 +25,19 @@ Check if the variable contains the metadata identifying it as a parameter. function isparameter(x) x = unwrap(x) - if x isa Symbolic && (varT = getvariabletype(x, nothing)) !== nothing + if x isa SymbolicT && (varT = getvariabletype(x, nothing)) !== nothing return varT === PARAMETER #TODO: Delete this branch - elseif x isa Symbolic && Symbolics.getparent(x, false) !== false - p = Symbolics.getparent(x) + elseif x isa SymbolicT && iscall(x) && operation(x) === getindex + p = arguments(x)[1] isparameter(p) || (hasmetadata(p, Symbolics.VariableSource) && getmetadata(p, Symbolics.VariableSource)[1] == :parameters) - elseif iscall(x) && operation(x) isa Symbolic + elseif iscall(x) && operation(x) isa SymbolicT varT === PARAMETER || isparameter(operation(x)) elseif iscall(x) && operation(x) == (getindex) isparameter(arguments(x)[1]) - elseif x isa Symbolic + elseif x isa SymbolicT varT === PARAMETER else false @@ -46,17 +46,13 @@ end function iscalledparameter(x) x = unwrap(x) - return isparameter(getmetadata(x, CallWithParent, nothing)) + return SymbolicUtils.is_called_function_symbolic(x) && isparameter(operation(x)) end function getcalledparameter(x) x = unwrap(x) - # `parent` is a `CallWithMetadata` with the correct metadata, - # but no namespacing. `operation(x)` has the correct namespacing, - # but is not a `CallWithMetadata` and doesn't have any metadata. - # This approach combines both. - parent = getmetadata(x, CallWithParent) - return CallWithMetadata(operation(x), metadata(parent)) + @assert iscalledparameter(x) + return operation(x) end """ @@ -80,7 +76,7 @@ toparam(s::Num) = wrap(toparam(value(s))) Maps the variable to an unknown. """ -tovar(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, VARIABLE) +tovar(s::SymbolicT) = setmetadata(s, MTKVariableTypeCtx, VARIABLE) tovar(s::Union{Num, Symbolics.Arr}) = wrap(tovar(unwrap(s))) """ @@ -91,10 +87,10 @@ Define one or more known parameters. See also [`@independent_variables`](@ref), [`@variables`](@ref) and [`@constants`](@ref). """ macro parameters(xs...) - Symbolics._parse_vars(:parameters, + Symbolics.parse_vars(:parameters, Real, xs, - toparam) |> esc + toparam) end function find_types(array) diff --git a/src/problems/initializationproblem.jl b/src/problems/initializationproblem.jl index 6960811bbd..5267ad4abc 100644 --- a/src/problems/initializationproblem.jl +++ b/src/problems/initializationproblem.jl @@ -39,7 +39,7 @@ All other keyword arguments are forwarded to the wrapped nonlinear problem const for k in keys(op) has_u0_ics |= is_variable(sys, k) || isdifferential(k) || symbolic_type(k) == ArraySymbolic() && - is_sized_array_symbolic(k) && is_variable(sys, unwrap(first(wrap(k)))) + symbolic_has_known_size(k) && is_variable(sys, unwrap(first(wrap(k)))) end if !has_u0_ics && get_initializesystem(sys) !== nothing isys = get_initializesystem(sys; initialization_eqs, check_units) diff --git a/src/problems/jumpproblem.jl b/src/problems/jumpproblem.jl index 32aa25182f..113f5fc2f2 100644 --- a/src/problems/jumpproblem.jl +++ b/src/problems/jumpproblem.jl @@ -196,13 +196,13 @@ end ### Functions to determine which unknowns a jump depends on function get_variables!(dep, jump::Union{ConstantRateJump, VariableRateJump}, variables) jr = value(jump.rate) - (jr isa Symbolic) && get_variables!(dep, jr, variables) + (jr isa SymbolicT) && get_variables!(dep, jr, variables) dep end function get_variables!(dep, jump::MassActionJump, variables) sr = value(jump.scaled_rates) - (sr isa Symbolic) && get_variables!(dep, sr, variables) + (sr isa SymbolicT) && get_variables!(dep, sr, variables) for varasop in jump.reactant_stoch any(isequal(varasop[1]), variables) && push!(dep, varasop[1]) end diff --git a/src/problems/sccnonlinearproblem.jl b/src/problems/sccnonlinearproblem.jl index d71124adde..8d131a1f31 100644 --- a/src/problems/sccnonlinearproblem.jl +++ b/src/problems/sccnonlinearproblem.jl @@ -1,5 +1,3 @@ -const TypeT = Union{DataType, UnionAll} - struct CacheWriter{F} fn::F end diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 681025cb81..4c28708d67 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -3,18 +3,21 @@ module StructuralTransformations using Setfield: @set!, @set using UnPack: @unpack -using Symbolics: unwrap, linear_expansion, fast_substitute +using Symbolics: unwrap, linear_expansion, VartypeT, SymbolicT import Symbolics using SymbolicUtils +using SymbolicUtils: BSImpl using SymbolicUtils.Code using SymbolicUtils.Rewriters using SymbolicUtils: maketerm, iscall +import SymbolicUtils as SU +import Moshi using ModelingToolkit using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential, - unknowns, equations, vars, Symbolic, diff2term_with_unit, + unknowns, equations, vars, diff2term_with_unit, shift2term_with_unit, value, - operation, arguments, Sym, Term, simplify, symbolic_linear_solve, + operation, arguments, simplify, symbolic_linear_solve, isdiffeq, isdifferential, isirreducible, empty_substitutions, get_substitutions, get_tearing_state, get_iv, independent_variables, @@ -27,7 +30,8 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe filter_kwargs, lower_varname_with_unit, lower_shift_varname_with_unit, setio, SparseMatrixCLIL, get_fullvars, has_equations, observed, - Schedule, schedule, iscomplete, get_schedule + Schedule, schedule, iscomplete, get_schedule, VariableUnshifted, + VariableShift using ModelingToolkit.BipartiteGraphs import .BipartiteGraphs: invview, complete @@ -40,7 +44,7 @@ using ModelingToolkit: algeqs, EquationsView, dervars_range, diffvars_range, algvars_range, DiffGraph, complete!, get_fullvars, system_subset -using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic +using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic, getname using ModelingToolkit.DiffEqBase using ModelingToolkit.StaticArrays diff --git a/src/structural_transformation/pantelides.jl b/src/structural_transformation/pantelides.jl index 871bd99ef4..47fa5aa762 100644 --- a/src/structural_transformation/pantelides.jl +++ b/src/structural_transformation/pantelides.jl @@ -37,7 +37,7 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) # LHS variable is looked up from var_to_diff # the var_to_diff[i]-th variable is the differentiated version of var at i eq = out_eqs[eqidx] - lhs = if !(eq.lhs isa Symbolic) + lhs = if !(eq.lhs isa SymbolicT) 0 elseif isdiffeq(eq) # look up the variable that represents D(lhs) @@ -54,9 +54,9 @@ function pantelides_reassemble(state::TearingState, var_eq_matching) D(eq.lhs) end rhs = ModelingToolkit.expand_derivatives(D(eq.rhs)) - rhs = fast_substitute(rhs, state.param_derivative_map) + rhs = substitute(rhs, state.param_derivative_map) substitution_dict = Dict(x.lhs => x.rhs - for x in out_eqs if x !== nothing && x.lhs isa Symbolic) + for x in out_eqs if x !== nothing && x.lhs isa SymbolicT) sub_rhs = substitute(rhs, substitution_dict) out_eqs[diff] = lhs ~ sub_rhs end diff --git a/src/structural_transformation/symbolics_tearing.jl b/src/structural_transformation/symbolics_tearing.jl index 39b959c5a6..9760e5fef5 100644 --- a/src/structural_transformation/symbolics_tearing.jl +++ b/src/structural_transformation/symbolics_tearing.jl @@ -65,7 +65,7 @@ function eq_derivative!(ts::TearingState, ieq::Int; kwargs...) sys = ts.sys eq = equations(ts)[ieq] - eq = 0 ~ fast_substitute( + eq = 0 ~ substitute( ModelingToolkit.derivative( eq.rhs - eq.lhs, get_iv(sys); throw_no_derivative = true), ts.param_derivative_map) @@ -108,7 +108,7 @@ end function solve_equation(eq, var, simplify) rhs = value(symbolic_linear_solve(eq, var; simplify = simplify, check = false)) - occursin(var, rhs) && throw(EquationSolveErrors(eq, var, rhs)) + SU.query!(in(var), rhs) && throw(EquationSolveErrors(eq, var, rhs)) var ~ rhs end @@ -144,7 +144,7 @@ function to_mass_matrix_form(neweqs, ieq, graph, fullvars, isdervar::F, eq = 0 ~ eq.rhs - eq.lhs end rhs = eq.rhs - if rhs isa Symbolic + if rhs isa SymbolicT # Check if the RHS is solvable in all unknown variable derivatives and if those # the linear terms for them are all zero. If so, move them to the # LHS. @@ -217,7 +217,7 @@ function substitute_derivatives_algevars!( v_t = setio(diff2term_with_unit(unwrap(dd), unwrap(iv)), false, false) for eq in 𝑑neighbors(graph, dv) dummy_sub[dd] = v_t - neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t) + neweqs[eq] = substitute(neweqs[eq], dd => v_t) end fullvars[dv] = v_t # If we have: @@ -230,7 +230,7 @@ function substitute_derivatives_algevars!( while (ddx = var_to_diff[dx]) !== nothing dx_t = D(x_t) for eq in 𝑑neighbors(graph, ddx) - neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t) + neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t) end fullvars[ddx] = dx_t dx = ddx @@ -961,8 +961,8 @@ function update_simplified_system!( obs_sub[eq.lhs] = eq.rhs end # TODO: compute the dependency correctly so that we don't have to do this - obs = [fast_substitute(observed(sys), obs_sub); solved_eqs; - fast_substitute(state.additional_observed, obs_sub)] + obs = [substitute(observed(sys), obs_sub); solved_eqs; + substitute(state.additional_observed, obs_sub)] unknown_idxs = filter( i -> diff_to_var[i] === nothing && ispresent(i) && !(fullvars[i] in solved_vars), eachindex(state.fullvars)) @@ -1189,7 +1189,7 @@ end Backshift the given expression `ex`. """ function backshift_expr(ex, iv) - ex isa Symbolic || return ex + ex isa SymbolicT || return ex return descend_lower_shift_varname_with_unit( simplify_shifts(distribute_shift(Shift(iv, -1)(ex))), iv) end @@ -1251,7 +1251,7 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) array || continue iscall(lhs) || continue operation(lhs) === getindex || continue - Symbolics.shape(lhs) != Symbolics.Unknown() || continue + SU.shape(lhs) isa SU.Unknown && continue arg1 = arguments(lhs)[1] cnt = get(arr_obs_occurrences, arg1, 0) arr_obs_occurrences[arg1] = cnt + 1 @@ -1264,7 +1264,7 @@ function tearing_hacks(sys, obs, unknowns, neweqs; array = true) for sym in unknowns iscall(sym) || continue operation(sym) === getindex || continue - Symbolics.shape(sym) != Symbolics.Unknown() || continue + SU.shape(sym) isa SU.Unknown && continue arg1 = arguments(sym)[1] cnt = get(arr_obs_occurrences, arg1, 0) cnt == 0 && continue diff --git a/src/structural_transformation/utils.jl b/src/structural_transformation/utils.jl index 3fa4f28aa9..f6ff669c0e 100644 --- a/src/structural_transformation/utils.jl +++ b/src/structural_transformation/utils.jl @@ -249,7 +249,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no a, b, islinear = linear_expansion(term, var) a, b = unwrap(a), unwrap(b) islinear || (all_int_vars = false; continue) - if a isa Symbolic + if a isa SymbolicT all_int_vars = false if !allow_symbolic if allow_parameter @@ -257,7 +257,7 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no if any( v -> any(isequal(v), fullvars) || symbolic_type(v) == ArraySymbolic() && - Symbolics.shape(v) != Symbolics.Unknown() && + SU.shape(v) isa SU.Unknown || any(x -> any(isequal(x), fullvars), collect(v)), vars( a; op = Union{Differential, Shift, Pre, Sample, Hold, Initial})) @@ -503,43 +503,55 @@ end """ Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t). """ -function shift2term(var) - iscall(var) || return var - op = operation(var) - op isa Shift || return var - iv = op.t - arg = only(arguments(var)) - if operation(arg) === getindex - idxs = arguments(arg)[2:end] - newvar = shift2term(op(first(arguments(arg))))[idxs...] - unshifted = ModelingToolkit.getunshifted(newvar)[idxs...] - newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, unshifted) - return newvar +function shift2term(var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(f, args) && if f isa Shift end => begin + op = f + arg = args[1] + Moshi.Match.@match arg begin + BSImpl.Term(; f, args, type, shape, metadata) && if f === getindex end => begin + newargs = copy(parent(args)) + newargs[1] = shift2term(op(newargs[1])) + unshifted_args = copy(newargs) + unshifted_args[1] = ModelingToolkit.getunshifted(newargs[1]) + unshifted = BSImpl.Term{VartypeT}(getindex, unshifted_args; type, shape, metadata) + if metadata === nothing + metadata = Base.ImmutableDict{DataType, Any}(VariableUnshifted, unshifted) + elseif metadata isa Base.ImmutableDict{DataType, Any} + metadata = Base.ImmutableDict(metadata, VariableUnshifted, unshifted) + end + return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata) + end + _ => nothing + end + unshifted = ModelingToolkit.getunshifted(arg) + is_lowered = unshifted !== nothing + backshift = op.steps + ModelingToolkit.getshift(arg) + io = IOBuffer() + O = (is_lowered ? unshifted : arg)::SymbolicT + write(io, getname(O)) + # Char(0x209c) = ₜ + write(io, Char(0x209c)) + # Char(0x208b) = ₋ (subscripted minus) + # Char(0x208a) = ₊ (subscripted plus) + pm = backshift > 0 ? Char(0x208a) : Char(0x208b) + write(io, pm) + backshift = abs(backshift) + N = ndigits(backshift) + den = 10 ^ (N - 1) + for _ in 1:N + # subscripted number, e.g. ₁ + write(io, Char(0x2080 + div(backshift, den) % 10)) + den = div(den, 10) + end + newname = Symbol(take!(io)) + newvar = Symbolics.rename(var, newname) + newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) + newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) + return newvar + end + _ => return var end - is_lowered = !isnothing(ModelingToolkit.getunshifted(arg)) - - backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps - - # Char(0x208b) = ₋ (subscripted minus) - # Char(0x208a) = ₊ (subscripted plus) - pm = backshift > 0 ? Char(0x208a) : Char(0x208b) - # subscripted number, e.g. ₁ - num = join(Char(0x2080 + d) for d in reverse!(digits(abs(backshift)))) - # Char(0x209c) = ₜ - # ds = ₜ₋₁ - ds = join([Char(0x209c), pm, num]) - - O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg - oldop = operation(O) - newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) : - Symbol(string(nameof(oldop))) - - newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname), - Symbolics.children(O), Symbolics.metadata(O)) - newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname)) - newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O) - newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift) - return newvar end function isdoubleshift(var) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 0bd05bb4b9..a30b8f71df 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -69,7 +69,7 @@ function wrap_assignments(isscalar, assignments; let_block = false) end end -const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___) +const MTKPARAMETERS_ARG = SSym(:___mtkparameters___; type = Vector{Vector{Any}}, shape = SymbolicUtils.Unknown(1)) """ $(TYPEDSIGNATURES) @@ -94,11 +94,11 @@ See also [`@independent_variables`](@ref) and [`ModelingToolkit.get_iv`](@ref). """ function independent_variables(sys::AbstractSystem) if isdefined(sys, :iv) && getfield(sys, :iv) !== nothing - return [getfield(sys, :iv)] + return SymbolicT[getfield(sys, :iv)] elseif isdefined(sys, :ivs) - return getfield(sys, :ivs) + return getfield(sys, :ivs)::Vector{SymbolicT} else - return [] + return SymbolicT[] end end @@ -170,17 +170,20 @@ function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem) return unknowns(sys) end -function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) - sym = unwrap(sym) +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + is_parameter(sys, unwrap(sym)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Int) + sym in 1:length(parameter_symbols(sys)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::SymbolicT) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing return sym isa ParameterIndex || is_parameter(ic, sym) || - iscall(sym) && - operation(sym) === getindex && + iscall(sym) && operation(sym) === getindex && is_parameter(ic, first(arguments(sym))) end - if unwrap(sym) isa Int - return unwrap(sym) in 1:length(parameter_symbols(sys)) - end return any(isequal(sym), parameter_symbols(sys)) || hasname(sym) && !(iscall(sym) && operation(sym) == getindex) && is_parameter(sys, getname(sym)) @@ -191,7 +194,7 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol return is_parameter(ic, sym) end - named_parameters = [getname(x) + named_parameters = Symbol[getname(x) for x in parameter_symbols(sys) if hasname(x) && !(iscall(x) && operation(x) == getindex)] return any(isequal(sym), named_parameters) || @@ -488,7 +491,8 @@ of a system. See the documentation section on initialization for more informatio struct Initial <: Symbolics.Operator end is_timevarying_operator(::Type{Initial}) = false Initial(x) = Initial()(x) -SymbolicUtils.promote_symtype(::Type{Initial}, T) = T +SymbolicUtils.promote_symtype(::Initial, ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::Initial, @nospecialize(x::SU.ShapeT)) = x SymbolicUtils.isbinop(::Initial) = false Base.nameof(::Initial) = :Initial Base.show(io::IO, x::Initial) = print(io, "Initial") @@ -507,16 +511,16 @@ function (f::Initial)(x) end # don't double wrap iscall(x) && operation(x) isa Initial && return x - result = if symbolic_type(x) == ArraySymbolic() - # create an array for `Initial(array)` - Symbolics.array_term(f, x) - elseif iscall(x) && operation(x) == getindex + sh = SU.shape(x) + result = if SU.is_array_shape(sh) + term(f, x; type = symtype(x), shape = sh) + elseif iscall(x) && operation(x) === getindex # instead of `Initial(x[1])` create `Initial(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] - term(getindex, f(arr), arguments(x)[2:end]...) + f(arr)[arguments(x)[2:end]...] else - term(f, x) + term(f, x; type = symtype(x), shape = sh) end # the result should be a parameter result = toparam(result) @@ -526,15 +530,6 @@ function (f::Initial)(x) return result end -# This is required so `fast_substitute` works -function SymbolicUtils.maketerm(::Type{<:BasicSymbolic}, ::Initial, args, meta) - val = Initial()(args...) - if symbolic_type(val) == NotSymbolic() - return val - end - return metadata(val, meta) -end - supports_initialization(sys::AbstractSystem) = true function add_initialization_parameters(sys::AbstractSystem; split = true) @@ -542,16 +537,20 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) supports_initialization(sys) || return sys is_initializesystem(sys) && return sys - all_initialvars = Set{BasicSymbolic}() + all_initialvars = Set{SymbolicT}() # time-independent systems don't initialize unknowns # but may initialize parameters using guesses for unknowns eqs = equations(sys) - if !(eqs isa Vector{Equation}) - eqs = Equation[x for x in eqs if x isa Equation] - end obs, eqs = unhack_observed(observed(sys), eqs) - for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs))) - x = unwrap(x) + for x in unknowns(sys) + if iscall(x) && operation(x) == getindex && split + push!(all_initialvars, arguments(x)[1]) + else + push!(all_initialvars, x) + end + end + for eq in obs + x = eq.lhs if iscall(x) && operation(x) == getindex && split push!(all_initialvars, arguments(x)[1]) else @@ -561,15 +560,19 @@ function add_initialization_parameters(sys::AbstractSystem; split = true) # add derivatives of all variables for steady-state initial conditions if is_time_dependent(sys) && !is_discrete_system(sys) - D = Differential(get_iv(sys)) - union!(all_initialvars, [D(v) for v in all_initialvars if iscall(v)]) + D = Differential(get_iv(sys)::SymbolicT) + for v in all_initialvars + iscall(v) && push!(all_initialvars, D(v)) + end end for eq in get_parameter_dependencies(sys) is_variable_floatingpoint(eq.lhs) || continue push!(all_initialvars, eq.lhs) end - all_initialvars = collect(all_initialvars) - initials = map(Initial(), all_initialvars) + initials = collect(all_initialvars) + for (i, v) in enumerate(initials) + initials[i] = Initial()(v) + end @set! sys.ps = unique!([get_ps(sys); initials]) defs = copy(get_defaults(sys)) for ivar in initials @@ -601,9 +604,9 @@ end Find [`GlobalScope`](@ref)d variables in `sys` and add them to the unknowns/parameters. """ function discover_globalscoped(sys::AbstractSystem) - newunknowns = OrderedSet() - newparams = OrderedSet() - iv = has_iv(sys) ? get_iv(sys) : nothing + newunknowns = OrderedSet{SymbolicT}() + newparams = OrderedSet{SymbolicT}() + iv::Union{SymbolicT, Nothing} = has_iv(sys) ? get_iv(sys) : nothing collect_scoped_vars!(newunknowns, newparams, sys, iv; depth = -1) setdiff!(newunknowns, observables(sys)) @set! sys.ps = unique!(vcat(get_ps(sys), collect(newparams))) @@ -626,34 +629,30 @@ This namespacing functionality can also be toggled independently of `complete` using [`toggle_namespacing`](@ref). """ function complete( - sys::AbstractSystem; split = true, flatten = true, add_initial_parameters = true) + sys::T; split = true, flatten = true, add_initial_parameters = true) where {T <: AbstractSystem} sys = discover_globalscoped(sys) if flatten - eqs = equations(sys) - if eqs isa AbstractArray && eltype(eqs) <: Equation - newsys = expand_connections(sys) - else - newsys = sys - end + newsys = expand_connections(sys) newsys = ModelingToolkit.flatten(newsys) if has_parent(newsys) && get_parent(sys) === nothing - @set! newsys.parent = complete(sys; split = false, flatten = false) + @set! newsys.parent = complete(sys; split = false, flatten = false)::T end sys = newsys sys = process_parameter_equations(sys) if add_initial_parameters sys = add_initialization_parameters(sys; split) end + cb_alg_eqs = Equation[alg_equations(sys); observed(sys)] if has_continuous_events(sys) && is_time_dependent(sys) @set! sys.continuous_events = complete.( get_continuous_events(sys); iv = get_iv(sys), - alg_eqs = [alg_equations(sys); observed(sys)]) + alg_eqs = cb_alg_eqs) end if has_discrete_events(sys) && is_time_dependent(sys) @set! sys.discrete_events = complete.( get_discrete_events(sys); iv = get_iv(sys), - alg_eqs = [alg_equations(sys); observed(sys)]) + alg_eqs = cb_alg_eqs) end end if split && has_index_cache(sys) @@ -661,6 +660,7 @@ function complete( # Ideally we'd do `get_ps` but if `flatten = false` # we don't get all of them. So we call `parameters`. all_ps = parameters(sys; initial_parameters = true) + all_ps_set = Set{SymbolicT}(all_ps) # inputs have to be maintained in a specific order input_vars = inputs(sys) if !isempty(all_ps) @@ -668,24 +668,29 @@ function complete( ps_split = reorder_parameters(sys, all_ps) # if there are tunables, they will all be in `ps_split[1]` # and the arrays will have been scalarized - ordered_ps = eltype(all_ps)[] + ordered_ps = SymbolicT[] + offset = 0 # if there are no tunables, vcat them if !isempty(get_index_cache(sys).tunable_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps) - ps_split = Base.tail(ps_split) + unflatten_parameters!(ordered_ps, ps_split[1], all_ps_set) + offset += 1 end # unflatten initial parameters if !isempty(get_index_cache(sys).initials_idx) - unflatten_parameters!(ordered_ps, ps_split[1], all_ps) - ps_split = Base.tail(ps_split) + unflatten_parameters!(ordered_ps, ps_split[1], all_ps_set) + offset += 1 + end + for i in (offset+1):length(ps_split) + append!(ordered_ps, ps_split[i]::Vector{SymbolicT}) end - ordered_ps = vcat( - ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[])) if isscheduled(sys) # ensure inputs are sorted - input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,)) - @assert all(!isnothing, input_idxs) - @assert issorted(input_idxs) + last_idx = 0 + for p in input_vars + idx = findfirst(isequal(p), ordered_ps)::Int + @assert last_idx < idx + last_idx = idx + end end @set! sys.ps = ordered_ps end @@ -722,26 +727,28 @@ parameters in the system `all_ps`, unscalarize the elements in `params` and appe to `buffer` in the same order as they are present in `params`. Effectively, if `params = [p[1], p[2], p[3], q]` then this is equivalent to `push!(buffer, p, q)`. """ -function unflatten_parameters!(buffer, params, all_ps) +function unflatten_parameters!(buffer::Vector{SymbolicT}, params::Vector{SymbolicT}, all_ps::Set{SymbolicT}) i = 1 # go through all the tunables while i <= length(params) sym = params[i] # if the sym is not a scalarized array symbolic OR it was already scalarized, # just push it as-is - if !iscall(sym) || operation(sym) != getindex || - any(isequal(sym), all_ps) + if !iscall(sym) || operation(sym) !== getindex || sym in all_ps push!(buffer, sym) i += 1 continue end + + arrsym = first(arguments(sym)) # the next `length(sym)` symbols should be scalarized versions of the same # array symbolic - if !allequal(first(arguments(x)) - for x in view(params, i:(i + length(sym) - 1))) - error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.") + for j in (i+1):(i+length(sym)-1) + p = params[j] + if !(iscall(p) && operation(p) === getindex && isequal(arguments(p)[1], arrsym)) + error("This should not be possible. Please open an issue in ModelingToolkit.jl with an MWE and stacktrace.") + end end - arrsym = first(arguments(sym)) push!(buffer, arrsym) i += length(arrsym) end @@ -1019,7 +1026,7 @@ struct LocalScope <: SymScope end Apply `LocalScope` to `sym`. """ -function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function LocalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) === getindex args = arguments(sym) @@ -1051,7 +1058,7 @@ end Apply `ParentScope` to `sym`, with `parent` being `LocalScope`. """ -function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function ParentScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) === getindex args = arguments(sym) @@ -1081,7 +1088,7 @@ struct GlobalScope <: SymScope end Apply `GlobalScope` to `sym`. """ -function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) +function GlobalScope(sym::Union{Num, SymbolicT, Symbolics.Arr{Num}}) apply_to_variables(sym) do sym if iscall(sym) && operation(sym) == getindex args = arguments(sym) @@ -1094,44 +1101,58 @@ function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) end end +const AllScopes = Union{LocalScope, ParentScope, GlobalScope} + renamespace(sys, eq::Equation) = namespace_equation(eq, sys) renamespace(names::AbstractVector, x) = foldr(renamespace, names, init = x) +renamespace(sys, tgt::AbstractSystem) = rename(tgt, renamespace(sys, nameof(tgt))) +renamespace(sys, tgt::Symbol) = Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, tgt) + """ $(TYPEDSIGNATURES) Namespace `x` with the name of `sys`. """ -function renamespace(sys, x) - sys === nothing && return x - x = unwrap(x) - if x isa Symbolic - T = typeof(x) - if iscall(x) && operation(x) isa Operator - return maketerm(typeof(x), operation(x), - Any[renamespace(sys, only(arguments(x)))], - metadata(x))::T - end - if iscall(x) && operation(x) === getindex - args = arguments(x) - return maketerm( - typeof(x), operation(x), vcat(renamespace(sys, args[1]), args[2:end]), - metadata(x))::T - end - let scope = getmetadata(x, SymScope, LocalScope()) +function renamespace(sys, x::SymbolicT) + isequal(x, SU.idxs_for_arrayop(VartypeT)) && return x + Moshi.Match.@match x begin + BSImpl.Sym(; name) => let scope = getmetadata(x, SymScope, LocalScope())::AllScopes if scope isa LocalScope - rename(x, renamespace(getname(sys), getname(x)))::T + return rename(x, renamespace(getname(sys), name))::SymbolicT elseif scope isa ParentScope - setmetadata(x, SymScope, scope.parent)::T - else # GlobalScope - x::T + return setmetadata(x, SymScope, scope.parent)::SymbolicT + elseif scope isa GlobalScope + return x + end + error() + end + BSImpl.Term(; f, args, shape, type, metadata) => begin + if f === getindex + newargs = copy(parent(args)) + newargs[1] = renamespace(sys, args[1]) + return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata) + elseif f isa SymbolicT + let scope = getmetadata(x, SymScope, LocalScope())::Union{LocalScope, ParentScope, GlobalScope} + if scope isa LocalScope + return rename(x, renamespace(getname(sys), getname(x)))::SymbolicT + elseif scope isa ParentScope + return setmetadata(x, SymScope, scope.parent)::SymbolicT + elseif scope isa GlobalScope + return x + end + error() + end + elseif f isa Operator + newargs = copy(parent(args)) + for (i, arg) in enumerate(args) + newargs[i] = renamespace(sys, arg) + end + return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata) end + error() end - elseif x isa AbstractSystem - rename(x, renamespace(sys, nameof(x))) - else - Symbol(getname(sys), NAMESPACE_SEPARATOR_SYMBOL, x) end end @@ -1156,8 +1177,14 @@ Return `equations(sys)`, namespaced by the name of `sys`. """ function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys)) eqs = equations(sys) - isempty(eqs) && return Equation[] - map(eq -> namespace_equation(eq, sys; ivs), eqs) + isempty(eqs) && return eqs + if eqs === get_eqs(sys) + eqs = copy(eqs) + end + for i in eachindex(eqs) + eqs[i] = namespace_equation(eqs[i], sys; ivs) + end + return eqs end function namespace_initialization_equations( @@ -1204,7 +1231,15 @@ function namespace_jump(j::MassActionJump, sys) end function namespace_jumps(sys::AbstractSystem) - return [namespace_jump(j, sys) for j in get_jumps(sys)] + js = jumps(sys) + isempty(js) && return js + if js === get_jumps(sys) + js = copy(js) + end + for i in eachindex(js) + js[i] = namespace_jump(js[i], sys) + end + return js end function namespace_brownians(sys::AbstractSystem) @@ -1224,48 +1259,63 @@ function is_array_of_symbolics(x) any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x) end -function namespace_expr( - O, sys, n = (sys === nothing ? nothing : nameof(sys)); - ivs = sys === nothing ? nothing : independent_variables(sys)) - sys === nothing && return O - O = unwrap(O) - # Exceptions for arrays of symbolic and Ref of a symbolic, the latter - # of which shows up in broadcasts - if symbolic_type(O) == NotSymbolic() && !(O isa AbstractArray) && !(O isa Ref) - return O - end - if any(isequal(O), ivs) - return O - elseif iscall(O) - T = typeof(O) - renamed = let sys = sys, n = n, T = T - map(a -> namespace_expr(a, sys, n; ivs)::Any, arguments(O)) - end - if isvariable(O) - # Use renamespace so the scope is correct, and make sure to use the - # metadata from the rescoped variable - rescoped = renamespace(n, O) - maketerm(typeof(rescoped), operation(rescoped), renamed, - metadata(rescoped)) - elseif Symbolics.isarraysymbolic(O) - # promote_symtype doesn't work for array symbolics - maketerm(typeof(O), operation(O), renamed, metadata(O)) - else - maketerm(typeof(O), operation(O), renamed, metadata(O)) +function namespace_expr(O, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) + return O +end +function namespace_expr(O::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}, sys::AbstractSystem, n::Symbol = nameof(sys); kw...) + namespace_expr(O, args...; kw...) +end +function namespace_expr(O::AbstractArray, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) + is_array_of_symbolics(O) || return O + O = copy(O) + for i in eachindex(O) + O[i] = namespace_expr(O[i], sys, n; ivs) + end + return O +end +function namespace_expr(O::SymbolicT, sys::AbstractSystem, n::Symbol = nameof(sys); ivs = independent_variables(sys)) + any(isequal(O), ivs) && return O + isvar = isvariable(O) + Moshi.Match.@match O begin + BSImpl.Const(;) => return O + BSImpl.Sym(;) => return isvar ? renamespace(n, O) : O + BSImpl.Term(; f, args, metadata, type, shape) => begin + newargs = copy(parent(args)) + for i in eachindex(args) + newargs[i] = namespace_expr(newargs[i], sys, n; ivs) + end + if isvar + rescoped = renamespace(n, O) + f = Moshi.Data.variant_getfield(rescoped, BSImpl.Term{VartypeT}, :f) + meta = Moshi.Data.variant_getfield(rescoped, BSImpl.Term{VartypeT}, :metadata) + elseif f isa SymbolicT + f = renamespace(n, f) + meta = metadata + end + return BSImpl.Term{VartypeT}(f, newargs; type, shape, metadata = meta) end - elseif isvariable(O) - renamespace(n, O) - elseif O isa AbstractArray && is_array_of_symbolics(O) - let sys = sys, n = n - map(o -> namespace_expr(o, sys, n; ivs), O) + BSImpl.AddMul(; coeff, dict, variant, type, shape, metadata) => begin + newdict = copy(dict) + for (k, v) in newdict + newdict[namespace_expr(k, sys, n; ivs)] = v + end + return BSImpl.AddMul{VartypeT}(coeff, newdict, variant; type, shape, metadata) + end + BSImpl.Div(; num, den, type, shape, metadata) => begin + num = namespace_expr(num, sys, n; ivs) + den = namespace_expr(den, sys, n; ivs) + return BSImpl.Div{VartypeT}(num, den, false; type, shape, metadata) + end + BSImpl.ArrayOp(; output_idx, expr, term, ranges, reduce, type, shape, metadata) => begin + if term isa SymbolicT + term = namespace_expr(term, sys, n; ivs) + end + expr = namespace_expr(expr, sys, n; ivs) + return BSImpl.ArrayOp{VartypeT}(output_idx, expr, reduce, term, ranges; type, shape, metadata) end - else - O end end -_nonum(@nospecialize x) = x isa Num ? x.val : x - """ $(TYPEDSIGNATURES) @@ -1277,21 +1327,14 @@ See also [`ModelingToolkit.get_unknowns`](@ref). function unknowns(sys::AbstractSystem) sts = get_unknowns(sys) systems = get_systems(sys) - nonunique_unknowns = if isempty(systems) - sts - else - system_unknowns = reduce(vcat, namespace_variables.(systems)) - isempty(sts) ? system_unknowns : [sts; system_unknowns] + if isempty(systems) + return sts end - isempty(nonunique_unknowns) && return nonunique_unknowns - # `Vector{Any}` is incompatible with the `SymbolicIndexingInterface`, which uses - # `elsymtype = symbolic_type(eltype(_arg))` - # which inappropriately returns `NotSymbolic()` - if nonunique_unknowns isa Vector{Any} - nonunique_unknowns = _nonum.(nonunique_unknowns) + result = copy(sts) + for subsys in systems + append!(result, namespace_variables(subsys)) end - @assert typeof(nonunique_unknowns) !== Vector{Any} - unique(nonunique_unknowns) + return result end """ @@ -1315,19 +1358,24 @@ See also [`@parameters`](@ref) and [`ModelingToolkit.get_ps`](@ref). """ function parameters(sys::AbstractSystem; initial_parameters = false) ps = get_ps(sys) - if ps == SciMLBase.NullParameters() + if ps === SciMLBase.NullParameters() return [] end if eltype(ps) <: Pair ps = first.(ps) end systems = get_systems(sys) - result = unique(isempty(systems) ? ps : - [ps; reduce(vcat, namespace_parameters.(systems))]) + if isempty(systems) + return ps + end + result = copy(ps) + for subsys in systems + append!(result, namespace_parameters(subsys)) + end if !initial_parameters && !is_initializesystem(sys) filter!(result) do sym return !(isoperator(sym, Initial) || - iscall(sym) && operation(sym) == getindex && + iscall(sym) && operation(sym) === getindex && isoperator(arguments(sym)[1], Initial)) end end @@ -1500,10 +1548,8 @@ function defaults_and_guesses(sys::AbstractSystem) end unknowns(sys::Union{AbstractSystem, Nothing}, v) = namespace_expr(v, sys) -for vType in [Symbolics.Arr, Symbolics.Symbolic{<:AbstractArray}] - @eval unknowns(sys::AbstractSystem, v::$vType) = namespace_expr(v, sys) - @eval parameters(sys::AbstractSystem, v::$vType) = toparam(unknowns(sys, v)) -end +unknowns(sys::AbstractSystem, v::Symbolics.Arr) = namespace_expr(v, sys) +parameters(sys::AbstractSystem, v::Symbolics.Arr) = toparam(unknowns(sys, v)) parameters(sys::Union{AbstractSystem, Nothing}, v) = toparam(unknowns(sys, v)) for f in [:unknowns, :parameters] @eval function $f(sys::AbstractSystem, vs::AbstractArray) @@ -1525,15 +1571,12 @@ See also [`full_equations`](@ref) and [`ModelingToolkit.get_eqs`](@ref). function equations(sys::AbstractSystem) eqs = get_eqs(sys) systems = get_systems(sys) - if isempty(systems) - return eqs - else - eqs = Equation[eqs; - reduce(vcat, - namespace_equations.(get_systems(sys)); - init = Equation[])] - return eqs + isempty(systems) && return eqs + eqs = copy(eqs) + for subsys in systems + append!(eqs, namespace_equations(subsys)) end + return eqs end """ @@ -1613,10 +1656,12 @@ all the subsystems of `sys` (appropriately namespaced). function jumps(sys::AbstractSystem) js = get_jumps(sys) systems = get_systems(sys) - if isempty(systems) - return js + isempty(systems) && return js + js = copy(js) + for subsys in systems + append!(js, namespace_jumps(subsys)) end - return [js; reduce(vcat, namespace_jumps.(systems); init = [])] + return js end """ @@ -1665,8 +1710,14 @@ end function namespace_constraints(sys) cstrs = constraints(sys) - isempty(cstrs) && return Vector{Union{Equation, Inequality}}(undef, 0) - map(cstr -> namespace_constraint(cstr, sys), cstrs) + isempty(cstrs) && return cstrs + if cstrs === get_constraints(sys) + cstrs = copy(cstrs) + end + for i in eachindex(cstrs) + cstrs[i] = namespace_constraint(cstrs[i], sys) + end + return cstrs end """ @@ -1677,7 +1728,12 @@ Get all constraints in the system `sys` and all of its subsystems, appropriately function constraints(sys::AbstractSystem) cs = get_constraints(sys) systems = get_systems(sys) - isempty(systems) ? cs : [cs; reduce(vcat, namespace_constraints.(systems))] + isempty(systems) && return cs + cs = copy(sys) + for subsys in systems + append!(cs, namespace_constraints(subsys)) + end + return cs end """ @@ -2257,7 +2313,7 @@ end function default_to_parentscope(v) uv = unwrap(v) - uv isa Symbolic || return v + uv isa SymbolicT || return v apply_to_variables(v) do sym ParentScope(sym) end @@ -2735,26 +2791,26 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair}, elseif sys isa System rules = todict(map(r -> Symbolics.unwrap(r[1]) => Symbolics.unwrap(r[2]), collect(rules))) - newsys = @set sys.eqs = fast_substitute(get_eqs(sys), rules) + newsys = @set sys.eqs = substitute(get_eqs(sys), rules) @set! newsys.unknowns = map(get_unknowns(sys)) do var get(rules, var, var) end @set! newsys.ps = map(get_ps(sys)) do var get(rules, var, var) end - @set! newsys.parameter_dependencies = fast_substitute( + @set! newsys.parameter_dependencies = substitute( get_parameter_dependencies(sys), rules) - @set! newsys.defaults = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.defaults = Dict(substitute(k, rules) => substitute(v, rules) for (k, v) in get_defaults(sys)) - @set! newsys.guesses = Dict(fast_substitute(k, rules) => fast_substitute(v, rules) + @set! newsys.guesses = Dict(substitute(k, rules) => substitute(v, rules) for (k, v) in get_guesses(sys)) - @set! newsys.noise_eqs = fast_substitute(get_noise_eqs(sys), rules) - @set! newsys.costs = Vector{Union{Real, BasicSymbolic}}(fast_substitute( + @set! newsys.noise_eqs = substitute(get_noise_eqs(sys), rules) + @set! newsys.costs = Vector{Union{Real, BasicSymbolic}}(substitute( get_costs(sys), rules)) - @set! newsys.observed = fast_substitute(get_observed(sys), rules) - @set! newsys.initialization_eqs = fast_substitute( + @set! newsys.observed = substitute(get_observed(sys), rules) + @set! newsys.initialization_eqs = substitute( get_initialization_eqs(sys), rules) - @set! newsys.constraints = fast_substitute(get_constraints(sys), rules) + @set! newsys.constraints = substitute(get_constraints(sys), rules) @set! newsys.systems = map(s -> substitute(s, rules), get_systems(sys)) else error("substituting symbols is not supported for $(typeof(sys))") @@ -2774,18 +2830,18 @@ function process_parameter_equations(sys::AbstractSystem) if !isempty(get_systems(sys)) throw(ArgumentError("Expected flattened system")) end - varsbuf = Set() + varsbuf = Set{SymbolicT}() pareq_idxs = Int[] eqs = equations(sys) for (i, eq) in enumerate(eqs) empty!(varsbuf) - vars!(varsbuf, eq; op = Union{Differential, Initial, Pre}) + SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{Union{Differential, Initial, Pre}}()) # singular equations isempty(varsbuf) && continue if all(varsbuf) do sym is_parameter(sys, sym) || symbolic_type(sym) == ArraySymbolic() && - is_sized_array_symbolic(sym) && + symbolic_has_known_size(sym) && all(Base.Fix1(is_parameter, sys), collect(sym)) || iscall(sym) && operation(sym) === getindex && is_parameter(sys, arguments(sym)[1]) @@ -2802,8 +2858,11 @@ function process_parameter_equations(sys::AbstractSystem) end end - pareqs = [get_parameter_dependencies(sys); eqs[pareq_idxs]] - explicitpars = [eq.lhs for eq in pareqs] + pareqs = Equation[get_parameter_dependencies(sys); eqs[pareq_idxs]] + explicitpars = SymbolicT[] + for eq in pareqs + push!(explicitpars, eq.lhs) + end pareqs = topsort_equations(pareqs, explicitpars) eqs = eqs[setdiff(eachindex(eqs), pareq_idxs)] diff --git a/src/systems/alias_elimination.jl b/src/systems/alias_elimination.jl index f24a2562fe..da6a11dd02 100644 --- a/src/systems/alias_elimination.jl +++ b/src/systems/alias_elimination.jl @@ -97,7 +97,7 @@ function alias_elimination!(state::TearingState; kwargs...) nvs_orig = ndsts(graph_orig) for ieq in eqs_to_update eq = eqs[ieq] - eqs[ieq] = fast_substitute(eq, subs) + eqs[ieq] = substitute(eq, subs) end @set! mm.nparentrows = nsrcs(graph) @set! mm.row_cols = eltype(mm.row_cols)[mm.row_cols[i] @@ -411,7 +411,7 @@ julia> ModelingToolkit.topsort_equations(eqs, [x, y, z, k]) Equation(x(t), y(t) + z(t)) ``` """ -function topsort_equations(eqs, unknowns; check = true) +function topsort_equations(eqs::Vector{Equation}, unknowns::Vector{SymbolicT}; check = true) graph, assigns = observed2graph(eqs, unknowns) neqs = length(eqs) degrees = zeros(Int, neqs) @@ -426,20 +426,20 @@ function topsort_equations(eqs, unknowns; check = true) q = Queue{Int}(neqs) for (i, d) in enumerate(degrees) - d == 0 && enqueue!(q, i) + d == 0 && push!(q, i) end idx = 0 ordered_eqs = similar(eqs, 0) sizehint!(ordered_eqs, neqs) while !isempty(q) - 𝑠eq = dequeue!(q) + 𝑠eq = popfirst!(q) idx += 1 push!(ordered_eqs, eqs[𝑠eq]) var = assigns[𝑠eq] for 𝑑eq in 𝑑neighbors(graph, var) degree = degrees[𝑑eq] = degrees[𝑑eq] - 1 - degree == 0 && enqueue!(q, 𝑑eq) + degree == 0 && push!(q, 𝑑eq) end end @@ -448,22 +448,25 @@ function topsort_equations(eqs, unknowns; check = true) return ordered_eqs end -function observed2graph(eqs, unknowns) +function observed2graph(eqs::Vector{Equation}, unknowns::Vector{SymbolicT})::Tuple{BipartiteGraph{Int, Nothing}, Vector{Int}} graph = BipartiteGraph(length(eqs), length(unknowns)) - v2j = Dict(unknowns .=> 1:length(unknowns)) + v2j = Dict{SymbolicT, Int}(unknowns .=> 1:length(unknowns)) # `assigns: eq -> var`, `eq` defines `var` assigns = similar(eqs, Int) - + vars = Set{SymbolicT}() for (i, eq) in enumerate(eqs) lhs_j = get(v2j, eq.lhs, nothing) lhs_j === nothing && throw(ArgumentError("The lhs $(eq.lhs) of $eq, doesn't appear in unknowns.")) assigns[i] = lhs_j - vs = vars(eq.rhs; op = Symbolics.Operator) - for v in vs + empty!(vars) + SU.search_variables!(vars, eq.rhs; is_atomic = OperatorIsAtomic{SU.Operator}()) + for v in vars j = get(v2j, v, nothing) - j !== nothing && add_edge!(graph, i, j) + if j isa Int + add_edge!(graph, i, j) + end end end diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index a5a612b9ca..42bf94eb02 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -250,7 +250,7 @@ Remove all `AnalysisPoint`s in `sys` and any of its subsystems, replacing them b """ function remove_analysis_points(sys::AbstractSystem) eqs = map(get_eqs(sys)) do eq - eq.lhs isa AnalysisPoint ? to_connection(eq.rhs) : eq + value(eq.lhs) isa AnalysisPoint ? to_connection(value(eq.rhs)) : eq end @set! sys.eqs = eqs @set! sys.systems = map(remove_analysis_points, get_systems(sys)) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c94166103b..086dd5fa9d 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -7,15 +7,19 @@ end struct SymbolicAffect affect::Vector{Equation} alg_eqs::Vector{Equation} - discrete_parameters::Vector{Any} + discrete_parameters::Vector{SymbolicT} end function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[], - discrete_parameters = Any[], kwargs...) - if !(discrete_parameters isa AbstractVector) - discrete_parameters = Any[discrete_parameters] - elseif !(discrete_parameters isa Vector{Any}) - discrete_parameters = Vector{Any}(discrete_parameters) + discrete_parameters = SymbolicT[], kwargs...) + if symbolic_type(discrete_parameters) !== NotSymbolic() + discrete_parameters = SymbolicT[unwrap(discrete_parameters)] + elseif !(discrete_parameters isa Vector{SymbolicT}) + _discs = SymbolicT[] + for p in discrete_parameters + push!(_discs, unwrap(p)) + end + discrete_parameters = _discs end SymbolicAffect(affect, alg_eqs, discrete_parameters) end @@ -25,34 +29,31 @@ function SymbolicAffect(affect::SymbolicAffect; kwargs...) end SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...) -function Symbolics.fast_substitute(aff::SymbolicAffect, rules) - substituter = Base.Fix2(fast_substitute, rules) - SymbolicAffect(map(substituter, aff.affect), map(substituter, aff.alg_eqs), - map(substituter, aff.discrete_parameters)) +function (s::SymbolicUtils.Substituter)(aff::SymbolicAffect) + SymbolicAffect(s(aff.affect), s(aff.alg_eqs), s(aff.discrete_parameters)) end struct AffectSystem """The internal implicit discrete system whose equations are solved to obtain values after the affect.""" system::AbstractSystem """Unknowns of the parent ODESystem whose values are modified or accessed by the affect.""" - unknowns::Vector + unknowns::Vector{SymbolicT} """Parameters of the parent ODESystem whose values are accessed by the affect.""" - parameters::Vector + parameters::Vector{SymbolicT} """Parameters of the parent ODESystem whose values are modified by the affect.""" - discretes::Vector + discretes::Vector{SymbolicT} end -function Symbolics.fast_substitute(aff::AffectSystem, rules) - substituter = Base.Fix2(fast_substitute, rules) +function (s::SymbolicUtils.Substituter)(aff::AffectSystem) sys = aff.system - @set! sys.eqs = map(substituter, get_eqs(sys)) - @set! sys.parameter_dependencies = map(substituter, get_parameter_dependencies(sys)) - @set! sys.defaults = Dict([k => substituter(v) for (k, v) in defaults(sys)]) - @set! sys.guesses = Dict([k => substituter(v) for (k, v) in guesses(sys)]) - @set! sys.unknowns = map(substituter, get_unknowns(sys)) - @set! sys.ps = map(substituter, get_ps(sys)) - AffectSystem(sys, map(substituter, aff.unknowns), - map(substituter, aff.parameters), map(substituter, aff.discretes)) + @set! sys.eqs = s(get_eqs(sys)) + @set! sys.parameter_dependencies = (get_parameter_dependencies(sys)) + @set! sys.defaults = Dict([k => s(v) for (k, v) in defaults(sys)]) + @set! sys.guesses = Dict([k => s(v) for (k, v) in guesses(sys)]) + @set! sys.unknowns = s(get_unknowns(sys)) + @set! sys.ps = s(get_ps(sys)) + AffectSystem(sys, s(aff.unknowns), s(aff.parameters), s(aff.discretes)) + end function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...) @@ -60,7 +61,11 @@ function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], discrete_parameters = spec.discrete_parameters, kwargs...) end -function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], +@noinline function warn_algebraic_equation(eq::Equation) + @warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1." +end + +function AffectSystem(affect::Vector{Equation}; discrete_parameters = SymbolicT[], iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...) isempty(affect) && return nothing if isnothing(iv) @@ -68,26 +73,24 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], @warn "No independent variable specified. Defaulting to t_nounits." end - discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters]) - discrete_parameters = unwrap.(discrete_parameters) + discrete_parameters = SymbolicAffect(affect; alg_eqs, discrete_parameters).discrete_parameters for p in discrete_parameters - occursin(unwrap(iv), unwrap(p)) || + SU.query!(isequal(unwrap(iv)), unwrap(p)) || error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).") end - dvs = OrderedSet() - params = OrderedSet() - _varsbuf = Set() + dvs = OrderedSet{SymbolicT}() + params = OrderedSet{SymbolicT}() + _varsbuf = Set{SymbolicT}() for eq in affect - if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() || - symbolic_type(eq.lhs) === NotSymbolic()) - @warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1." + if !haspre(eq) && !(isconst(eq.lhs) && isconst(eq.rhs)) + @invokelatest warn_algebraic_equation(eq) end collect_vars!(dvs, params, eq, iv; op = Pre) empty!(_varsbuf) - vars!(_varsbuf, eq; op = Pre) - filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf) + SU.search_variables!(_varsbuf, eq; is_atomic = OperatorIsAtomic{Pre}()) + filter!(x -> iscall(x) && operation(x) === Pre(), _varsbuf) union!(params, _varsbuf) diffvs = collect_applied_operators(eq, Differential) union!(dvs, diffvs) @@ -95,30 +98,35 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], for eq in alg_eqs collect_vars!(dvs, params, eq, iv) end - pre_params = filter(haspre ∘ value, params) - sys_params = collect(setdiff(params, union(discrete_parameters, pre_params))) + pre_params = filter(haspre, params) + sys_params = SymbolicT[] + disc_ps_set = Set{SymbolicT}(discrete_parameters) + for p in params + p in disc_ps_set && continue + p in pre_params && continue + push!(sys_params, p) + end discretes = map(tovar, discrete_parameters) dvs = collect(dvs) _dvs = map(default_toterm, dvs) - rev_map = Dict(zip(discrete_parameters, discretes)) - subs = merge(rev_map, Dict(zip(dvs, _dvs))) - affect = Symbolics.fast_substitute(affect, subs) - alg_eqs = Symbolics.fast_substitute(alg_eqs, subs) + rev_map = Dict{SymbolicT, SymbolicT}(zip(discrete_parameters, discretes)) + subs = merge(rev_map, Dict{SymbolicT, SymbolicT}(zip(dvs, _dvs))) + affect = substitute(affect, subs) + alg_eqs = substitute(alg_eqs, subs) @named affectsys = System( vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)), collect(union(pre_params, sys_params)); is_discrete = true) affectsys = mtkcompile(affectsys; fully_determined = nothing) # get accessed parameters p from Pre(p) in the callback parameters - accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params)))) + accessed_params = Vector{SymbolicT}(filter(isparameter, map(unPre, collect(pre_params)))) union!(accessed_params, sys_params) # add scalarized unknowns to the map. - _dvs = reduce(vcat, map(scalarize, _dvs), init = Any[]) + _dvs = reduce(vcat, map(scalarize, _dvs), init = SymbolicT[]) - AffectSystem(affectsys, collect(_dvs), collect(accessed_params), - collect(discrete_parameters)) + AffectSystem(affectsys, _dvs, accessed_params, discrete_parameters) end system(a::AffectSystem) = a.system @@ -169,7 +177,7 @@ Base.nameof(::Pre) = :Pre Base.show(io::IO, x::Pre) = print(io, "Pre") unPre(x::Num) = unPre(unwrap(x)) unPre(x::Symbolics.Arr) = unPre(unwrap(x)) -unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x +unPre(x::SymbolicT) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x function (p::Pre)(x) iw = Symbolics.iswrapped(x) @@ -186,14 +194,15 @@ function (p::Pre)(x) iscall(x) && operation(x) isa Pre && return x result = if symbolic_type(x) == ArraySymbolic() # create an array for `Pre(array)` + term(p, x; type = symtype(x), shape = SU.shape(x)) Symbolics.array_term(p, x) elseif iscall(x) && operation(x) == getindex # instead of `Pre(x[1])` create `Pre(x)[1]` # which allows parameter indexing to handle this case automatically. arr = arguments(x)[1] - term(getindex, p(arr), arguments(x)[2:end]...) + p(arr)[arguments(x)[2:end]...] else - term(p, x) + term(p, x; type = symtype(x), shape = SU.shape(x)) end # the result should be a parameter result = toparam(result) @@ -420,14 +429,14 @@ A callback that triggers at the first timestep that the conditions are satisfied The condition can be one of: - Δt::Real - periodic events with period Δt - ts::Vector{Real} - events trigger at these preset times given by `ts` -- eqs::Vector{Symbolic} - events trigger when the condition evaluates to true +- eqs::Vector{SymbolicT} - events trigger when the condition evaluates to true Arguments: - iv: The independent variable of the system. This must be specified if the independent variable appears in one of the equations explicitly, as in x ~ t + 1. - alg_eqs: Algebraic equations of the system that must be satisfied after the callback occurs. """ struct SymbolicDiscreteCallback <: AbstractCallback - conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}} + conditions::Union{Number, Vector{<:Number}, SymbolicT} affect::Union{Affect, SymbolicAffect, Nothing} initialize::Union{Affect, SymbolicAffect, Nothing} finalize::Union{Affect, SymbolicAffect, Nothing} @@ -435,9 +444,10 @@ struct SymbolicDiscreteCallback <: AbstractCallback end function SymbolicDiscreteCallback( - condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing; + condition::Union{SymbolicT, Number, Vector{<:Number}}, affect = nothing; initialize = nothing, finalize = nothing, reinitializealg = nothing, kwargs...) + @assert !(condition isa SymbolicT && symtype(condition) != Bool) c = is_timed_condition(condition) ? condition : value(scalarize(condition)) if isnothing(reinitializealg) @@ -569,6 +579,9 @@ conditions(cb::AbstractCallback) = cb.conditions function conditions(cbs::Vector{<:AbstractCallback}) reduce(vcat, conditions(cb) for cb in cbs; init = []) end +function conditions(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(conditions, vcat, cbs; init = Equation[]) +end equations(cb::AbstractCallback) = conditions(cb) equations(cb::Vector{<:AbstractCallback}) = conditions(cb) @@ -871,7 +884,7 @@ function default_operating_point(affsys::AffectSystem) T = symtype(p) if T <: Number op[p] = false - elseif T <: Array{<:Real} && is_sized_array_symbolic(p) + elseif T <: Array{<:Real} && symbolic_has_known_size(p) op[p] = zeros(size(p)) end end @@ -897,7 +910,7 @@ function compile_equational_affect( obseqs, eqs = unhack_observed(observed(affsys), equations(affsys)) if isempty(equations(affsys)) - update_eqs = Symbolics.fast_substitute( + update_eqs = substitute( obseqs, Dict([p => unPre(p) for p in parameters(affsys)])) rhss = map(x -> x.rhs, update_eqs) lhss = map(x -> x.lhs, update_eqs) diff --git a/src/systems/codegen.jl b/src/systems/codegen.jl index 7719fbcdaa..02a2dbd462 100644 --- a/src/systems/codegen.jl +++ b/src/systems/codegen.jl @@ -561,7 +561,7 @@ function generate_boundary_conditions(sys::System, u0, u0_idxs, t0; expression = cons = [con.lhs - con.rhs for con in constraints(sys)] # conssubs = Dict() # get_constraint_unknown_subs!(conssubs, cons, stidxmap, iv, sol) - # cons = map(x -> fast_substitute(x, conssubs), cons) + # cons = map(x -> substitute(x, conssubs), cons) init_conds = Any[] for i in u0_idxs @@ -1065,7 +1065,7 @@ function build_explicit_observed_function(sys, ts; Base.throw(ArgumentError("Symbol $var is not present in the system.")) end end - ts = fast_substitute(ts, namespace_subs) + ts = substitute(ts, namespace_subs) obsfilter = if param_only if is_split(sys) diff --git a/src/systems/codegen_utils.jl b/src/systems/codegen_utils.jl index dbbd7f85a8..b77bba98b4 100644 --- a/src/systems/codegen_utils.jl +++ b/src/systems/codegen_utils.jl @@ -135,8 +135,8 @@ end """ The argument of generated functions corresponding to the history function. """ -const DDE_HISTORY_FUN = Sym{Symbolics.FnType{Tuple{Any, <:Real}, Vector{Real}}}(:___history___) -const BVP_SOLUTION = Sym{Symbolics.FnType{Tuple{<:Real}, Vector{Real}}}(:__sol__) +const DDE_HISTORY_FUN = SSym(:___history___; type = SU.FnType{Tuple{Any, <:Real}, Vector{Real}}, shape = SU.Unknown(1)) +const BVP_SOLUTION = SSym(:__sol__; type = Symbolics.FnType{Tuple{<:Real}, Vector{Real}}, shape = SU.Unknown(1)) """ $(TYPEDSIGNATURES) diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl index 99110e37e9..27e542c2ab 100644 --- a/src/systems/connectiongraph.jl +++ b/src/systems/connectiongraph.jl @@ -455,7 +455,7 @@ function connectionsets(graph::HyperGraph{V}) where {V} invmap = graph.invmap # union all of the hyperedges - disjoint_sets = IntDisjointSets(length(invmap)) + disjoint_sets = IntDisjointSet(length(invmap)) for edge_i in 𝑠vertices(bigraph) hyperedge = 𝑠neighbors(bigraph, edge_i) isempty(hyperedge) && continue diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index c0ddf5baee..95ca70e6be 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -72,7 +72,7 @@ end Get the connection type of symbolic variable `s` from the `VariableConnectType` metadata. Defaults to `Equality` if not present. """ -function get_connection_type(s::Symbolic) +function get_connection_type(s::SymbolicT) s = unwrap(s) if iscall(s) && operation(s) === getindex s = arguments(s)[1] @@ -632,8 +632,8 @@ function returned from `generate_isouter`. """ function handle_maybe_connect_equation!(eqs, state::AbstractConnectionState, eq::Equation, namespace::Vector{Symbol}, isouter) - lhs = eq.lhs - rhs = eq.rhs + lhs = value(eq.lhs) + rhs = value(eq.rhs) if !(lhs isa Connection) # split connections and equations @@ -948,8 +948,7 @@ function expand_instream(csets::Vector{Vector{ConnectionVertex}}, sys::AbstractS stream_var = only(arguments(expr)) iscall(stream_var) && operation(stream_var) === getindex || continue args = arguments(stream_var) - new_expr = Symbolics.array_term( - instream, args[1]; size = size(args[1]), ndims = ndims(args[1]))[args[2:end]...] + new_expr = term(instream, args[1]; type = symtype(args[1]), shape = SU.shape(args[1]))[args[2:end]...] instream_subs[expr] = new_expr end @@ -1116,4 +1115,4 @@ function instream_rt(ins::Val{inner_n}, outs::Val{outer_n}, for k in 1:M and ck.m_flow.max > 0 =# end -SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Vararg) = Real +SymbolicUtils.promote_symtype(::typeof(instream_rt), ::Type{T}, ::Type{S}, ::Type{R}) where {T, S, R} = Real diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index f823260e2a..c1cce1cb27 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -142,7 +142,7 @@ function change_of_variables( for (new_var, ex, first, second) in zip(new_vars, dfdt, ∂f∂x, ∂2f∂x2) for (eqs, neq) in zip(old_eqs, neqs) - if occursin(value(eqs.lhs), value(ex)) + if SU.query!(isequal(value(eqs.lhs)), value(ex)) ex = substitute(ex, eqs.lhs => eqs.rhs) if isSDE for (noise, B) in zip(neq, brownvars) @@ -470,7 +470,7 @@ julia> M = change_independent_variable(M, x); julia> M = mtkcompile(M; allow_symbolic = true); julia> unknowns(M) -3-element Vector{SymbolicUtils.BasicSymbolic{Real}}: +3-element Vector{Symbolics.SymbolicsT}: xˍt(x) y(x) yˍx(x) @@ -1039,9 +1039,9 @@ function respecialize(sys::AbstractSystem, mapping; all = false) if iscall(k) op = operation(k) args = arguments(k) - new_p = SymbolicUtils.term(op, args...; type = T) + new_p = SymbolicUtils.term(op, args...; type = T, shape = SU.shape(v)) else - new_p = SymbolicUtils.Sym{T}(getname(k)) + new_p = SSym(getname(k); type = T, shape = SU.shape(v)) end get_ps(sys)[idx] = new_p @@ -1049,7 +1049,7 @@ function respecialize(sys::AbstractSystem, mapping; all = false) subrules[unwrap(k)] = unwrap(new_p) end - substituter = Base.Fix2(fast_substitute, subrules) + substituter = Base.Fix2(substitute, subrules) @set! sys.eqs = map(substituter, get_eqs(sys)) @set! sys.observed = map(substituter, get_observed(sys)) @set! sys.initialization_eqs = map(substituter, get_initialization_eqs(sys)) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 1c43022f4b..7fc0c6abe1 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -67,10 +67,10 @@ function ImperativeAffect(; f, kwargs...) ImperativeAffect(f; kwargs...) end -function Symbolics.fast_substitute(aff::ImperativeAffect, rules) - substituter = Base.Fix2(fast_substitute, rules) - ImperativeAffect(aff.f, map(substituter, aff.obs), aff.obs_syms, - map(substituter, aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) +function (s::SymbolicUtils.Substituter)(aff::ImperativeAffect) + ImperativeAffect(aff.f, s(aff.obs), aff.obs_syms, + s(aff.modified), aff.mod_syms, aff.ctx, aff.skip_checks) + end function Base.show(io::IO, mfa::ImperativeAffect) @@ -85,10 +85,16 @@ context(a::ImperativeAffect) = a.ctx observed(a::ImperativeAffect) = a.obs observed_syms(a::ImperativeAffect) = a.obs_syms function discretes(a::ImperativeAffect) - Iterators.filter(ModelingToolkit.isparameter, - Iterators.flatten(Iterators.map( - x -> symbolic_type(x) == NotSymbolic() && x isa AbstractArray ? x : [x], - a.modified))) + discs = SymbolicT[] + for val in a.modified + val = unwrap(val) + if val isa SymbolicT + isparameter(a) && push!(discs, val) + elseif val isa AbstractArray + append!(discs, filter(isparameter, map(unwrap, val))) + end + end + return discs end modified(a::ImperativeAffect) = a.modified modified_syms(a::ImperativeAffect) = a.mod_syms diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 19c78413cf..90c8859a98 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -1,13 +1,10 @@ +const TypeT = Union{DataType, UnionAll, Union} + struct BufferTemplate - type::Union{DataType, UnionAll, Union} + type::TypeT length::Int end -function BufferTemplate(s::Type{<:Symbolics.Struct}, length::Int) - T = Symbolics.juliatype(s) - BufferTemplate(T, length) -end - struct Nonnumeric <: SciMLStructures.AbstractPortion end const NONNUMERIC_PORTION = Nonnumeric() @@ -31,16 +28,15 @@ struct DiscreteIndex idx_in_clock::Int end -const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}} -const NonnumericMap = Dict{ - Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}} -const UnknownIndexMap = Dict{ - BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}} -const TunableIndexMap = Dict{BasicSymbolic, - Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}} +const MaybeUnknownArrayIndexT = Union{Int, UnitRange{Int}, AbstractArray{Int}} +const MaybeArrayIndexT = Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} +const ParamIndexMap = Dict{SymbolicT, Tuple{Int, Int}} +const NonnumericMap = Dict{SymbolicT, Tuple{Int, Int}} +const UnknownIndexMap = Dict{SymbolicT, MaybeUnknownArrayIndexT} +const TunableIndexMap = Dict{SymbolicT, MaybeArrayIndexT} const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}} -const SymbolicParam = Union{BasicSymbolic, CallWithMetadata} +const SymbolicParam = SymbolicT struct IndexCache unknown_idx::UnknownIndexMap @@ -52,9 +48,8 @@ struct IndexCache initials_idx::TunableIndexMap constant_idx::ParamIndexMap nonnumeric_idx::NonnumericMap - observed_syms_to_timeseries::Dict{BasicSymbolic, TimeseriesSetType} - dependent_pars_to_timeseries::Dict{ - Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType} + observed_syms_to_timeseries::Dict{SymbolicT, TimeseriesSetType} + dependent_pars_to_timeseries::Dict{SymbolicT, TimeseriesSetType} discrete_buffer_sizes::Vector{Vector{BufferTemplate}} tunable_buffer_size::BufferTemplate initials_buffer_size::BufferTemplate @@ -70,26 +65,36 @@ function IndexCache(sys::AbstractSystem) let idx = 1 for sym in unks - usym = unwrap(sym) - rsym = renamespace(sys, usym) - sym_idx = if Symbolics.isarraysymbolic(sym) + rsym = renamespace(sys, sym) + sym_idx::MaybeUnknownArrayIndexT = if Symbolics.isarraysymbolic(sym) reshape(idx:(idx + length(sym) - 1), size(sym)) else idx end - unk_idxs[usym] = sym_idx + unk_idxs[sym] = sym_idx unk_idxs[rsym] = sym_idx idx += length(sym) end + found_array_syms = Set{SymbolicT}() for sym in unks - usym = unwrap(sym) iscall(sym) && operation(sym) === getindex || continue arrsym = arguments(sym)[1] - all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue - - idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)] + arrsym in found_array_syms && continue + idxs = Int[] + valid_arrsym = true + for i in eachindex(arrsym) + idxsym = arrsym[i] + idx = get(unk_idxs, idxsym, nothing)::Union{Int, Nothing} + valid_arrsym = idx !== nothing + valid_arrsym || break + push!(idxs, idx) + end + push!(found_array_syms, arrsym) + valid_arrsym || break if idxs == idxs[begin]:idxs[end] - idxs = reshape(idxs[begin]:idxs[end], size(idxs)) + idxs = reshape(idxs[begin]:idxs[end], size(idxs))::Array{Int} + else + idxs = reshape(idxs, size(arrsym))::Array{Int} end rsym = renamespace(sys, arrsym) unk_idxs[arrsym] = idxs @@ -97,62 +102,24 @@ function IndexCache(sys::AbstractSystem) end end - tunable_pars = BasicSymbolic[] - initial_pars = BasicSymbolic[] - constant_buffers = Dict{Any, Set{BasicSymbolic}}() - nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}() - - function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S} - sym = unwrap(sym) - buf = get!(buffers, ctype, S()) - push!(buf, sym) - end - function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype) - sym = unwrap(sym) - push!(buffers, sym) - end - - disc_param_callbacks = Dict{SymbolicParam, Set{Int}}() - events = vcat(continuous_events(sys), discrete_events(sys)) - for (i, event) in enumerate(events) - discs = Set{SymbolicParam}() - affs = affects(event) - if !(affs isa AbstractArray) - affs = [affs] - end - for affect in affs - if affect isa AffectSystem || affect isa ImperativeAffect - union!(discs, unwrap.(discretes(affect))) - elseif isnothing(affect) - continue - else - error("Unhandled affect type $(typeof(affect))") - end - end - - for sym in discs - is_parameter(sys, sym) || - error("Expected discrete variable $sym in callback to be a parameter") - - # Only `foo(t)`-esque parameters can be saved - if iscall(sym) && length(arguments(sym)) == 1 && - isequal(only(arguments(sym)), get_iv(sys)) - clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym) - push!(clocks, i) - elseif is_variable_floatingpoint(sym) - insert_by_type!(constant_buffers, sym, symtype(sym)) - else - stype = symtype(sym) - if stype <: FnType - stype = fntype_to_function_type(stype) - end - insert_by_type!(nonnumeric_buffers, sym, stype) - end - end - end - clock_partitions = unique(collect(values(disc_param_callbacks))) - disc_symtypes = unique(symtype.(keys(disc_param_callbacks))) - disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes)) + tunable_pars = SymbolicT[] + initial_pars = SymbolicT[] + constant_buffers = Dict{TypeT, Set{SymbolicT}}() + nonnumeric_buffers = Dict{TypeT, Set{SymbolicT}}() + + disc_param_callbacks = Dict{SymbolicParam, BitSet}() + cevs = continuous_events(sys) + devs = discrete_events(sys) + events = Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}[cevs; devs] + parse_callbacks_for_discretes!(cevs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, 0) + parse_callbacks_for_discretes!(devs, disc_param_callbacks, constant_buffers, nonnumeric_buffers, length(cevs)) + clock_partitions = unique(collect(values(disc_param_callbacks)))::Vector{BitSet} + disc_symtypes = Set{TypeT}() + for x in keys(disc_param_callbacks) + push!(disc_symtypes, symtype(x)) + end + disc_symtypes = collect(disc_symtypes)::Vector{TypeT} + disc_symtype_idx = Dict{TypeT, Int}(zip(disc_symtypes, eachindex(disc_symtypes))) disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes] for sym in keys(disc_param_callbacks) push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym) @@ -160,13 +127,12 @@ function IndexCache(sys::AbstractSystem) disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes] for (i, buffer) in enumerate(disc_syms_by_symtype) for partition in clock_partitions - push!(disc_syms_by_symtype_by_partition[i], - [sym for sym in buffer if disc_param_callbacks[sym] == partition]) + push!(disc_syms_by_symtype_by_partition[i], filter(==(partition) ∘ Base.Fix1(getindex, disc_param_callbacks), buffer)) end end disc_idxs = Dict{SymbolicParam, DiscreteIndex}() callback_to_clocks = Dict{ - Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}() + Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, BitSet}() for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition) symi = 0 for (parti, disc_syms) in enumerate(disc_syms_by_partition) @@ -194,26 +160,24 @@ function IndexCache(sys::AbstractSystem) disc_buffer_templates = Vector{BufferTemplate}[] for (symtype, disc_syms_by_partition) in zip( disc_symtypes, disc_syms_by_symtype_by_partition) - push!(disc_buffer_templates, - [BufferTemplate(symtype, length(buf)) for buf in disc_syms_by_partition]) + push!(disc_buffer_templates, map(Base.Fix1(BufferTemplate, symtype) ∘ length, disc_syms_by_partition)) end for p in parameters(sys; initial_parameters = true) - p = unwrap(p) ctype = symtype(p) if ctype <: FnType - ctype = fntype_to_function_type(ctype) + ctype = fntype_to_function_type(ctype)::TypeT end haskey(disc_idxs, p) && continue haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue insert_by_type!( if ctype <: Real || ctype <: AbstractArray{<:Real} - if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() && + if istunable(p, true) && symbolic_has_known_size(p) && (ctype == Real || ctype <: AbstractFloat || ctype <: AbstractArray{Real} || ctype <: AbstractArray{<:AbstractFloat}) - if iscall(p) && operation(p) isa Initial + if iscall(p) && operation(p) === Initial() initial_pars else tunable_pars @@ -229,33 +193,10 @@ function IndexCache(sys::AbstractSystem) ) end - function get_buffer_sizes_and_idxs(T, buffers::Dict) - idxs = T() - buffer_sizes = BufferTemplate[] - for (i, (T, buf)) in enumerate(buffers) - for (j, p) in enumerate(buf) - ttp = default_toterm(p) - rp = renamespace(sys, p) - rttp = renamespace(sys, ttp) - idxs[p] = (i, j) - idxs[ttp] = (i, j) - idxs[rp] = (i, j) - idxs[rttp] = (i, j) - end - if T <: Symbolics.FnType - T = Any - end - push!(buffer_sizes, BufferTemplate(T, length(buf))) - end - return idxs, buffer_sizes - end - const_idxs, - const_buffer_sizes = get_buffer_sizes_and_idxs( - ParamIndexMap, constant_buffers) + const_buffer_sizes = get_buffer_sizes_and_idxs(ParamIndexMap, constant_buffers) nonnumeric_idxs, - nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs( - NonnumericMap, nonnumeric_buffers) + nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(NonnumericMap, nonnumeric_buffers) tunable_idxs = TunableIndexMap() tunable_buffer_size = 0 @@ -264,7 +205,8 @@ function IndexCache(sys::AbstractSystem) empty!(initial_pars) end for p in tunable_pars - idx = if size(p) == () + sh = SU.shape(p) + idx = if !SU.is_array_shape(sh) tunable_buffer_size + 1 else reshape( @@ -282,7 +224,8 @@ function IndexCache(sys::AbstractSystem) initials_idxs = TunableIndexMap() initials_buffer_size = 0 for p in initial_pars - idx = if size(p) == () + sh = SU.shape(p) + idx = if !SU.is_array_shape(sh) initials_buffer_size + 1 else reshape( @@ -300,24 +243,27 @@ function IndexCache(sys::AbstractSystem) for k in collect(keys(tunable_idxs)) v = tunable_idxs[k] v isa AbstractArray || continue - for (kk, vv) in zip(collect(k), v) + v = v::Union{UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} + iter = vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT} + for (kk::SymbolicT, vv) in zip(iter, v) tunable_idxs[kk] = vv end end for k in collect(keys(initials_idxs)) v = initials_idxs[k] v isa AbstractArray || continue - for (kk, vv) in zip(collect(k), v) + v = v::Union{UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}} + iter = vec(collect(k)::Array{SymbolicT})::Vector{SymbolicT} + for (kk, vv) in zip(iter, v) initials_idxs[kk] = vv end end - dependent_pars_to_timeseries = Dict{ - Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}() - + dependent_pars_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() + vs = Set{SymbolicT}() for eq in get_parameter_dependencies(sys) sym = eq.lhs - vs = vars(eq.rhs) + SU.search_variables!(vs, eq.rhs) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs @@ -331,24 +277,29 @@ function IndexCache(sys::AbstractSystem) rttsym = renamespace(sys, ttsym) for s in (sym, ttsym, rsym, rttsym) dependent_pars_to_timeseries[s] = timeseries - if hasname(s) && (!iscall(s) || operation(s) != getindex) + if hasname(s) && (!iscall(s) || operation(s) !== getindex) symbol_to_variable[getname(s)] = sym end end end - observed_syms_to_timeseries = Dict{BasicSymbolic, TimeseriesSetType}() + observed_syms_to_timeseries = Dict{SymbolicT, TimeseriesSetType}() for eq in observed(sys) if symbolic_type(eq.lhs) != NotSymbolic() sym = eq.lhs - vs = vars(eq.rhs; op = Nothing) + empty!(vs) + SU.search_variables!(vs, eq.rhs) timeseries = TimeseriesSetType() if is_time_dependent(sys) for v in vs if (idx = get(disc_idxs, v, nothing)) !== nothing push!(timeseries, idx.clock_idx) - elseif iscall(v) && operation(v) === getindex && - (idx = get(disc_idxs, arguments(v)[1], nothing)) !== nothing + elseif Moshi.Match.@match v begin + BSImpl.Term(; f, args) => begin + f === getindex && (idx = get(disc_idxs, args[1], nothing)) !== nothing + end + _ => false + end push!(timeseries, idx.clock_idx) elseif haskey(observed_syms_to_timeseries, v) union!(timeseries, observed_syms_to_timeseries[v]) @@ -369,13 +320,12 @@ function IndexCache(sys::AbstractSystem) end end - for sym in Iterators.flatten((keys(unk_idxs), keys(disc_idxs), keys(tunable_idxs), - keys(const_idxs), keys(nonnumeric_idxs), - keys(observed_syms_to_timeseries), independent_variable_symbols(sys))) - if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) - symbol_to_variable[getname(sym)] = sym - end - end + populate_symbol_to_var!(symbol_to_variable, keys(unk_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(disc_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(tunable_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(const_idxs)) + populate_symbol_to_var!(symbol_to_variable, keys(nonnumeric_idxs)) + populate_symbol_to_var!(symbol_to_variable, independent_variable_symbols(sys)) return IndexCache( unk_idxs, @@ -396,6 +346,80 @@ function IndexCache(sys::AbstractSystem) ) end +function populate_symbol_to_var!(symbol_to_variable::Dict{Symbol, SymbolicT}, vars) + for sym::SymbolicT in vars + if hasname(sym) && (!iscall(sym) || operation(sym) !== getindex) + symbol_to_variable[getname(sym)] = sym + end + end +end + +""" + $TYPEDSIGNATURES + +Utility function for the `IndexCache` constructor. +""" +function insert_by_type!(buffers::Dict{TypeT, Set{SymbolicT}}, sym::SymbolicT, ctype::TypeT) + buf = get!(buffers, ctype, S()) + push!(buf, sym) +end +function insert_by_type!(buffers::Vector{SymbolicT}, sym::SymbolicT, ::TypeT) + push!(buffers, sym) +end + +function parse_callbacks_for_discretes!(events::Vector, disc_param_callbacks::Dict{SymbolicT, BitSet}, constant_buffers::Dict{TypeT, Set{SymbolicT}}, nonnumeric_buffers::Dict{TypeT, Set{SymbolicT}}, offset::Int) + for (i, event) in enumerate(events) + discs = Set{SymbolicParam}() + affect = event.affect::Union{AffectSystem, ImperativeAffect, Nothing} + if affect isa AffectSystem || affect isa ImperativeAffect + union!(discs, discretes(affect)) + elseif affect === nothing + continue + end + + for sym in discs + is_parameter(sys, sym) || + error("Expected discrete variable $sym in callback to be a parameter") + + # Only `foo(t)`-esque parameters can be saved + if iscall(sym) && length(arguments(sym)) == 1 && + isequal(only(arguments(sym)), get_iv(sys)) + clocks = get!(BitSet, disc_param_callbacks, sym) + push!(clocks, i + offset) + elseif is_variable_floatingpoint(sym) + insert_by_type!(constant_buffers, sym, symtype(sym)) + else + stype = symtype(sym) + if stype <: FnType + stype = fntype_to_function_type(stype)::TypeT + end + insert_by_type!(nonnumeric_buffers, sym, stype) + end + end + end +end + +function get_buffer_sizes_and_idxs(::Type{BufT}, buffers::Dict) where {BufT} + idxs = BufT() + buffer_sizes = BufferTemplate[] + for (i, (T, buf)) in enumerate(buffers) + for (j, p) in enumerate(buf) + ttp = default_toterm(p) + rp = renamespace(sys, p) + rttp = renamespace(sys, ttp) + idxs[p] = (i, j) + idxs[ttp] = (i, j) + idxs[rp] = (i, j) + idxs[rttp] = (i, j) + end + if T <: Symbolics.FnType + T = Any + end + push!(buffer_sizes, BufferTemplate(T, length(buf))) + end + return idxs, buffer_sizes +end + function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym) variable_index(ic, sym) !== nothing end @@ -418,14 +442,16 @@ function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym) parameter_index(ic, sym) !== nothing end -function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym) - if sym isa Symbol - sym = get(ic.symbol_to_variable, sym, nothing) - sym === nothing && return nothing - end - sym = unwrap(sym) - validate_size = Symbolics.isarraysymbolic(sym) && symtype(sym) <: AbstractArray && - Symbolics.shape(sym) !== Symbolics.Unknown() +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Union{Num, Symbolics.Arr, Symbolics.CallAndWrap}) + parameter_index(ic, unwrap(sym)) +end +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::Symbol) + sym = get(ic.symbol_to_variable, sym, nothing) + sym === nothing && return nothing + parameter_index(ic, sym) +end +function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym::SymbolicT) + validate_size = Symbolics.isarraysymbolic(sym) && symbolic_has_known_size(sym) return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing ParameterIndex(SciMLStructures.Tunable(), idx, validate_size) elseif (idx = check_index_map(ic.initials_idx, sym)) !== nothing @@ -472,29 +498,20 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...)) end -function check_index_map(idxmap, sym) - if (idx = get(idxmap, sym, nothing)) !== nothing - return idx - elseif !isa(sym, Symbol) && (!iscall(sym) || operation(sym) !== getindex) && - hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing - return idx - end +function check_index_map(idxmap::Dict{SymbolicT, V}, sym::SymbolicT)::Union{V, Nothing} where {V} + idx = get(idxmap, sym, nothing) + idx === nothing || return idx dsym = default_toterm(sym) isequal(sym, dsym) && return nothing - if (idx = get(idxmap, dsym, nothing)) !== nothing - idx - elseif !isa(dsym, Symbol) && (!iscall(dsym) || operation(dsym) !== getindex) && - hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing - idx - else - nothing - end + idx = get(idxmap, dsym, nothing) + idx === nothing || return idx + return nothing end function reorder_parameters( sys::AbstractSystem, ps = parameters(sys; initial_parameters = true); kwargs...) if has_index_cache(sys) && get_index_cache(sys) !== nothing - reorder_parameters(get_index_cache(sys), ps; kwargs...) + reorder_parameters(get_index_cache(sys)::IndexCache, ps; kwargs...) elseif ps isa Tuple ps else @@ -502,47 +519,54 @@ function reorder_parameters( end end -function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = true) +const COMMON_DEFAULT_VAR = unwrap(only(@variables __DEF__)) + +function reorder_parameters(ic::IndexCache, ps::Vector{SymbolicT}; drop_missing = false, flatten = true) isempty(ps) && return () - param_buf = if ic.tunable_buffer_size.length == 0 - () - else - (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.tunable_buffer_size.length)],) + result = Vector{Union{Vector{SymbolicT}, Vector{Vector{SymbolicT}}}}() + param_buf = fill(COMMON_DEFAULT_VAR, ic.tunable_buffer_size.length) + push!(result, param_buf) + initials_buf = fill(COMMON_DEFAULT_VAR, ic.initials_buffer_size.length) + push!(result, initials_buf) + + disc_buf = Vector{SymbolicT}[] + for bufszs in ic.discrete_buffer_sizes + push!(disc_buf, fill(COMMON_DEFAULT_VAR, sum(x -> x.length, bufszs))) end - initials_buf = if ic.initials_buffer_size.length == 0 - () + const_buf = Vector{SymbolicT}[] + for bufsz in ic.constant_buffer_sizes + push!(const_buf, fill(COMMON_DEFAULT_VAR, bufsz.length)) + end + nonnumeric_buf = Vector{SymbolicT}[] + for bufsz in ic.nonnumeric_buffer_sizes + push!(nonnumeric_buf, fill(COMMON_DEFAULT_VAR, bufsz.length)) + end + if flatten + append!(result, disc_buf) + append!(result, const_buf) + append!(result, nonnumeric_buf) else - (BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(ic.initials_buffer_size.length)],) - end - - disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) - for _ in 1:(sum(x -> x.length, temp))] - for temp in ic.discrete_buffer_sizes) - const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)] - for temp in ic.constant_buffer_sizes) - nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF)) - for _ in 1:(temp.length)] - for temp in ic.nonnumeric_buffer_sizes) + push!(result, disc_buf) + push!(result, const_buf) + push!(result, nonnumeric_buf) + end for p in ps - p = unwrap(p) if haskey(ic.discrete_idx, p) idx = ic.discrete_idx[p] disc_buf[idx.buffer_idx][idx.idx_in_buffer] = p elseif haskey(ic.tunable_idx, p) i = ic.tunable_idx[p] if i isa Int - param_buf[1][i] = unwrap(p) + param_buf[i] = p else - param_buf[1][i] = unwrap.(collect(p)) + param_buf[i] = collect(p) end elseif haskey(ic.initials_idx, p) i = ic.initials_idx[p] if i isa Int - initials_buf[1][i] = unwrap(p) + initials_buf[i] = p else - initials_buf[1][i] = unwrap.(collect(p)) + initials_buf[i] = collect(p) end elseif haskey(ic.constant_idx, p) i, j = ic.constant_idx[p] @@ -555,37 +579,20 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false, flatten = end end - param_buf = broadcast.(unwrap, param_buf) - initials_buf = broadcast.(unwrap, initials_buf) - disc_buf = broadcast.(unwrap, disc_buf) - const_buf = broadcast.(unwrap, const_buf) - nonnumeric_buf = broadcast.(unwrap, nonnumeric_buf) - if drop_missing - filterer = !isequal(unwrap(variable(:DEF))) - param_buf = filter.(filterer, param_buf) - initials_buf = filter.(filterer, initials_buf) - disc_buf = filter.(filterer, disc_buf) - const_buf = filter.(filterer, const_buf) - nonnumeric_buf = filter.(filterer, nonnumeric_buf) - end - - if flatten - result = ( - param_buf..., initials_buf..., disc_buf..., const_buf..., nonnumeric_buf...) - if all(isempty, result) - return () - end - return result - else - if isempty(param_buf) - param_buf = ((),) - end - if isempty(initials_buf) - initials_buf = ((),) + filterer = !isequal(COMMON_DEFAULT_VAR) + for inner in result + if inner isa Vector{SymbolicT} + filter!(filterer, inner) + elseif inner isa Vector{Vector{SymbolicT}} + for buf in inner + filter!(filterer, buf) + end + end end - return (param_buf..., initials_buf..., disc_buf, const_buf, nonnumeric_buf) end + + return result end # Given a parameter index, find the index of the buffer it is in when diff --git a/src/systems/model_parsing.jl b/src/systems/model_parsing.jl index 699cfee8fd..0369b12212 100644 --- a/src/systems/model_parsing.jl +++ b/src/systems/model_parsing.jl @@ -629,7 +629,7 @@ function _set_var_metadata!(metadata_with_exprs, a, m, v::Expr) a end function _set_var_metadata!(metadata_with_exprs, a, m, v) - wrap(set_scalar_metadata(unwrap(a), m, v)) + wrap(setmetadata(unwrap(a), m, v)) end function set_var_metadata(a, ms) diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 96c00411ad..80d0a4792d 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -1,5 +1,5 @@ function contains_variable(x, wrt) - any(y -> occursin(y, x), wrt) + any(y -> SU.query!(isequal(y), x), wrt) end """ @@ -270,7 +270,7 @@ function PolynomialTransformation(sys::System) transformation_err = nothing for t in all_non_poly_terms # if the term involves multiple unknowns, we can't invert it - dvs_in_term = map(x -> occursin(x, t), dvs) + dvs_in_term = map(x -> SU.query!(isequal(x), t), dvs) if count(dvs_in_term) > 1 transformation_err = MultivarTerm(t, dvs[dvs_in_term]) is_poly = false @@ -369,7 +369,7 @@ function transform_system(sys::System, transformation::PolynomialTransformation; t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs)) # the substituted variable occurs outside the substituted term poly_and_nonpoly = map(dvs) do x - all(!isequal(x), new_dvs) && occursin(x, t) + all(!isequal(x), new_dvs) && SU.query!(isequal(x), t) end if any(poly_and_nonpoly) return NotPolynomialError( diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index f377f0202f..b318b7392a 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -271,8 +271,7 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem; vars!(vs, eq; op = Initial) allpars = full_parameters(sys) for p in allpars - if symbolic_type(p) == ArraySymbolic() && - Symbolics.shape(p) != Symbolics.Unknown() + if symbolic_type(p) == ArraySymbolic() && SU.shape(p) isa SU.Unknown append!(allpars, Symbolics.scalarize(p)) end end @@ -502,7 +501,7 @@ function get_possibly_array_fallback_singletons(varmap, p) return varmap[p] end if symbolic_type(p) == ArraySymbolic() - is_sized_array_symbolic(p) || return nothing + symbolic_has_known_size(p) || return nothing scal = collect(p) if all(x -> haskey(varmap, x), scal) res = [varmap[x] for x in scal] @@ -824,30 +823,19 @@ Counteracts the CSE/array variable hacks in `symbolics_tearing.jl` so it works w initialization. """ function unhack_observed(obseqs::Vector{Equation}, eqs::Vector{Equation}) - subs = Dict() - tempvars = Set() - rm_idxs = Int[] + subs = Dict{SymbolicT, SymbolicT}() + mask = trues(length(obseqs)) for (i, eq) in enumerate(obseqs) - iscall(eq.rhs) || continue - if operation(eq.rhs) == StructuralTransformations.change_origin - push!(rm_idxs, i) - continue - end - end - - for (i, eq) in enumerate(obseqs) - if eq.lhs in tempvars - subs[eq.lhs] = eq.rhs - push!(rm_idxs, i) - end + mask[i] = !iscall(eq.rhs) || operation(eq.rhs) !== StructuralTransformations.change_origin end - obseqs = obseqs[setdiff(eachindex(obseqs), rm_idxs)] - obseqs = map(obseqs) do eq - fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + obseqs = obseqs[mask] + for i in eachindex(obseqs) + obseqs[i] = fixpoint_sub(obseqs[i].lhs, subs) ~ fixpoint_sub(obseqs[i], subs) end - eqs = map(eqs) do eq - fixpoint_sub(eq.lhs, subs) ~ fixpoint_sub(eq.rhs, subs) + eqs = copy(eqs) + for i in eachindex(eqs) + eqs[i] = fixpoint_sub(eqs[i].lhs, subs) ~ fixpoint_sub(eqs[i], subs) end return obseqs, eqs end diff --git a/src/systems/optimal_control_interface.jl b/src/systems/optimal_control_interface.jl index 5a0ddbf8d5..108eb05893 100644 --- a/src/systems/optimal_control_interface.jl +++ b/src/systems/optimal_control_interface.jl @@ -357,16 +357,16 @@ function substitute_model_vars(model, sys, exprs, tspan) t = get_iv(sys) exprs = map( - c -> Symbolics.fast_substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs) + c -> substitute(c, whole_t_map(model, t, x_ops, c_ops)), exprs) (ti, tf) = tspan if symbolic_type(tf) === ScalarSymbolic() _tf = model.tₛ + ti exprs = map( - c -> Symbolics.fast_substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs) - exprs = map(c -> Symbolics.fast_substitute(c, Dict(tf => _tf)), exprs) + c -> substitute(c, free_t_map(model, tf, x_ops, c_ops)), exprs) + exprs = map(c -> substitute(c, Dict(tf => _tf)), exprs) end - exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs) + exprs = map(c -> substitute(c, fixed_t_map(model, x_ops, c_ops)), exprs) exprs end @@ -440,7 +440,7 @@ end function substitute_toterm(vars, exprs) toterm_map = Dict([u => default_toterm(value(u)) for u in vars]) - exprs = map(c -> Symbolics.fast_substitute(c, toterm_map), exprs) + exprs = map(c -> substitute(c, toterm_map), exprs) end function substitute_params(pmap, exprs) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 637ed674ae..23142400b6 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -1,4 +1,3 @@ -symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x) symconvert(::Type{T}, x::V) where {T, V} = convert(promote_type(T, V), x) symconvert(::Type{Real}, x::Integer) = convert(Float16, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) @@ -124,7 +123,7 @@ function MTKParameters( val = symconvert(ctype, val) done = set_value(sym, val) if !done && Symbolics.isarraysymbolic(sym) - if Symbolics.shape(sym) === Symbolics.Unknown() + if !symbolic_has_known_size(sym) for i in eachindex(val) set_value(sym[i], val[i]) end @@ -464,11 +463,11 @@ function validate_parameter_type(ic::IndexCache, p, idx::ParameterIndex, val) end stype = symtype(p) sz = if stype <: AbstractArray - Symbolics.shape(p) == Symbolics.Unknown() ? Symbolics.Unknown() : size(p) + size(p) elseif stype <: Number size(p) else - Symbolics.Unknown() + SU.Unknown(-1) end validate_parameter_type(ic, stype, sz, p, idx, val) end @@ -480,7 +479,7 @@ function validate_parameter_type(ic::IndexCache, idx::ParameterIndex, val) stype = AbstractArray{<:stype} end validate_parameter_type( - ic, stype, Symbolics.Unknown(), nothing, idx, val) + ic, stype, SU.Unknown(-1), nothing, idx, val) end function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) @@ -500,7 +499,7 @@ function validate_parameter_type(ic::IndexCache, stype, sz, sym, index, val) :validate_parameter_type, sym === nothing ? index : sym, stype, val)) end # ... and must match sizes - if stype <: AbstractArray && sz != Symbolics.Unknown() && size(val) != sz + if stype <: AbstractArray && !(sz isa SU.Unknown) && size(val) != sz throw(InvalidParameterSizeException(sym, val)) end # Early exit @@ -719,7 +718,7 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru sym = idx idx = parameter_index(ic, sym) if idx === nothing - Symbolics.shape(sym) == Symbolics.Unknown() && + symbolic_has_known_size(sym) || throw(ParameterNotInSystem(sym)) size(sym) == size(val) || throw(InvalidParameterSizeException(sym, val)) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2a1208586b..0f822c0abe 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -17,10 +17,10 @@ anydict(x) = AnyDict(x) """ $(TYPEDSIGNATURES) -Check if `x` is a symbolic with known size. Assumes `Symbolics.shape(unwrap(x))` +Check if `x` is a symbolic with known size. Assumes `SymbolicUtils.shape(unwrap(x))` is a valid operation. """ -is_sized_array_symbolic(x) = Symbolics.shape(unwrap(x)) != Symbolics.Unknown() +symbolic_has_known_size(x) = !(SU.shape(unwrap(x)) isa SU.Unknown) """ $(TYPEDSIGNATURES) @@ -128,7 +128,7 @@ function add_fallbacks!( haskey(varmap, ttvar) && continue # array symbolics with a defined size may be present in the scalarized form - if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + if Symbolics.isarraysymbolic(var) && symbolic_has_known_size(var) val = map(eachindex(var)) do idx # @something is lazy and saves from writing a massive if-elseif-else @something(get(varmap, var[idx], nothing), @@ -162,7 +162,7 @@ function add_fallbacks!( fallbacks, arrvar, nothing) get(fallbacks, ttarrvar, nothing) Some(nothing) if val !== nothing val = val[idxs...] - is_sized_array_symbolic(arrvar) && push!(arrvars, arrvar) + symbolic_has_known_size(arrvar) && push!(arrvars, arrvar) end else val = nothing @@ -197,7 +197,7 @@ function missingvars( ttsym = toterm(var) haskey(varmap, ttsym) && continue - if Symbolics.isarraysymbolic(var) && is_sized_array_symbolic(var) + if Symbolics.isarraysymbolic(var) && symbolic_has_known_size(var) mask = map(eachindex(var)) do idx !haskey(varmap, var[idx]) && !haskey(varmap, ttsym[idx]) end @@ -486,7 +486,7 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) v === nothing && continue symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v) && continue haskey(varmap, k) || continue - varmap[k] = fixpoint_sub(v, varmap; maxiters = limit) + varmap[k] = value(fixpoint_sub(v, varmap; maxiters = limit)) end end @@ -535,7 +535,7 @@ If a scalarized entry already exists, it is not overridden. function scalarize_vars_in_varmap!(varmap::AbstractDict, vars) for var in vars symbolic_type(var) == ArraySymbolic() || continue - is_sized_array_symbolic(var) || continue + symbolic_has_known_size(var) || continue haskey(varmap, var) || continue for i in eachindex(var) haskey(varmap, var[i]) && continue @@ -958,9 +958,11 @@ end A function to be used as `update_initializeprob!` in `OverrideInitData`. Requires `is_update_oop = Val(true)` to be passed to `update_initializeprob!`. + +Any changes to this method should also be made to the one in ChainRulesCoreExt. """ function update_initializeprob!(initprob, prob) - pgetter = ChainRulesCore.@ignore_derivatives get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter + pgetter = get_scimlfn(prob).initialization_data.metadata.oop_reconstruct_u0_p.pgetter p = pgetter(prob, initprob) return remake(initprob; p) end diff --git a/src/systems/solver_nlprob.jl b/src/systems/solver_nlprob.jl index badfe21efb..d4772018c1 100644 --- a/src/systems/solver_nlprob.jl +++ b/src/systems/solver_nlprob.jl @@ -55,7 +55,7 @@ function inner_nlsystem(sys::System, mm, nlstep_compile::Bool) subrules = Dict([v => unwrap(gamma2*v + inner_tmp[i]) for (i, v) in enumerate(dvs)]) subrules[t] = unwrap(c) - new_rhss = map(Base.Fix2(fast_substitute, subrules), rhss) + new_rhss = map(Base.Fix2(substitute, subrules), rhss) new_rhss = collect(outer_tmp) .+ gamma1 .* new_rhss .- gamma3 * mm * dvs new_eqs = [0 ~ rhs for rhs in new_rhss] diff --git a/src/systems/state_machines.jl b/src/systems/state_machines.jl index ea65981804..48d9d2f4f6 100644 --- a/src/systems/state_machines.jl +++ b/src/systems/state_machines.jl @@ -75,7 +75,7 @@ for (s, T) in [(:timeInState, :Real), seed = hash(s) @eval begin $s(x) = wrap(term($s, x)) - SymbolicUtils.promote_symtype(::typeof($s), _...) = $T + SymbolicUtils.promote_symtype(::typeof($s), ::Type{S}) where {S} = $T function SymbolicUtils.show_call(io, ::typeof($s), args) if isempty(args) print(io, $s, "()") diff --git a/src/systems/system.jl b/src/systems/system.jl index 6db36ebd36..03dd548203 100644 --- a/src/systems/system.jl +++ b/src/systems/system.jl @@ -46,7 +46,7 @@ struct System <: IntermediateDeprecationSystem this noise matrix is diagonal. Diagonal noise can be specified by providing an `N` length vector. If this field is `nothing`, the system does not have noise. """ - noise_eqs::Union{Nothing, AbstractVector, AbstractMatrix} + noise_eqs::Union{Nothing, Vector{SymbolicT}, Matrix{SymbolicT}} """ Jumps associated with the system. Each jump can be a `VariableRateJump`, `ConstantRateJump` or `MassActionJump`. See `JumpProcesses.jl` for more information. @@ -63,7 +63,7 @@ struct System <: IntermediateDeprecationSystem loss of an optimization problem. Scalar loss values must also be provided as a single- element vector. """ - costs::Vector{<:Union{BasicSymbolic, Real}} + costs::Vector{SymbolicT} """ A function which combines costs into a scalar value. This should take two arguments, the `costs` of this system and the consolidated costs of all subsystems in the order @@ -76,25 +76,25 @@ struct System <: IntermediateDeprecationSystem The variables being solved for by this system. For example, in a differential equation system, this contains the dependent variables. """ - unknowns::Vector + unknowns::Vector{SymbolicT} """ The parameters of the system. Parameters can either be variables that parameterize the problem being solved for (e.g. the spring constant of a mass-spring system) or additional unknowns not part of the main dynamics of the system (e.g. discrete/clocked variables in a hybrid ODE). """ - ps::Vector + ps::Vector{SymbolicT} """ The brownian variables of the system, created via `@brownians`. Each brownian variable represents an independent noise. A system with brownians cannot be simulated directly. It needs to be compiled using `mtkcompile` into `noise_eqs`. """ - brownians::Vector + brownians::Vector{SymbolicT} """ The independent variable for a time-dependent system, or `nothing` for a time-independent system. """ - iv::Union{Nothing, BasicSymbolic{Real}} + iv::Union{Nothing, SymbolicT} """ Equations that compute variables of a system that have been eliminated from the set of unknowns by `mtkcompile`. More generally, this contains all variables that can be @@ -117,7 +117,7 @@ struct System <: IntermediateDeprecationSystem A mapping from the name of a variable to the actual symbolic variable in the system. This is used to enable `getproperty` syntax to access variables of a system. """ - var_to_name::Dict{Symbol, Any} + var_to_name::Dict{Symbol, SymbolicT} """ The name of the system. """ @@ -132,11 +132,11 @@ struct System <: IntermediateDeprecationSystem by initial values provided to the problem constructor. Defaults of parent systems take priority over those in child systems. """ - defaults::Dict + defaults::SymmapT """ Guess values for variables of a system that are solved for during initialization. """ - guesses::Dict + guesses::SymmapT """ A list of subsystems of this system. Used for hierarchically building models. """ @@ -167,7 +167,7 @@ struct System <: IntermediateDeprecationSystem associated error message. By default these assertions cause the generated code to output `NaN`s if violated, but can be made to error using `debug_system`. """ - assertions::Dict{BasicSymbolic, String} + assertions::Dict{SymbolicT, String} """ The metadata associated with this system, as a `Base.ImmutableDict`. This follows the same interface as SymbolicUtils.jl. Metadata can be queried and updated using @@ -193,12 +193,12 @@ struct System <: IntermediateDeprecationSystem $INTERNAL_FIELD_WARNING The list of input variables of the system. """ - inputs::OrderedSet{BasicSymbolic} + inputs::OrderedSet{SymbolicT} """ $INTERNAL_FIELD_WARNING The list of output variables of the system. """ - outputs::OrderedSet{BasicSymbolic} + outputs::OrderedSet{SymbolicT} """ The `TearingState` of the system post-simplification with `mtkcompile`. """ @@ -264,9 +264,9 @@ struct System <: IntermediateDeprecationSystem tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events, - connector_type, assertions = Dict{BasicSymbolic, String}(), + connector_type, assertions = Dict{SymbolicT, String}(), metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [], - inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(), + inputs = Set{SymbolicT}(), outputs = Set{SymbolicT}(), tearing_state = nothing, namespacing = true, complete = false, index_cache = nothing, ignored_connections = nothing, preface = nothing, parent = nothing, initializesystem = nothing, @@ -278,30 +278,38 @@ struct System <: IntermediateDeprecationSystem variable $iv. """)) end - jumps = Vector{JumpType}(jumps) - if (checks == true || (checks & CheckComponents) > 0) && iv !== nothing - check_independent_variables([iv]) + @assert iv === nothing || symtype(iv) === Real + if (checks isa Bool && checks === true || checks isa Int && (checks & CheckComponents) > 0) && iv !== nothing + check_independent_variables((iv,)) check_variables(unknowns, iv) check_parameters(ps, iv) check_equations(eqs, iv) - if noise_eqs !== nothing && size(noise_eqs, 1) != length(eqs) - throw(IllFormedNoiseEquationsError(size(noise_eqs, 1), length(eqs))) + Neq = length(eqs) + if noise_eqs isa Matrix{SymbolicT} + N1 = size(noise_eqs, 1) + elseif noise_eqs isa Vector{SymbolicT} + N1 = length(noise_eqs) + elseif noise_eqs === nothing + N1 = Neq + else + error() end + N1 == Neq || throw(IllFormedNoiseEquationsError(N1, Neq)) check_equations(equations(continuous_events), iv) check_subsystems(systems) end - if checks == true || (checks & CheckUnits) > 0 - u = __get_unit_type(unknowns, ps, iv) - if noise_eqs === nothing - check_units(u, eqs) - else - check_units(u, eqs, noise_eqs) - end - if iv !== nothing - check_units(u, jumps, iv) - end - isempty(constraints) || check_units(u, constraints) - end + # if checks == true || (checks & CheckUnits) > 0 + # u = __get_unit_type(unknowns, ps, iv) + # if noise_eqs === nothing + # check_units(u, eqs) + # else + # check_units(u, eqs, noise_eqs) + # end + # if iv !== nothing + # check_units(u, jumps, iv) + # end + # isempty(constraints) || check_units(u, constraints) + # end new(tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps, brownians, iv, observed, parameter_dependencies, var_to_name, name, description, defaults, @@ -320,6 +328,24 @@ function default_consolidate(costs, subcosts) return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0) end +unwrap_vars(vars::AbstractArray{SymbolicT}) = vars +function unwrap_vars(vars::AbstractArray) + result = similar(vars, SymbolicT) + for i in eachindex(vars) + result[i] = SU.Const{VartypeT}(vars[i]) + end + return result +end + +defsdict(x::SymmapT) = x +function defsdict(x::AbstractDict) + result = SymmapT() + for (k, v) in x + result[unwrap(k)] = SU.Const{VartypeT}(v) + end + return result +end + """ $(TYPEDSIGNATURES) @@ -336,71 +362,77 @@ for time-independent systems, unknowns `dvs`, parameters `ps` and brownian varia All other keyword arguments are named identically to the corresponding fields in [`System`](@ref). """ -function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; - constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = [], - costs = BasicSymbolic[], consolidate = default_consolidate, - observed = Equation[], parameter_dependencies = Equation[], defaults = Dict(), - guesses = Dict(), systems = System[], initialization_eqs = Equation[], +function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[]; + constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = JumpType[], + costs = SymbolicT[], consolidate = default_consolidate, + observed = Equation[], parameter_dependencies = Equation[], defaults = SymmapT(), + guesses = SymmapT(), systems = System[], initialization_eqs = Equation[], continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[], - connector_type = nothing, assertions = Dict{BasicSymbolic, String}(), + connector_type = nothing, assertions = Dict{SymbolicT, String}(), metadata = MetadataT(), gui_metadata = nothing, - is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(), - outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing, + is_dde = nothing, tstops = [], inputs = OrderedSet{SymbolicT}(), + outputs = OrderedSet{SymbolicT}(), tearing_state = nothing, ignored_connections = nothing, parent = nothing, description = "", name = nothing, discover_from_metadata = true, initializesystem = nothing, is_initializesystem = false, is_discrete = false, preface = [], checks = true) name === nothing && throw(NoNameError()) + + if !(eqs isa Vector{Equation}) + eqs = Equation[eqs] + end + eqs = eqs::Vector{Equation} + if !isempty(parameter_dependencies) - @warn """ - The `parameter_dependencies` keyword argument is deprecated. Please provide all - such equations as part of the normal equations of the system. - """ - eqs = Equation[eqs; parameter_dependencies] + @invokelatest warn_pdeps() + append!(eqs, parameter_dependencies) end iv = unwrap(iv) - ps = unwrap.(ps) - dvs = unwrap.(dvs) - filter!(!Base.Fix2(isdelay, iv), dvs) - brownians = unwrap.(brownians) - - if !(eqs isa AbstractArray) - eqs = [eqs] + ps = vec(unwrap_vars(ps)) + dvs = vec(unwrap_vars(dvs)) + if iv !== nothing + filter!(!Base.Fix2(isdelay, iv), dvs) end + brownians = unwrap_vars(brownians) if noise_eqs !== nothing - noise_eqs = unwrap.(noise_eqs) + noise_eqs = unwrap_vars(noise_eqs) end - costs = unwrap.(costs) - if isempty(costs) - costs = Union{BasicSymbolic, Real}[] - end + costs = vec(unwrap_vars(costs)) - defaults = anydict(defaults) - guesses = anydict(guesses) - inputs = OrderedSet{BasicSymbolic}(inputs) - outputs = OrderedSet{BasicSymbolic}(outputs) + defaults = defsdict(defaults) + guesses = defsdict(guesses) + if !(inputs isa OrderedSet{SymbolicT}) + inputs = OrderedSet{SymbolicT}(inputs) + end + if !(outputs isa OrderedSet{SymbolicT}) + outputs = OrderedSet{SymbolicT}(outputs) + end for subsys in systems - for var in ModelingToolkit.inputs(subsys) + for var in get_inputs(subsys) push!(inputs, renamespace(subsys, var)) end - for var in ModelingToolkit.outputs(subsys) + for var in get_outputs(subsys) push!(outputs, renamespace(subsys, var)) end end - var_to_name = anydict() + var_to_name = Dict{Symbol, SymbolicT}() - let defaults = discover_from_metadata ? defaults : Dict(), - guesses = discover_from_metadata ? guesses : Dict(), - inputs = discover_from_metadata ? inputs : Set(), - outputs = discover_from_metadata ? outputs : Set() + let defaults = discover_from_metadata ? defaults : SymmapT(), + guesses = discover_from_metadata ? guesses : SymmapT(), + inputs = discover_from_metadata ? inputs : OrderedSet{SymbolicT}(), + outputs = discover_from_metadata ? outputs : OrderedSet{SymbolicT}() process_variables!(var_to_name, defaults, guesses, dvs) process_variables!(var_to_name, defaults, guesses, ps) - process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed]) - process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed]) + buffer = SymbolicT[] + for eq in observed + push!(buffer, eq.lhs) + push!(buffer, eq.rhs) + end + process_variables!(var_to_name, defaults, guesses, buffer) for var in dvs if isinput(var) @@ -410,15 +442,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; end end end - filter!(!(isnothing ∘ last), defaults) - filter!(!(isnothing ∘ last), guesses) - defaults = anydict([unwrap(k) => unwrap(v) for (k, v) in defaults]) - guesses = anydict([unwrap(k) => unwrap(v) for (k, v) in guesses]) + filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), defaults) + filter!(!(Base.Fix1(===, COMMON_NOTHING) ∘ last), guesses) - sysnames = nameof.(systems) - unique_sysnames = Set(sysnames) - if length(unique_sysnames) != length(sysnames) - throw(NonUniqueSubsystemsError(sysnames, unique_sysnames)) + + if !allunique(map(nameof, systems)) + nonunique_subsystems(systems) end continuous_events, discrete_events = create_symbolic_events( @@ -432,7 +461,10 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; is_dde = _check_if_dde(eqs, iv, systems) end - assertions = Dict{BasicSymbolic, String}(unwrap(k) => v for (k, v) in assertions) + _assertions = Dict{SymbolicT, String} + for (k, v) in assertions + _assertions[unwrap(k)::SymbolicT] = v + end if isempty(metadata) metadata = MetadataT() @@ -446,6 +478,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; metadata = meta end metadata = refreshed_metadata(metadata) + jumps = Vector{JumpType}(jumps) System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints, costs, consolidate, dvs, ps, brownians, iv, observed, Equation[], var_to_name, name, description, defaults, guesses, systems, initialization_eqs, @@ -455,6 +488,21 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = []; initializesystem, is_initializesystem, is_discrete; checks) end +@noinline function nonunique_subsystems(systems) + sysnames = nameof.(systems) + unique_sysnames = Set(sysnames) + throw(NonUniqueSubsystemsError(sysnames, unique_sysnames)) +end + +@noinline function warn_pdeps() + @warn """ + The `parameter_dependencies` keyword argument is deprecated. Please provide all + such equations as part of the normal equations of the system. + """ +end + +SymbolicIndexingInterface.getname(x::System) = nameof(x) + """ $(TYPEDSIGNATURES) @@ -481,7 +529,7 @@ function System(eqs::Vector{Equation}, iv; kwargs...) diffeqs = Equation[] othereqs = Equation[] for eq in eqs - if !(eq.lhs isa Union{Symbolic, Number, AbstractArray}) + if !(eq.lhs isa Union{SymbolicT, Number, AbstractArray}) push!(othereqs, eq) continue end @@ -608,15 +656,13 @@ function gather_array_params(ps) for p in ps if iscall(p) && operation(p) === getindex par = arguments(p)[begin] - if Symbolics.shape(Symbolics.unwrap(par)) !== Symbolics.Unknown() && - all(par[i] in ps for i in eachindex(par)) + if symbolic_has_known_size(p) && all(par[i] in ps for i in eachindex(par)) push!(new_ps, par) else push!(new_ps, p) end else - if symbolic_type(p) == ArraySymbolic() && - Symbolics.shape(unwrap(p)) != Symbolics.Unknown() + if symbolic_type(p) == ArraySymbolic() && symbolic_has_known_size(p) for i in eachindex(p) delete!(new_ps, p[i]) end @@ -677,7 +723,7 @@ function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv) for var in auxvars if !iscall(var) - occursin(iv, var) && (var ∈ sts || + SU.query!(isequal(iv), var) && (var ∈ sts || throw(ArgumentError("Time-dependent variable $var is not an unknown of the system."))) elseif length(arguments(var)) > 1 throw(ArgumentError("Too many arguments for variable $var.")) @@ -722,19 +768,15 @@ differential equations. """ is_dde(sys::AbstractSystem) = has_is_dde(sys) && get_is_dde(sys) -function _check_if_dde(eqs, iv, subsystems) - is_dde = any(ModelingToolkit.is_dde, subsystems) - if !is_dde - vs = Set() - for eq in eqs - vars!(vs, eq) - is_dde = any(vs) do sym - isdelay(unwrap(sym), iv) - end - is_dde && break - end +_check_if_dde(eqs::Vector{Equation}, iv::Nothing, subsystems::Vector{System}) = false +function _check_if_dde(eqs::Vector{Equation}, iv::SymbolicT, subsystems::Vector{System}) + any(ModelingToolkit.is_dde, subsystems) && return true + pred = Base.Fix2(isdelay, iv) + for eq in eqs + SU.query!(pred, eq.lhs) && return true + SU.query!(pred, eq.rhs) && return true end - return is_dde + return false end """ @@ -896,7 +938,7 @@ function NonlinearSystem(sys::System) subrules[var] = 0.0 end eqs = map(eqs) do eq - fast_substitute(eq, subrules) + substitute(eq, subrules) end nsys = System(eqs, unknowns(sys), [parameters(sys); get_iv(sys)]; defaults = merge(defaults(sys), Dict(get_iv(sys) => Inf)), guesses = guesses(sys), diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 4c52300239..6aa4166d03 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -195,7 +195,7 @@ function simplify_optimization_system(sys::System; split = true, kwargs...) dvs[i] = irrvar end end - econs = fast_substitute.(econs, (irreducible_subs,)) + econs = substitute.(econs, (irreducible_subs,)) nlsys = System(econs, dvs, parameters(sys); name = :___tmp_nlsystem) snlsys = mtkcompile(nlsys; kwargs..., fully_determined = false) obs = observed(snlsys) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a1460731cb..77f4f144c4 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -1,11 +1,11 @@ using DataStructures using Symbolics: linear_expansion, unwrap -using SymbolicUtils: iscall, operation, arguments, Symbolic +using SymbolicUtils: iscall, operation, arguments using SymbolicUtils: quick_cancel, maketerm using ..ModelingToolkit import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten, value, InvalidSystemException, isdifferential, _iszero, - isparameter, Connection, + isparameter, Connection, SymbolicT independent_variables, SparseMatrixCLIL, AbstractSystem, equations, isirreducible, input_timedomain, TimeDomain, InferredTimeDomain, @@ -304,7 +304,7 @@ end function symbolic_contains(var, set) var in set || symbolic_type(var) == ArraySymbolic() && - Symbolics.shape(var) != Symbolics.Unknown() && + symbolic_has_known_size(var) && all(x -> x in set, Symbolics.scalarize(var)) end @@ -372,10 +372,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) union!(dvs, xx) end end - ps = Set{Symbolic}() + ps = Set{SymbolicT}() for x in full_parameters(sys) push!(ps, x) - if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() + if symbolic_type(x) == ArraySymbolic() && symbolic_has_known_size(x) xx = Symbolics.scalarize(x) union!(ps, xx) end @@ -408,7 +408,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) if iscall(eq.lhs) && (op = operation(eq.lhs)) isa Differential && isequal(op.x, iv) && is_time_dependent_parameter(only(arguments(eq.lhs)), ps, iv) # parameter derivatives are opted out by specifying `D(p) ~ missing`, but - # we want to store `nothing` in the map because that means `fast_substitute` + # we want to store `nothing` in the map because that means `substitute` # will ignore the rule. We will this identify the presence of `eq′.lhs` in # the differentiated expression and error. param_derivative_map[eq.lhs] = coalesce(eq.rhs, nothing) @@ -680,7 +680,7 @@ function trivial_tearing!(ts::TearingState) end isvalid || continue # skip if the LHS is present in the RHS, since then this isn't explicit - if occursin(eq.lhs, eq.rhs) + if SU.query!(isequal(eq.lhs), eq.rhs) push!(blacklist, i) continue end @@ -751,12 +751,12 @@ function shift_discrete_system(ts::TearingState) if any(isequal(k), fullvars) && !isa(operation(k), Union{Sample, Hold, Pre})) for i in eachindex(fullvars) - fullvars[i] = StructuralTransformations.simplify_shifts(fast_substitute( - fullvars[i], discmap; operator = Union{Sample, Hold, Pre})) + fullvars[i] = StructuralTransformations.simplify_shifts(substitute( + fullvars[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}())) end for i in eachindex(eqs) - eqs[i] = StructuralTransformations.simplify_shifts(fast_substitute( - eqs[i], discmap; operator = Union{Sample, Hold, Pre})) + eqs[i] = StructuralTransformations.simplify_shifts(substitute( + eqs[i], discmap; filterer = Symbolics.FPSubFilterer{Union{Sample, Hold, Pre}}())) end @set! ts.sys.eqs = eqs @set! ts.fullvars = fullvars @@ -846,7 +846,7 @@ function Base.show(io::IO, mime::MIME"text/plain", s::SystemStructure) " variables\n") Base.print_matrix(io, SystemStructurePrintMatrix(s)) else - S = incidence_matrix(s.graph, Num(Sym{Real}(:×))) + S = incidence_matrix(s.graph, Num(SSym(:×; type = Real, shape = SU.ShapeVecT()))) print(io, "Incidence matrix:") show(io, mime, S) end diff --git a/src/systems/unit_check.jl b/src/systems/unit_check.jl index acf7451065..839b77094f 100644 --- a/src/systems/unit_check.jl +++ b/src/systems/unit_check.jl @@ -12,7 +12,7 @@ function __get_literal_unit(x) if x isa Pair x = x[1] end - if !(x isa Union{Num, Symbolic}) + if !(x isa Union{Num, SymbolicT}) return nothing end v = value(x) @@ -71,7 +71,6 @@ get_unit(x::AbstractArray) = map(get_unit, x) get_unit(x::Num) = get_unit(unwrap(x)) get_unit(x::Symbolics.Arr) = get_unit(unwrap(x)) get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) -get_unit(op::Difference, args) = get_unit(args[1]) / get_unit(op.t) get_unit(op::typeof(getindex), args) = get_unit(args[1]) get_unit(x::SciMLBase.NullParameters) = unitless get_unit(op::typeof(instream), args) = get_unit(args[1]) @@ -114,7 +113,7 @@ function get_unit(op::Conditional, args) return terms[2] end -function get_unit(op::typeof(Symbolics._mapreduce), args) +function get_unit(op::typeof(mapreduce), args) if args[2] == + get_unit(args[3]) else @@ -129,7 +128,7 @@ function get_unit(op::Comparison, args) return unitless end -function get_unit(x::Symbolic) +function get_unit(x::SymbolicT) if (u = __get_literal_unit(x)) !== nothing screen_unit(u) elseif issym(x) @@ -156,8 +155,8 @@ function get_unit(x::Symbolic) op = operation(x) if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) + elseif iscall(op) && operation(op) === getindex + gp = arguments(op)[1] return screen_unit(getmetadata(gp, VariableUnit, unitless)) end # Actual function calls: args = arguments(x) @@ -249,14 +248,14 @@ function _validate(conn::Connection; info::String = "") end function validate(jump::Union{VariableRateJump, - ConstantRateJump}, t::Symbolic; + ConstantRateJump}, t::SymbolicT; info::String = "") newinfo = replace(info, "eq." => "jump") _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units validate(jump.affect!, info = newinfo) end -function validate(jump::MassActionJump, t::Symbolic; info::String = "") +function validate(jump::MassActionJump, t::SymbolicT; info::String = "") left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols net_symbols = [x[1] for x in jump.net_stoch] all_symbols = vcat(left_symbols, net_symbols) @@ -267,7 +266,7 @@ function validate(jump::MassActionJump, t::Symbolic; info::String = "") ["scaled_rates", "1/(t*reactants^$n))"]; info) end -function validate(jumps::Vector{JumpType}, t::Symbolic) +function validate(jumps::Vector{JumpType}, t::SymbolicT) labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] majs = filter(x -> x isa MassActionJump, jumps) crjs = filter(x -> x isa ConstantRateJump, jumps) @@ -284,7 +283,7 @@ function validate(eq::Union{Inequality, Equation}; info::String = "") end end function validate(eq::Equation, - term::Union{Symbolic, DQ.AbstractQuantity, Num}; info::String = "") + term::Union{SymbolicT, DQ.AbstractQuantity, Num}; info::String = "") _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) end function validate(eq::Equation, terms::Vector; info::String = "") @@ -306,10 +305,10 @@ function validate(eqs::Vector, noise::Matrix; info::String = "") all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -function validate(eqs::Vector, term::Symbolic; info::String = "") +function validate(eqs::Vector, term::SymbolicT; info::String = "") all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing +validate(term::SymbolicT) = safe_get_unit(term, "") !== nothing """ Throws error if units of equations are invalid. diff --git a/src/systems/validation.jl b/src/systems/validation.jl index d416a02ea2..ecd98b1d43 100644 --- a/src/systems/validation.jl +++ b/src/systems/validation.jl @@ -6,11 +6,11 @@ using ..ModelingToolkit: ValidationError, get_systems, Conditional, Comparison using JumpProcesses: MassActionJump, ConstantRateJump, VariableRateJump -using Symbolics: Symbolic, value, issym, isadd, ismul, ispow +using Symbolics: SymbolicT, value, issym, isadd, ismul, ispow, CallAndWrap const MT = ModelingToolkit -Base.:*(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x * y -Base.:/(x::Union{Num, Symbolic}, y::Unitful.AbstractQuantity) = x / y +Base.:*(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x * y +Base.:/(x::Union{Num, SymbolicT}, y::Unitful.AbstractQuantity) = x / y """ Throw exception on invalid unit types, otherwise return argument. @@ -49,7 +49,7 @@ get_unit(x::Real) = unitless get_unit(x::Unitful.Quantity) = screen_unit(Unitful.unit(x)) get_unit(x::AbstractArray) = map(get_unit, x) get_unit(x::Num) = get_unit(value(x)) -function get_unit(x::Union{Symbolics.ArrayOp, Symbolics.Arr, Symbolics.CallWithMetadata}) +function get_unit(x::Union{Symbolics.Arr, CallAndWrap}) get_literal_unit(x) end get_unit(op::Differential, args) = get_unit(args[1]) / get_unit(op.x) @@ -89,7 +89,7 @@ function get_unit(op::Conditional, args) return terms[2] end -function get_unit(op::typeof(Symbolics._mapreduce), args) +function get_unit(op::typeof(mapreduce), args) if args[2] == + get_unit(args[3]) else @@ -104,7 +104,7 @@ function get_unit(op::Comparison, args) return unitless end -function get_unit(x::Symbolic) +function get_unit(x::SymbolicT) if issym(x) get_literal_unit(x) elseif isadd(x) @@ -129,8 +129,8 @@ function get_unit(x::Symbolic) op = operation(x) if issym(op) || (iscall(op) && iscall(operation(op))) # Dependent variables, not function calls return screen_unit(getmetadata(x, VariableUnit, unitless)) # Like x(t) or x[i] - elseif iscall(op) && !iscall(operation(op)) - gp = getmetadata(x, Symbolics.GetindexParent, nothing) # Like x[1](t) + elseif iscall(op) && operation(op) === getindex + gp = arguments(op)[1] return screen_unit(getmetadata(gp, VariableUnit, unitless)) end # Actual function calls: args = arguments(x) @@ -214,14 +214,14 @@ function _validate(conn::Connection; info::String = "") end function validate(jump::Union{MT.VariableRateJump, - MT.ConstantRateJump}, t::Symbolic; + MT.ConstantRateJump}, t::SymbolicT; info::String = "") newinfo = replace(info, "eq." => "jump") _validate([jump.rate, 1 / t], ["rate", "1/t"], info = newinfo) && # Assuming the rate is per time units validate(jump.affect!, info = newinfo) end -function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") +function validate(jump::MT.MassActionJump, t::SymbolicT; info::String = "") left_symbols = [x[1] for x in jump.reactant_stoch] #vector of pairs of symbol,int -> vector symbols net_symbols = [x[1] for x in jump.net_stoch] all_symbols = vcat(left_symbols, net_symbols) @@ -232,7 +232,7 @@ function validate(jump::MT.MassActionJump, t::Symbolic; info::String = "") ["scaled_rates", "1/(t*reactants^$n))"]; info) end -function validate(jumps::Vector{JumpType}, t::Symbolic) +function validate(jumps::Vector{JumpType}, t::SymbolicT) labels = ["in Mass Action Jumps,", "in Constant Rate Jumps,", "in Variable Rate Jumps,"] majs = filter(x -> x isa MassActionJump, jumps) crjs = filter(x -> x isa ConstantRateJump, jumps) @@ -249,7 +249,7 @@ function validate(eq::MT.Equation; info::String = "") end end function validate(eq::MT.Equation, - term::Union{Symbolic, Unitful.Quantity, Num}; info::String = "") + term::Union{SymbolicT, Unitful.Quantity, Num}; info::String = "") _validate([eq.lhs, eq.rhs, term], ["left", "right", "noise"]; info) end function validate(eq::MT.Equation, terms::Vector; info::String = "") @@ -271,10 +271,10 @@ function validate(eqs::Vector, noise::Matrix; info::String = "") all([validate(eqs[idx], noise[idx, :], info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -function validate(eqs::Vector, term::Symbolic; info::String = "") +function validate(eqs::Vector, term::SymbolicT; info::String = "") all([validate(eqs[idx], term, info = info * " in eq. #$idx") for idx in 1:length(eqs)]) end -validate(term::Symbolics.SymbolicUtils.Symbolic) = safe_get_unit(term, "") !== nothing +validate(term::SymbolicT) = safe_get_unit(term, "") !== nothing """ Throws error if units of equations are invalid. diff --git a/src/utils.jl b/src/utils.jl index 0da7e4860b..fc5fbda207 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,7 +20,7 @@ function detime_dvs(op) if !iscall(op) op elseif issym(operation(op)) - Sym{Real}(nameof(operation(op))) + SSym(nameof(operation(op)); type = Real, shape = SU.ShapeVecT()) else maketerm(typeof(op), operation(op), detime_dvs.(arguments(op)), metadata(op)) @@ -33,7 +33,7 @@ end Reverse `detime_dvs` for the given `dvs` using independent variable `iv`. """ function retime_dvs(op, dvs, iv) - issym(op) && return Sym{FnType{Tuple{symtype(iv)}, Real}}(nameof(op))(iv) + issym(op) && return SSym(nameof(op); type = FnType{Tuple{symtype(iv)}, Real}, shape = SU.ShapeVecT())(iv) iscall(op) ? maketerm(typeof(op), operation(op), retime_dvs.(arguments(op), (dvs,), (iv,)), metadata(op)) : @@ -84,7 +84,7 @@ end function readable_code(expr) expr = Base.remove_linenums!(_readable_code(expr)) rec_remove_macro_linenums!(expr) - JuliaFormatter.format_text(string(expr), JuliaFormatter.SciMLStyle()) + return string(expr) end # System validation enums @@ -117,11 +117,14 @@ const CheckUnits = 1 << 2 function check_independent_variables(ivs) for iv in ivs - isparameter(iv) || - @warn "Independent variable $iv should be defined with @independent_variables $iv." + isparameter(iv) || @invokelatest warn_indepvar(iv) end end +@noinline function warn_indepvar(iv::SymbolicT) + @warn "Independent variable $iv should be defined with @independent_variables $iv." +end + function check_parameters(ps, iv) for p in ps isequal(iv, p) && @@ -129,29 +132,23 @@ function check_parameters(ps, iv) end end -function is_delay_var(iv, var) - if Symbolics.isarraysymbolic(var) - return is_delay_var(iv, first(collect(var))) - end - args = nothing - try - args = arguments(var) - catch - return false +function is_delay_var(iv::SymbolicT, var::SymbolicT) + Moshi.Match.@match var begin + BSImpl.Term(; f, args) => begin + length(args) > 1 && return false + arg = args[1] + isequal(arg, iv) && return false + return symtype(arg) <: Real + end + _ => false end - length(args) > 1 && return false - isequal(first(args), iv) && return false - delay = iv - first(args) - delay isa Integer || - delay isa AbstractFloat || - (delay isa Num && isreal(value(delay))) end function check_variables(dvs, iv) for dv in dvs isequal(iv, dv) && throw(ArgumentError("Independent variable $iv not allowed in dependent variables.")) - (is_delay_var(iv, dv) || occursin(iv, dv)) || + (is_delay_var(iv, dv) || SU.query!(isequal(iv), dv)) || throw(ArgumentError("Variable $dv is not a function of independent variable $iv.")) end end @@ -187,20 +184,35 @@ function collect_ivs(eqs, op = Differential) return ivs end +struct IndepvarCheckPredicate + iv::SymbolicT +end + +function (icp::IndepvarCheckPredicate)(ex::SymbolicT) + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Differential end => begin + f = f::Differential + isequal(f.x, icp.iv) || throw_multiple_iv(icp.iv, f.x) + return false + end + _ => false + end +end + +@noinline function throw_multiple_iv(iv, newiv) + throw(ArgumentError("Differential w.r.t. variable ($newiv) other than the independent variable ($iv) are not allowed.")) +end + """ check_equations(eqs, iv) Assert that equations are well-formed when building ODE, i.e., only containing a single independent variable. """ -function check_equations(eqs, iv) - ivs = collect_ivs(eqs) - display = collect(ivs) - length(ivs) <= 1 || - throw(ArgumentError("Differential w.r.t. multiple variables $display are not allowed.")) - if length(ivs) == 1 - single_iv = pop!(ivs) - isequal(single_iv, iv) || - throw(ArgumentError("Differential w.r.t. variable ($single_iv) other than the independent variable ($iv) are not allowed.")) +function check_equations(eqs::Vector{Equation}, iv::SymbolicT) + icp = IndepvarCheckPredicate(iv) + for eq in eqs + SU.query!(icp, eq.lhs) + SU.query!(icp, eq.rhs) end end @@ -211,10 +223,12 @@ Assert that the subsystems have the appropriate namespacing behavior. """ function check_subsystems(systems) idxs = findall(!does_namespacing, systems) - if !isempty(idxs) - names = join(" " .* string.(nameof.(systems[idxs])), "\n") - throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)")) - end + isempty(idxs) || throw_bad_namespacing(systems, idxs) +end + +@noinline function throw_bad_namespacing(systems, idxs) + names = join(" " .* string.(nameof.(systems[idxs])), "\n") + throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)")) end """ @@ -271,57 +285,51 @@ function setdefault(v, val) val === nothing ? v : wrap(setdefaultval(unwrap(v), value(val))) end -function process_variables!(var_to_name, defs, guesses, vars) +function process_variables!(var_to_name::Dict{Symbol, SymbolicT}, defs::SymmapT, guesses::SymmapT, vars::Vector{SymbolicT}) collect_defaults!(defs, vars) collect_guesses!(guesses, vars) collect_var_to_name!(var_to_name, vars) return nothing end -function process_variables!(var_to_name, defs, vars) +function process_variables!(var_to_name::Dict{Symbol, SymbolicT}, defs::SymmapT, vars::Vector{SymbolicT}) collect_defaults!(defs, vars) collect_var_to_name!(var_to_name, vars) return nothing end -function collect_defaults!(defs, vars) +function collect_defaults!(defs::SymmapT, vars::Vector{SymbolicT}) for v in vars - symbolic_type(v) == NotSymbolic() && continue - if haskey(defs, v) || !hasdefault(unwrap(v)) || (def = getdefault(v)) === nothing + isconst(v) && continue + if haskey(defs, v) || (def = Symbolics.getdefaultval(v, nothing)) === nothing continue end - defs[v] = getdefault(v) + defs[v] = SU.Const{VartypeT}(def) end return defs end -function collect_guesses!(guesses, vars) +function collect_guesses!(guesses::SymmapT, vars::Vector{SymbolicT}) for v in vars + isconst(v) && continue symbolic_type(v) == NotSymbolic() && continue - if haskey(guesses, v) || !hasguess(unwrap(v)) || (def = getguess(v)) === nothing + if haskey(guesses, v) || (def = getguess(v)) === nothing continue end - guesses[v] = getguess(v) + guesses[v] = SU.Const{VartypeT}(def) end return guesses end -function collect_var_to_name!(vars, xs) +function collect_var_to_name!(vars::Dict{Symbol, SymbolicT}, xs::Vector{SymbolicT}) for x in xs - symbolic_type(x) == NotSymbolic() && continue - x = unwrap(x) - if hasmetadata(x, Symbolics.GetindexParent) - xarr = getmetadata(x, Symbolics.GetindexParent) - hasname(xarr) || continue - vars[Symbolics.getname(xarr)] = xarr - else - if iscall(x) && operation(x) === getindex - x = arguments(x)[1] - end - x = unwrap(x) - hasname(x) || continue - vars[Symbolics.getname(unwrap(x))] = x + x = Moshi.Match.@match x begin + BSImpl.Const(;) => continue + BSImpl.Term(; f, args) && if f === getindex end => args[1] + _ => x end + hasname(x) || continue + vars[getname(x)] = x end end @@ -329,9 +337,7 @@ end Throw error when difference/derivative operation occurs in the R.H.S. """ @noinline function throw_invalid_operator(opvar, eq, op::Type) - if op === Difference - error("The Difference operator is deprecated, use ShiftIndex instead") - elseif op === Differential + if op === Differential optext = "derivative" end msg = "The $optext variable must be isolated to the left-hand " * @@ -388,11 +394,9 @@ isdifferential(expr) = isoperator(expr, Differential) isdiffeq(eq) = isdifferential(eq.lhs) || isoperator(eq.lhs, Shift) isvariable(x::Num)::Bool = isvariable(value(x)) -function isvariable(x)::Bool - x isa Symbolic || return false - p = getparent(x, nothing) - p === nothing || (x = p) - hasmetadata(x, VariableSource) +function isvariable(x) + x isa SymbolicT || return false + hasmetadata(x, VariableSource) || iscall(x) && operation(x) === getindex && isvariable(arguments(x)[1])::Bool end """ @@ -412,7 +416,7 @@ v = ModelingToolkit.vars(D(y) ~ u) v == Set([D(y), u]) ``` """ -function vars(exprs::Symbolic; op = Differential) +function vars(exprs::SymbolicT; op = Differential) iscall(exprs) ? vars([exprs]; op = op) : Set([exprs]) end vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op) @@ -522,13 +526,10 @@ ModelingToolkit.collect_applied_operators(eq, Differential) == Set([D(y)]) The difference compared to `collect_operator_variables` is that `collect_operator_variables` returns the variable without the operator applied. """ -function collect_applied_operators(x, op) - v = vars(x, op = op) - filter(v) do x - issym(x) && return false - iscall(x) && return operation(x) isa op - false - end +function collect_applied_operators(x::SymbolicT, ::Type{op}) where {op} + v = Set{SymbolicT}() + SU.search_variables!(v, x; is_atomic = OnlyOperatorIsAtomic{op}()) + return v end """ @@ -539,12 +540,12 @@ Search through equations and parameter dependencies of `sys`, where sys is at a recursively searches through all subsystems of `sys`, increasing the depth if it is not `-1`. A depth of `-1` indicates searching for variables with `GlobalScope`. """ -function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential) +function collect_scoped_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, sys::AbstractSystem, iv::Union{SymbolicT, Nothing}; depth = 1, op = Differential) if has_eqs(sys) for eq in equations(sys) eqtype_supports_collect_vars(eq) || continue if eq isa Equation - eq.lhs isa Union{Symbolic, Number} || continue + symtype(eq.lhs) <: Number || continue end collect_vars!(unknowns, parameters, eq, iv; depth, op) end @@ -618,6 +619,24 @@ function Base.showerror(io::IO, err::OperatorIndepvarMismatchError) end end +struct OnlyOperatorIsAtomic{O} end + +function (::OnlyOperatorIsAtomic{O})(ex::SymbolicT) where {O} + Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa O end => true + _ => false + end +end + +struct OperatorIsAtomic{O} end + +function (::OperatorIsAtomic{O})(ex::SymbolicT) where {O} + SU.default_is_atomic(ex) && Moshi.Match.@match ex begin + BSImpl.Term(; f) && if f isa Operator end => f isa O + _ => true + end +end + """ $(TYPEDSIGNATURES) @@ -632,11 +651,15 @@ can be checked using `check_scope_depth`. This function should return `nothing`. """ -function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics.Operator) - if issym(expr) - return collect_var!(unknowns, parameters, expr, iv; depth) +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::SymbolicT, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + Moshi.Match.@match expr begin + BSImpl.Const(;) => return + BSImpl.Sym(;) => return collect_var!(unknowns, parameters, expr, iv; depth) + _ => nothing end - for var in vars(expr; op) + vars = Set{SymbolicT}() + SU.search_variables!(vars, expr; is_atomic = OperatorIsAtomic{op}()) + for var in vars while iscall(var) && operation(var) isa op validate_operator(operation(var), arguments(var), iv; context = expr) var = arguments(var)[1] @@ -646,6 +669,13 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics return nothing end +function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::AbstractArray{SymbolicT}, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator) + for var in expr + collect_vars!(unknowns, parameters, var, iv; depth, op) + end + return nothing +end + """ $(TYPEDSIGNATURES) @@ -687,11 +717,11 @@ function collect_var!(unknowns, parameters, var, iv; depth = 0) Encountered a wrapped value in `collect_var!`. This function should only ever \ receive unwrapped symbolic variables. This is likely a bug in the code generating \ an expression passed to `collect_vars!` or `collect_scoped_vars!`. A common cause \ - is using `substitute` or `fast_substitute` with rules where the values are \ + is using `substitute` with rules where the values are \ wrapped symbolic variables. """) end - check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing + check_scope_depth(getmetadata(var, SymScope, LocalScope())::AllScopes, depth) || return nothing var = setmetadata(var, SymScope, LocalScope()) if iscalledparameter(var) callable = getcalledparameter(var) @@ -719,7 +749,7 @@ function check_scope_depth(scope, depth) if scope isa LocalScope return depth == 0 elseif scope isa ParentScope - return depth > 0 && check_scope_depth(scope.parent, depth - 1) + return depth > 0 && check_scope_depth(scope.parent, depth - 1)::Bool elseif scope isa GlobalScope return depth == -1 end @@ -803,7 +833,7 @@ end function _with_unit(f, x, t, args...) x = f(x, args...) - if hasmetadata(x, VariableUnit) && (t isa Symbolic && hasmetadata(t, VariableUnit)) + if hasmetadata(x, VariableUnit) && (t isa SymbolicT && hasmetadata(t, VariableUnit)) xu = getmetadata(x, VariableUnit) tu = getmetadata(t, VariableUnit) x = setmetadata(x, VariableUnit, xu / tu) @@ -833,8 +863,8 @@ end Check if `T` is an appropriate symtype for a symbolic variable representing a floating point number or array of such numbers. """ -function is_floatingpoint_symtype(T::Type) - return T == Real || T == Number || T == Complex || T <: AbstractFloat || +function is_floatingpoint_symtype(T) + return T === Real || T === Number || T === Complex || T <: AbstractFloat || T <: AbstractArray && is_floatingpoint_symtype(eltype(T)) end @@ -978,7 +1008,7 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) end any(isequal(expr), vars) && return expr iscall(expr) || return expr - Symbolics.shape(expr) == Symbolics.Unknown() && return expr + symbolic_has_known_size(expr) || return expr haskey(state, expr) && return state[expr] op = operation(expr) args = arguments(expr) @@ -1069,7 +1099,7 @@ function var_in_varlist(var, varlist::AbstractSet, iv) # indexed array symbolic, unscalarized array present (iscall(var) && operation(var) === getindex && arguments(var)[1] in varlist) || # unscalarized sized array symbolic, all scalarized elements present - (symbolic_type(var) == ArraySymbolic() && is_sized_array_symbolic(var) && + (symbolic_type(var) == ArraySymbolic() && symbolic_has_known_size(var) && all(x -> x in varlist, collect(var))) || # delayed variables (isdelay(var, iv) && var_in_varlist(operation(var)(iv), varlist, iv)) @@ -1184,4 +1214,4 @@ function wrap_with_D(n, D, repeats) else wrap_with_D(D(n), D, repeats - 1) end -end \ No newline at end of file +end diff --git a/src/variables.jl b/src/variables.jl index 46c9c95bc6..543617e7de 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -179,7 +179,7 @@ struct Stream <: AbstractConnectType end # special stream connector Get the connect type of x. See also [`hasconnect`](@ref). """ getconnect(x::Num) = getconnect(unwrap(x)) -getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing) +getconnect(x::SymbolicT) = Symbolics.getmetadata(x, VariableConnectType, nothing) """ hasconnect(x) @@ -190,13 +190,13 @@ function setconnect(x, t::Type{T}) where {T <: AbstractConnectType} setmetadata(x, VariableConnectType, t) end -### Input, Output, Irreducible -isvarkind(m, x::Union{Num, Symbolics.Arr}) = isvarkind(m, value(x)) -function isvarkind(m, x) - iskind = getmetadata(x, m, nothing) - iskind !== nothing && return iskind - x = getparent(x, x) - getmetadata(x, m, false) +### Input, Output, Irreducible +isvarkind(m, x, def = false) = safe_getmetadata(m, x, def) +safe_getmetadata(m, x::Union{Num, Symbolics.Arr}, def) = safe_getmetadata(m, value(x), def) +function safe_getmetadata(m, x, default) + hasmetadata(x, m) && return getmetadata(x, m) + iscall(x) && operation(x) === getindex && return safe_getmetadata(m, arguments(x)[1], default) + return default end """ @@ -218,13 +218,13 @@ setio(x, i::Bool, o::Bool) = setoutput(setinput(x, i), o) Check if variable `x` is marked as an input. """ -isinput(x) = isvarkind(VariableInput, x) +isinput(x) = isvarkind(VariableInput, x)::Bool """ $(TYPEDSIGNATURES) Check if variable `x` is marked as an output. """ -isoutput(x) = isvarkind(VariableOutput, x) +isoutput(x) = isvarkind(VariableOutput, x)::Bool # Before the solvability check, we already have handled IO variables, so # irreducibility is independent from IO. @@ -234,7 +234,7 @@ isoutput(x) = isvarkind(VariableOutput, x) Check if `x` is marked as irreducible. This prevents it from being eliminated as an observed variable in `mtkcompile`. """ -isirreducible(x) = isvarkind(VariableIrreducible, x) +isirreducible(x) = isvarkind(VariableIrreducible, x)::Bool setirreducible(x, v::Bool) = setmetadata(x, VariableIrreducible, v) state_priority(x::Union{Num, Symbolics.Arr}) = state_priority(unwrap(x)) """ @@ -245,19 +245,30 @@ chosen as a state in `mtkcompile`. """ state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64 -normalize_to_differential(x) = x +function normalize_to_differential(@nospecialize(op)) + if op isa Shift && op.t isa SymbolicT + return Differential(op.t) ^ op.steps + else + return op + end +end -function default_toterm(x) - if iscall(x) && (op = operation(x)) isa Operator - if !(op isa Differential) - if op isa Shift && op.steps < 0 +default_toterm(x) = x +function default_toterm(x::SymbolicT) + Moshi.Match.@match x begin + BSImpl.Term(; f, args, shape, type, metadata) && if f isa Operator end => begin + if f isa Shift && f.steps < 0 return shift2term(x) + elseif f isa Differential + return Symbolics.diff2term(x) + else + newf = normalize_to_differential(f) + f === newf && return x + x = BSImpl.Term{VartypeT}(newf, args; type, shape, metadata) + return Symbolics.diff2term(x) end - x = normalize_to_differential(op)(arguments(x)...) end - Symbolics.diff2term(x) - else - x + _ => return x end end @@ -280,12 +291,12 @@ Create parameters with bounds like this @parameters p [bounds=(-1, 1)] ``` """ -function getbounds(x::Union{Num, Symbolics.Arr, SymbolicUtils.Symbolic}) +function getbounds(x::Union{Num, Symbolics.Arr, SymbolicT}) x = unwrap(x) - p = Symbolics.getparent(x, nothing) - if p === nothing + if operation(p) === getindex + p = arguments(p)[1] bounds = Symbolics.getmetadata(x, VariableBounds, (-Inf, Inf)) - if symbolic_type(x) == ArraySymbolic() && Symbolics.shape(x) != Symbolics.Unknown() + if symbolic_type(x) == ArraySymbolic() && symbolic_has_known_size(x) bounds = map(bounds) do b b isa AbstractArray && return b return fill(b, size(x)) @@ -297,7 +308,7 @@ function getbounds(x::Union{Num, Symbolics.Arr, SymbolicUtils.Symbolic}) idxs = arguments(x)[2:end] bounds = map(bounds) do b if b isa AbstractArray - if Symbolics.shape(p) != Symbolics.Unknown() && size(p) != size(b) + if symbolic_has_known_size(p) && size(p) != size(b) throw(DimensionMismatch("Expected array variable $p with shape $(size(p)) to have bounds of identical size. Found $bounds of size $(size(bounds)).")) end return b[idxs...] @@ -339,9 +350,7 @@ isdisturbance(x::Num) = isdisturbance(Symbolics.unwrap(x)) Determine whether symbolic variable `x` is marked as a disturbance input. """ function isdisturbance(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDisturbance, false) + isvarkind(VariableDisturbance, x)::Bool end setdisturbance(x, v) = setmetadata(x, VariableDisturbance, v) @@ -372,9 +381,7 @@ Create a tunable parameter by See also [`tunable_parameters`](@ref), [`getbounds`](@ref) """ function istunable(x, default = true) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableTunable, default) + isvarkind(VariableTunable, x, default)::Bool end ## Dist ======================================================================== @@ -398,9 +405,7 @@ getdist(u) # retrieve distribution ``` """ function getdist(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDistribution, nothing) + safe_getmetadata(VariableDistribution, x, nothing) end """ @@ -492,9 +497,7 @@ getdescription(x::Symbolics.Arr) = getdescription(Symbolics.unwrap(x)) Return any description attached to variables `x`. If no description is attached, an empty string is returned. """ function getdescription(x) - p = Symbolics.getparent(x, nothing) - p === nothing || (x = p) - Symbolics.getmetadata(x, VariableDescription, "") + safe_getmetadata(VariableDescription, x, "") end """ @@ -512,7 +515,7 @@ end Maps the brownianiable to an unknown. """ -tobrownian(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN) +tobrownian(s::SymbolicT) = setmetadata(s, MTKVariableTypeCtx, BROWNIAN) tobrownian(s::Num) = Num(tobrownian(value(s))) isbrownian(s) = getvariabletype(s) === BROWNIAN @@ -526,10 +529,10 @@ macro brownians(xs...) x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$ || Meta.isexpr(x, :$), xs) || error("@brownians only takes scalar expressions!") - Symbolics._parse_vars(:brownian, + Symbolics.parse_vars(:brownian, Real, xs, - tobrownian) |> esc + tobrownian) end ## Guess ====================================================================== @@ -587,7 +590,7 @@ Fetch any miscellaneous data associated with symbolic variable `x`. See also [`hasmisc(x)`](@ref). """ getmisc(x::Num) = getmisc(unwrap(x)) -getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing) +getmisc(x::SymbolicT) = Symbolics.getmetadata(x, VariableMisc, nothing) """ hasmisc(x) @@ -606,7 +609,7 @@ setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata) Fetch the unit associated with variable `x`. This function is a metadata getter for an individual variable, while `get_unit` is used for unit inference on more complicated sdymbolic expressions. """ getunit(x::Num) = getunit(unwrap(x)) -getunit(x::Symbolic) = Symbolics.getmetadata(x, VariableUnit, nothing) +getunit(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnit, nothing) """ hasunit(x) @@ -615,10 +618,10 @@ Check if the variable `x` has a unit. hasunit(x) = getunit(x) !== nothing getunshifted(x::Num) = getunshifted(unwrap(x)) -getunshifted(x::Symbolic) = Symbolics.getmetadata(x, VariableUnshifted, nothing) +getunshifted(x::SymbolicT) = Symbolics.getmetadata(x, VariableUnshifted, nothing)::Union{SymbolicT, Nothing} getshift(x::Num) = getshift(unwrap(x)) -getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) +getshift(x::SymbolicT) = Symbolics.getmetadata(x, VariableShift, 0)::Int ################### ### Evaluate at ### @@ -629,7 +632,7 @@ getshift(x::Symbolic) = Symbolics.getmetadata(x, VariableShift, 0) An operator that evaluates time-dependent variables at a specific absolute time point `t`. # Fields -- `t::Union{Symbolic, Number}`: The absolute time at which to evaluate the variable. +- `t::Union{SymbolicT, Number}`: The absolute time at which to evaluate the variable. # Description `EvalAt` is used to evaluate time-dependent variables at a specific time point. This is particularly @@ -677,12 +680,12 @@ end See also: [`Differential`](@ref) """ struct EvalAt <: Symbolics.Operator - t::Union{Symbolic, Number} + t::Union{SymbolicT, Number} end -function (A::EvalAt)(x::Symbolic) +function (A::EvalAt)(x::SymbolicT) if symbolic_type(x) == NotSymbolic() || !iscall(x) - if x isa Symbolics.CallWithMetadata + if x isa CallAndWrap return x(A.t) else return x diff --git a/test/model_parsing.jl b/test/model_parsing.jl index 2c713d4149..eefd155db6 100644 --- a/test/model_parsing.jl +++ b/test/model_parsing.jl @@ -201,8 +201,7 @@ resistor = getproperty(rc, :resistor; namespace = false) @named pi_model = PiModel() - @test typeof(ModelingToolkit.getdefault(pi_model.p)) <: - SymbolicUtils.BasicSymbolic{Irrational} + @test symtype(ModelingToolkit.getdefault(pi_model.p)) <: Irrational @test getdefault(getdefault(pi_model.p)) == π end @@ -1007,7 +1006,7 @@ end vars = Symbolics.get_variables(only(equations(ex))) @test length(vars) == 2 for u in Symbolics.unwrap.(unknowns(ex)) - @test !Symbolics.hasmetadata(u, Symbolics.CallWithParent) + @test !SymbolicUtils.is_function_symbolic(u) @test any(isequal(u), vars) end end diff --git a/test/odesystem.jl b/test/odesystem.jl index cf16b31a42..417eaa70a7 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -535,7 +535,7 @@ sys = complete(sys) us = map(s -> (@variables $s(t))[1], syms) ps = map(s -> (@variables $s(t))[1], syms_p) buffer, = @variables $buffername[1:length(u0)] - dummy_var = Sym{Any}(:_) # this is safe because _ cannot be a rvalue in Julia + dummy_var = Symbolics.SSym(:_; type = Any) # this is safe because _ cannot be a rvalue in Julia ss = Iterators.flatten((us, ps)) vv = Iterators.flatten((u0, p0)) diff --git a/test/optimizationsystem.jl b/test/optimizationsystem.jl index 30e80e1cec..c29b68b9be 100644 --- a/test/optimizationsystem.jl +++ b/test/optimizationsystem.jl @@ -399,7 +399,7 @@ end prob = OptimizationProblem(sys, [x => [42.0, 12.37]]; hess = true, sparse = true) symbolic_hess = Symbolics.hessian(cost(sys), x) - symbolic_hess_value = Symbolics.fast_substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]])) + symbolic_hess_value = substitute(symbolic_hess, Dict(x[1] => prob[x[1]], x[2] => prob[x[2]])) oop_hess = prob.f.hess(prob.u0, prob.p) @test oop_hess ≈ symbolic_hess_value diff --git a/test/sciml_problem_inputs.jl b/test/sciml_problem_inputs.jl index a91a8d8c7c..cdb83ad0e2 100644 --- a/test/sciml_problem_inputs.jl +++ b/test/sciml_problem_inputs.jl @@ -2,7 +2,7 @@ # Fetch packages using ModelingToolkit, JumpProcesses, NonlinearSolve, OrdinaryDiffEq, StaticArrays, - SteadyStateDiffEq, StochasticDiffEq, SciMLBase, Test + SteadyStateDiffEq, StochasticDiffEq, SciMLBase, Test, SymbolicUtils using ModelingToolkit: t_nounits as t, D_nounits as D # Sets rnd number. @@ -29,7 +29,7 @@ begin ] noise_eqs = fill(0.01, 3, 6) jumps = [ - MassActionJump(kp, Pair{Symbolics.BasicSymbolic{Real}, Int64}[], [X => 1]), + MassActionJump(kp, Pair{Symbolics.SymbolicT, Int64}[], [X => 1]), MassActionJump(kd, [X => 1], [X => -1]), MassActionJump(k1, [X => 1], [X => -1, Y => 1]), MassActionJump(k2, [Y => 1], [X => 1, Y => -1]), diff --git a/test/simplify.jl b/test/simplify.jl index 4252e3262e..6968883132 100644 --- a/test/simplify.jl +++ b/test/simplify.jl @@ -1,5 +1,7 @@ using ModelingToolkit using ModelingToolkit: value +using Symbolics: STerm +import SymbolicUtils using Test @independent_variables t @@ -11,7 +13,7 @@ null_op = 0 * t one_op = 1 * t @test isequal(simplify(one_op), t) -identity_op = Num(Term(identity, [value(x)])) +identity_op = Num(STerm(identity, [value(x)]; type = Real, shape = SymbolicUtils.ShapeVecT())) @test isequal(simplify(identity_op), x) minus_op = -x diff --git a/test/variable_parsing.jl b/test/variable_parsing.jl index 60b4e24d64..a00aa2f729 100644 --- a/test/variable_parsing.jl +++ b/test/variable_parsing.jl @@ -2,15 +2,16 @@ using ModelingToolkit using Test using ModelingToolkit: value, Flow -using SymbolicUtils: FnType +using Symbolics: SSym +using SymbolicUtils: FnType, ShapeVecT @independent_variables t @variables x(t) y(t) # test multi-arg @variables z(t) # test single-arg -x1 = Num(Sym{FnType{Tuple{Any}, Real}}(:x)(value(t))) -y1 = Num(Sym{FnType{Tuple{Any}, Real}}(:y)(value(t))) -z1 = Num(Sym{FnType{Tuple{Any}, Real}}(:z)(value(t))) +x1 = Num(SSym(:x; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) +y1 = Num(SSym(:y; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) +z1 = Num(SSym(:z; type = FnType{Tuple{Any}, Real}, shape = ShapeVecT())(value(t))) @test isequal(x1, x) @test isequal(y1, y) @@ -22,9 +23,9 @@ z1 = Num(Sym{FnType{Tuple{Any}, Real}}(:z)(value(t))) end @parameters σ(..) -t1 = Num(Sym{Real}(:t)) -s1 = Num(Sym{Real}(:s)) -σ1 = Num(Sym{FnType{Tuple, Real}}(:σ)) +t1 = Num(SSym(:t; type = Real, shape = ShapeVecT())) +s1 = Num(SSym(:s; type = Real, shape = ShapeVecT())) +σ1 = Num(SSym(:σ; type = FnType{Tuple, Real}, shape = ShapeVecT())) @test isequal(t1, t) @test isequal(s1, s) @test isequal(σ1(t), σ(t))