diff --git a/src/autodetection.c b/src/autodetection.c index f96bc06..e4e121f 100644 --- a/src/autodetection.c +++ b/src/autodetection.c @@ -96,12 +96,6 @@ const char * autodetect_symbol_suffix(void * handle, const char * suffix_hint) { * incorrect `N` to cause it to change its return value based on how it is interpreting arugments. */ int32_t autodetect_blas_interface(void * isamax_addr) { - if (env_lowercase_match("LBT_FORCE_INTERFACE", "ilp64")) { - return LBT_INTERFACE_ILP64; - } - if (env_lowercase_match("LBT_FORCE_INTERFACE", "lp64")) { - return LBT_INTERFACE_LP64; - } // Typecast to function pointer for easier usage below int64_t (*isamax)(int64_t *, float *, int64_t *) = isamax_addr; @@ -145,19 +139,42 @@ int32_t autodetect_blas_interface(void * isamax_addr) { return LBT_INTERFACE_UNKNOWN; } + /* - * If this is an LAPACK library, we'll check interface type by invoking `dpotrf` with a - * purposefully incorrect `lda` to cause it to store an error code that we can inspect - * and determine if the internal pointer dereferences were 32-bit or 64-bit. + * Attempt to figure out the integer size based on the ilaver() LAPACK function. */ -int32_t autodetect_lapack_interface(void * dpotrf_addr) { - if (env_lowercase_match("LBT_FORCE_INTERFACE", "ilp64")) { +int32_t autodetect_lapack_interface_ilaver(void * ilaver_addr) { + // Use the ilaver function to find the integer type since it will not + // print an error message or terminate the program when we are testing it. + void (*ilaver)(int64_t *, int64_t *, int64_t *) = ilaver_addr; + + // Force all 64 bits to be set to 1 + int64_t major = -1; + int64_t minor = -1; + int64_t patch = -1; + + ilaver(&major, &minor, &patch); + + if (major > 0) { + // The version number should be positive, so this means that the entire number was overwritten + // with the version number, so it stored 64-bits in a 64-bit slot. return LBT_INTERFACE_ILP64; } - if (env_lowercase_match("LBT_FORCE_INTERFACE", "lp64")) { + if (major < 0) { + // The version number should be positive, so this means that the upper bits of + // the integer weren't written to, leaving it a negative number. So it stored + // 32-bits in a 64-bit slot. return LBT_INTERFACE_LP64; } - // Typecast to function pointer for easier usage below + + // We have no idea what happened + return LBT_INTERFACE_UNKNOWN; +} + +/* + * Attempt to figure out the integer size based on the dpotrf() LAPACK function. + */ +int32_t autodetect_lapack_interface_dpotrf(void * dpotrf_addr) { void (*dpotrf)(char *, int64_t *, double *, int64_t *, int64_t *) = dpotrf_addr; // This `dpotrf` invocation should result in an error code stored into `info` @@ -177,7 +194,8 @@ int32_t autodetect_lapack_interface(void * dpotrf_addr) { // This is what it looks like when a library stores a 32-bit value in a 64-bit slot. return LBT_INTERFACE_LP64; } - // We have no idea what happened; `info` isn't any of the options we thought it would be. + + // We have no idea what happened return LBT_INTERFACE_UNKNOWN; } @@ -186,8 +204,38 @@ int32_t autodetect_lapack_interface(void * dpotrf_addr) { * Returns the values "32", "64" or "0", denoting the bitwidth of the internal index representation. */ int32_t autodetect_interface(void * handle, const char * suffix) { + if (env_lowercase_match("LBT_FORCE_INTERFACE", "ilp64")) { + return LBT_INTERFACE_ILP64; + } + if (env_lowercase_match("LBT_FORCE_INTERFACE", "lp64")) { + return LBT_INTERFACE_LP64; + } + char symbol_name[MAX_SYMBOL_LEN]; + /* + * The detection logic works as follows: + * 1) Invoke the ilaver function with a pointer to a 64-bit integer to see if the internal pointer + * dereferences were 32-bit or 64-bit. + * Requires LAPACK symbol ilaver() + * 2) Try giving a potentially-bad input (negative length) to the BLAS isamax() function and see what + * it does. + * Requires BLAS symbol isamax() + * 3) Invoke `dpotrf` with a purposefully incorrect `lda` to cause it to + * store an error code that we can inspect and determine if the internal pointer + * dereferences were 32-bit or 64-bit. + * Requires LAPACK symbol dpotrf() + */ + + // Attempt LAPACK `ilaver()` test + // We test this first because this will not rely on possibly bad behavior in the library + // (e.g., it won't make an error condition that we then test). + build_symbol_name(symbol_name, "ilaver_", suffix); + void * ilaver = lookup_symbol(handle, symbol_name); + if (ilaver != NULL) { + return autodetect_lapack_interface_ilaver(ilaver); + } + // Attempt BLAS `isamax()` test build_symbol_name(symbol_name, "isamax_", suffix); void * isamax = lookup_symbol(handle, symbol_name); @@ -196,10 +244,13 @@ int32_t autodetect_interface(void * handle, const char * suffix) { } // Attempt LAPACK `dpotrf()` test + // This is last because some LAPACK libraries have a STOP call in their + // error handler, and also will print an error. So we want to avoid triggering + // those. build_symbol_name(symbol_name, "dpotrf_", suffix); void * dpotrf = lookup_symbol(handle, symbol_name); - if (dpotrf != NULL) { - return autodetect_lapack_interface(dpotrf); + if (ilaver != NULL || dpotrf != NULL) { + return autodetect_lapack_interface_dpotrf(dpotrf); } // Otherwise, this is probably not an LAPACK or BLAS library?! @@ -244,7 +295,7 @@ int32_t autodetect_complex_return_style(void * handle, const char * suffix) { * First, check to see if `zdotc` zeros out the first argument if all arguments are zero. * Supposedly, most well-behaved implementations will return `0 + 0*I` if the length of * the inputs is zero; so if it is using a "return argument", that's a good way to find out. - * + * * We detect this by setting `retval` to an initial value of `-1` typecast to a complex * value. The floating-point values are unimportant as they will be written to, but if * it is interpreted as an `int{32,64}_t`, it will be a negative value (which is not diff --git a/src/libblastrampoline_internal.h b/src/libblastrampoline_internal.h index e6ca762..b5deef0 100644 --- a/src/libblastrampoline_internal.h +++ b/src/libblastrampoline_internal.h @@ -85,7 +85,8 @@ uint8_t env_match_bool(const char * env_name, uint8_t default_value); void build_symbol_name(char * out, const char *symbol_name, const char *suffix); const char * autodetect_symbol_suffix(void * handle, const char * suffix_hint); int32_t autodetect_blas_interface(void * isamax_addr); -int32_t autodetect_lapack_interface(void * dpotrf_addr); +int32_t autodetect_lapack_interface_dpotrf(void * dpotrf_addr); +int32_t autodetect_lapack_interface_ilaver(void * ilaver_addr); int32_t autodetect_interface(void * handle, const char * suffix); int32_t autodetect_complex_return_style(void * handle, const char * suffix); int32_t autodetect_f2c(void * handle, const char * suffix);