Skip to content

Commit ebd90c5

Browse files
Pi and e to Float32 and Float16 (#559)
* pi and e to Float32 and Float16 * Mimic Julia behaviour * Fix * OOPS * Automated definitions
1 parent de0e3f9 commit ebd90c5

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

src/device/intrinsics/math.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@
22

33
using Base: FastMath
44
using Base.Math: throw_complex_domainerror
5+
import Core: Float16, Float32
56

67
# TODO:
78
# - wrap all intrinsics from include/metal/metal_math
89
# - add support for vector types
910
# - consider emitting LLVM intrinsics and lowering those in the back-end
1011

12+
### Constants
13+
# π and ℯ
14+
for T in (:Float16,:Float32), R in (RoundUp, RoundDown), irr in (π, ℯ)
15+
@eval @device_override $T(::typeof($irr), ::typeof($R)) = $@eval($T($irr,$R))
16+
end
17+
1118
### Common Intrinsics
1219
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
1320
@device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)

test/device/intrinsics/math.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,21 @@ end
311311
ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal metal = v"3.0" nextafter_out_test()))
312312
@test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir)
313313
end
314+
315+
# Borrowed from the Julia "Irrationals compared with Rationals and Floats" testset
316+
@testset "Comparisons with $irr" for irr in (π, ℯ)
317+
@eval function convert_test(res)
318+
res[1] = $T($irr, RoundDown) < $irr
319+
res[2] = $T($irr, RoundUp) > $irr
320+
res[3] = !($T($irr, RoundDown) > $irr)
321+
res[4] = !($T($irr, RoundUp) < $irr)
322+
return nothing
323+
end
324+
325+
res = MtlArray(zeros(Bool, 4))
326+
Metal.@sync @metal convert_test(res)
327+
@test all(Array(res))
328+
end
314329
end
315330
end
316331

0 commit comments

Comments
 (0)