diff --git a/src/matchers.jl b/src/matchers.jl index 6b36ca8c..263b352e 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -169,26 +169,66 @@ function term_matcher_constructor(term, acSets) return pow_term_matcher # if we want to do commutative checks, i.e. call matcher with different order of the arguments elseif acSets!==nothing && operation(term) in [+, *] + cond = false + if operation(term)==* + for a in arguments(term) + if iscall(a) && operation(a)===^ + cond = true + end + end + end + function commutative_term_matcher(success, data, bindings) !islist(data) && return nothing # if data is not a list, return nothing - !iscall(car(data)) && return nothing # if first element is not a call, return nothing - operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try + data = car(data) + !iscall(data) && return nothing # if first element is not a call, return nothing + if cond && (operation(data)===/) + nnmrt = arguments(data)[1] + den = arguments(data)[2] + if iscall(den) && operation(den)===^ + new_den = Term{symtype(nnmrt)}(^,[arguments(den)[1], -1*arguments(den)[2]]) + elseif iscall(den) && operation(den)===sqrt + new_den = Term{symtype(nnmrt)}(^,[arguments(den)[1], -1//2]) + elseif iscall(den) && operation(den)===* + new_den=[] + for a in arguments(den) + if iscall(a) && operation(a)===^ + push!(new_den, Term{symtype(nnmrt)}(^,[arguments(a)[1], -1*arguments(a)[2]])) + elseif iscall(a) && operation(a)===sqrt + push!(new_den, Term{symtype(nnmrt)}(^,[arguments(a)[1], -1//2])) + else + push!(new_den, Term{symtype(nnmrt)}(^,[a, -1])) + end + end + new_den = *(new_den...) + else + new_den = Term{symtype(nnmrt)}(^,[den, -1]) + end + if length(nnmrt) == 1 + data = *(new_den, nnmrt) + else + data = *(new_den, nnmrt...) + end + println("frankestein data: $data") + elseif operation(term) !== operation(data) + return nothing # if the operation of data is not the correct one, don't even try + end - T = symtype(car(data)) + T = symtype(data) if T <: Number - f = operation(car(data)) - data_args = arguments(car(data)) + f = operation(data) + data_args = arguments(data) for inds in acSets(eachindex(data_args), length(data_args)) candidate = Term{T}(f, @views data_args[inds]) - result = loop(candidate, bindings, matchers) + result = loop(candidate, bindings, matchers) result !== nothing && return success(result,1) end - # if car(data) does not subtype to number, it might not be commutative + # if data does not subtype to number, it might not be commutative else # call the normal matcher - result = loop(car(data), bindings, matchers) + result = loop(data, bindings, matchers) result !== nothing && return success(result, 1) end return nothing diff --git a/test/rewrite.jl b/test/rewrite.jl index 4b637eef..95662fd8 100644 --- a/test/rewrite.jl +++ b/test/rewrite.jl @@ -144,6 +144,20 @@ end r = @rule (~x + ~y)^(~m) => (~x, ~y, ~m) # rule to match (1/...)^(...) @test r((1/(a+b))^3) === (a,b,-3) + + + # neim problem + r_one = @rule a*b^(~n) => ~ + @test r_one(a/b)[:n] === -1 + @test r_one(a/b^3)[:n] === -3 + @test r_one(a/sqrt(b))[:n] === -1//2 + + r_two = @rule b^(~n)*c^(~m) => ~ + @test r_two(b^2/c)[:m] === -1 + @test r_two(b^2/c)[:n] === 2 + @test r_two(1/(b*sqrt(c)))[:n] === -1 + @test r_two(1/(b*sqrt(c)))[:m] === -1//2 + end @testset "Return the matches dictionary" begin