Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,66 @@ function term_matcher_constructor(term, acSets)
return pow_term_matcher
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
elseif acSets!==nothing && operation(term) in [+, *]
cond = false
if operation(term)==*
for a in arguments(term)
if iscall(a) && operation(a)===^
cond = true
end
end
end

function commutative_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try
data = car(data)
!iscall(data) && return nothing # if first element is not a call, return nothing
if cond && (operation(data)===/)
nnmrt = arguments(data)[1]
den = arguments(data)[2]
if iscall(den) && operation(den)===^
new_den = Term{symtype(nnmrt)}(^,[arguments(den)[1], -1*arguments(den)[2]])
elseif iscall(den) && operation(den)===sqrt
new_den = Term{symtype(nnmrt)}(^,[arguments(den)[1], -1//2])
elseif iscall(den) && operation(den)===*
new_den=[]
for a in arguments(den)
if iscall(a) && operation(a)===^
push!(new_den, Term{symtype(nnmrt)}(^,[arguments(a)[1], -1*arguments(a)[2]]))
elseif iscall(a) && operation(a)===sqrt
push!(new_den, Term{symtype(nnmrt)}(^,[arguments(a)[1], -1//2]))
else
push!(new_den, Term{symtype(nnmrt)}(^,[a, -1]))
end
end
new_den = *(new_den...)
else
new_den = Term{symtype(nnmrt)}(^,[den, -1])
end
if length(nnmrt) == 1
data = *(new_den, nnmrt)
else
data = *(new_den, nnmrt...)
end
println("frankestein data: $data")
elseif operation(term) !== operation(data)
return nothing # if the operation of data is not the correct one, don't even try
end

T = symtype(car(data))
T = symtype(data)
if T <: Number
f = operation(car(data))
data_args = arguments(car(data))
f = operation(data)
data_args = arguments(data)

for inds in acSets(eachindex(data_args), length(data_args))
candidate = Term{T}(f, @views data_args[inds])

result = loop(candidate, bindings, matchers)
result = loop(candidate, bindings, matchers)
result !== nothing && return success(result,1)
end
# if car(data) does not subtype to number, it might not be commutative
# if data does not subtype to number, it might not be commutative
else
# call the normal matcher
result = loop(car(data), bindings, matchers)
result = loop(data, bindings, matchers)
result !== nothing && return success(result, 1)
end
return nothing
Expand Down
14 changes: 14 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ end

r = @rule (~x + ~y)^(~m) => (~x, ~y, ~m) # rule to match (1/...)^(...)
@test r((1/(a+b))^3) === (a,b,-3)


# neim problem
r_one = @rule a*b^(~n) => ~
@test r_one(a/b)[:n] === -1
@test r_one(a/b^3)[:n] === -3
@test r_one(a/sqrt(b))[:n] === -1//2

r_two = @rule b^(~n)*c^(~m) => ~
@test r_two(b^2/c)[:m] === -1
@test r_two(b^2/c)[:n] === 2
@test r_two(1/(b*sqrt(c)))[:n] === -1
@test r_two(1/(b*sqrt(c)))[:m] === -1//2

end

@testset "Return the matches dictionary" begin
Expand Down
Loading