Skip to content

Commit 1b74420

Browse files
committed
Fix
1 parent 77c3e43 commit 1b74420

File tree

2 files changed

+25
-28
lines changed

2 files changed

+25
-28
lines changed

src/device/intrinsics/math.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,24 @@ using Base.Math: throw_complex_domainerror
1010

1111
### Constants
1212
# π
13-
@device_override Core.Float32(::typeof(π), ::RoundingMode) = reinterpret(Float32, 0x40490fdb) # 3.1415927f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000)))
14-
@device_override Core.Float32(::typeof(π), ::RoundingMode{:Down}) = reinterpret(Float32, 0x40490fda) # 3.1415925f0 prevfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x400921FB60000000))))
15-
@device_override Core.Float16(::typeof(π), ::RoundingMode{:Up}) = reinterpret(Float16, 0x4249) # Float16(3.143)
16-
@device_override Core.Float16(::typeof(π), ::RoundingMode) = reinterpret(Float16, 0x4248) # Float16(3.14)
13+
const M_PI_F = Float32(reinterpret(Float64, 0x400921FB60000000))
14+
const M_PI_H = reinterpret(Float16, 0x4248)
15+
@eval begin
16+
@device_override Core.Float32(::typeof(π), ::RoundingMode) = $M_PI_F
17+
@device_override Core.Float32(::typeof(π), ::RoundingMode{:Down}) = $(prevfloat(M_PI_F))
18+
@device_override Core.Float16(::typeof(π), ::RoundingMode{:Up}) = $(nextfloat(M_PI_H))
19+
@device_override Core.Float16(::typeof(π), ::RoundingMode) = $M_PI_H
20+
end
1721

1822
#
19-
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode{:Up}) = reinterpret(Float32, 0x402df855) # 2.718282f0 nextfloat(reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000))))
20-
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode) = reinterpret(Float32, 0x402df854) # 2.7182817f0 reinterpret(UInt32,Float32(reinterpret(Float64,0x4005BF0A80000000)))
21-
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode) = reinterpret(Float16, 0x4170) # Float16(2.719)
22-
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode{:Down}) = reinterpret(Float16, 0x416f) # Float16(2.717)
23+
const M_E_F = Float32(reinterpret(Float64, 0x4005BF0A80000000))
24+
const M_E_H = reinterpret(Float16, 0x4170)
25+
@eval begin
26+
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode{:Up}) = $(nextfloat(M_E_F))
27+
@device_override Core.Float32(::typeof(ℯ), ::RoundingMode) = $M_E_F
28+
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode) = $M_E_H
29+
@device_override Core.Float16(::typeof(ℯ), ::RoundingMode{:Down}) = $(prevfloat(M_E_H))
30+
end
2331

2432
### Common Intrinsics
2533
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)

test/device/intrinsics/math.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -314,29 +314,18 @@ end
314314

315315
# Borrowed from the Julia "Irrationals compared with Rationals and Floats" testset
316316
@testset "Comparisons with $irr" for irr in (π, ℯ)
317-
function convert_test_32(res)
318-
res[1] = Float32(irr,RoundDown) < irr
319-
res[2] = Float32(irr,RoundUp) > irr
320-
res[3] = !(Float32(irr,RoundDown) > irr)
321-
res[4] = !(Float32(irr,RoundUp) < irr)
317+
@eval function convert_test(res)
318+
res[1] = $T($irr, RoundDown) < $irr
319+
res[2] = $T($irr, RoundUp) > $irr
320+
res[3] = !($T($irr, RoundDown) > $irr)
321+
res[4] = !($T($irr, RoundUp) < $irr)
322322
return nothing
323323
end
324324

325-
res_32 = MtlArray(zeros(Bool,4))
326-
Metal.@sync @metal convert_test_32(res_32)
327-
@test all(Array(res_32))
328-
329-
function convert_test_16(res)
330-
res[1] = Float16(irr,RoundDown) < irr
331-
res[2] = Float16(irr,RoundUp) > irr
332-
res[3] = !(Float16(irr,RoundDown) > irr)
333-
res[4] = !(Float16(irr,RoundUp) < irr)
334-
return nothing
335-
end
336-
337-
res_16 = MtlArray(zeros(Bool,4))
338-
Metal.@sync @metal convert_test_16(res_16)
339-
@test all(Array(res_16))
325+
res = MtlArray(zeros(Bool, 4))
326+
@device_code_llvm @metal launch = false convert_test(res)
327+
Metal.@sync @metal convert_test(res)
328+
@test all(Array(res))
340329
end
341330
end
342331
end

0 commit comments

Comments
 (0)