From 22974ecb36688a21b3aa3ebbe038e7be87d963d1 Mon Sep 17 00:00:00 2001 From: wuwbobo2021 Date: Tue, 15 Jul 2025 10:29:11 +0800 Subject: [PATCH 1/5] Use macros for `Env` methods --- java-spaghetti/src/env.rs | 479 ++++++++++---------------------------- 1 file changed, 119 insertions(+), 360 deletions(-) diff --git a/java-spaghetti/src/env.rs b/java-spaghetti/src/env.rs index bb945dd..6b79614 100644 --- a/java-spaghetti/src/env.rs +++ b/java-spaghetti/src/env.rs @@ -159,6 +159,11 @@ impl<'env> Env<'env> { } } + pub fn throw(self, throwable: &Ref) { + let res = unsafe { ((**self.env).v1_2.Throw)(self.env, throwable.as_raw()) }; + assert_eq!(res, 0); + } + unsafe fn exception_to_string(self, exception: jobject) -> String { static METHOD_GET_MESSAGE: OnceLock = OnceLock::new(); let throwable_get_message = *METHOD_GET_MESSAGE.get_or_init(|| { @@ -322,42 +327,76 @@ impl<'env> Env<'env> { } res } +} - // Multi-Query Methods - // XXX: Remove these unused functions. +macro_rules! call_primitive_method_a { + ($name:ident, $ret_type:ident, $call:ident) => { + pub unsafe fn $name( + self, + this: jobject, + method: jmethodID, + args: *const jvalue, + ) -> Result<$ret_type, Local<'env, E>> { + let result = ((**self.env).v1_2.$call)(self.env, this, method, args); + self.exception_check()?; + Ok(result) + } + }; +} - pub unsafe fn require_class_method(self, class: &CStr, method: &CStr, descriptor: &CStr) -> (jclass, jmethodID) { - let class = self.require_class(class); - (class, self.require_method(class, method, descriptor)) - } +macro_rules! call_static_primitive_method_a { + ($name:ident, $ret_type:ident, $call:ident) => { + pub unsafe fn $name( + self, + class: jclass, + method: jmethodID, + args: *const jvalue, + ) -> Result<$ret_type, Local<'env, E>> { + let result = ((**self.env).v1_2.$call)(self.env, class, method, args); + self.exception_check()?; + Ok(result) + } + }; +} - pub unsafe fn require_class_static_method( - self, - class: &CStr, - method: &CStr, - descriptor: &CStr, - ) -> (jclass, jmethodID) { - let class = self.require_class(class); - (class, self.require_static_method(class, method, descriptor)) - } +macro_rules! get_primitive_field { + ($name:ident, $ret_type:ident, $call:ident) => { + pub unsafe fn $name(self, this: jobject, field: jfieldID) -> $ret_type { + ((**self.env).v1_2.$call)(self.env, this, field) + } + }; +} - pub unsafe fn require_class_field(self, class: &CStr, method: &CStr, descriptor: &CStr) -> (jclass, jfieldID) { - let class = self.require_class(class); - (class, self.require_field(class, method, descriptor)) - } +macro_rules! set_primitive_field { + ($name:ident, $arg_type:ident, $call:ident) => { + pub unsafe fn $name(self, this: jobject, field: jfieldID, value: $arg_type) { + ((**self.env).v1_2.$call)(self.env, this, field, value); + } + }; +} - pub unsafe fn require_class_static_field( - self, - class: &CStr, - method: &CStr, - descriptor: &CStr, - ) -> (jclass, jfieldID) { - let class = self.require_class(class); - (class, self.require_static_field(class, method, descriptor)) - } +macro_rules! get_static_primitive_field { + ($name:ident, $ret_type:ident, $call:ident) => { + pub unsafe fn $name(self, class: jclass, field: jfieldID) -> $ret_type { + ((**self.env).v1_2.$call)(self.env, class, field) + } + }; +} + +macro_rules! set_static_primitive_field { + ($name:ident, $arg_type:ident, $call:ident) => { + pub unsafe fn $name(self, class: jclass, field: jfieldID, value: $arg_type) { + ((**self.env).v1_2.$call)(self.env, class, field, value); + } + }; +} - // Constructor Methods +#[allow(non_camel_case_types)] +type void = (); +#[allow(clippy::missing_safety_doc)] +#[allow(unsafe_op_in_unsafe_fn)] +impl<'env> Env<'env> { pub unsafe fn new_object_a( self, class: jclass, @@ -386,104 +425,16 @@ impl<'env> Env<'env> { Ok(Some(Local::from_raw(self, result))) } } - - pub unsafe fn call_boolean_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallBooleanMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result != JNI_FALSE) - } - - pub unsafe fn call_byte_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallByteMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_char_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallCharMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_short_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallShortMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_int_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallIntMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_long_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallLongMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_float_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallFloatMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_double_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallDoubleMethodA)(self.env, this, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_void_method_a( - self, - this: jobject, - method: jmethodID, - args: *const jvalue, - ) -> Result<(), Local<'env, E>> { - ((**self.env).v1_2.CallVoidMethodA)(self.env, this, method, args); - self.exception_check() - } + // See `pub type jboolean = bool;` in `jni_sys` 0.4. + call_primitive_method_a! { call_boolean_method_a, bool, CallBooleanMethodA } + call_primitive_method_a! { call_byte_method_a, jbyte, CallByteMethodA } + call_primitive_method_a! { call_char_method_a, jchar, CallCharMethodA } + call_primitive_method_a! { call_short_method_a, jshort, CallShortMethodA } + call_primitive_method_a! { call_int_method_a, jint, CallIntMethodA } + call_primitive_method_a! { call_long_method_a, jlong, CallLongMethodA } + call_primitive_method_a! { call_float_method_a, jfloat, CallFloatMethodA } + call_primitive_method_a! { call_double_method_a, jdouble, CallDoubleMethodA } + call_primitive_method_a! { call_void_method_a, void, CallVoidMethodA } // Static Methods @@ -501,104 +452,15 @@ impl<'env> Env<'env> { Ok(Some(Local::from_raw(self, result))) } } - - pub unsafe fn call_static_boolean_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticBooleanMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result != JNI_FALSE) - } - - pub unsafe fn call_static_byte_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticByteMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_char_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticCharMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_short_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticShortMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_int_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticIntMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_long_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticLongMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_float_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticFloatMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_double_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result> { - let result = ((**self.env).v1_2.CallStaticDoubleMethodA)(self.env, class, method, args); - self.exception_check()?; - Ok(result) - } - - pub unsafe fn call_static_void_method_a( - self, - class: jclass, - method: jmethodID, - args: *const jvalue, - ) -> Result<(), Local<'env, E>> { - ((**self.env).v1_2.CallStaticVoidMethodA)(self.env, class, method, args); - self.exception_check() - } + call_static_primitive_method_a! { call_static_boolean_method_a, bool, CallStaticBooleanMethodA } + call_static_primitive_method_a! { call_static_byte_method_a, jbyte, CallStaticByteMethodA } + call_static_primitive_method_a! { call_static_char_method_a, jchar, CallStaticCharMethodA } + call_static_primitive_method_a! { call_static_short_method_a, jshort, CallStaticShortMethodA } + call_static_primitive_method_a! { call_static_int_method_a, jint, CallStaticIntMethodA } + call_static_primitive_method_a! { call_static_long_method_a, jlong, CallStaticLongMethodA } + call_static_primitive_method_a! { call_static_float_method_a, jfloat, CallStaticFloatMethodA } + call_static_primitive_method_a! { call_static_double_method_a, jdouble, CallStaticDoubleMethodA } + call_static_primitive_method_a! { call_static_void_method_a, void, CallStaticVoidMethodA } // Instance Fields @@ -610,75 +472,26 @@ impl<'env> Env<'env> { Some(Local::from_raw(self, result)) } } - - pub unsafe fn get_boolean_field(self, this: jobject, field: jfieldID) -> bool { - let result = ((**self.env).v1_2.GetBooleanField)(self.env, this, field); - result != JNI_FALSE - } - - pub unsafe fn get_byte_field(self, this: jobject, field: jfieldID) -> jbyte { - ((**self.env).v1_2.GetByteField)(self.env, this, field) - } - - pub unsafe fn get_char_field(self, this: jobject, field: jfieldID) -> jchar { - ((**self.env).v1_2.GetCharField)(self.env, this, field) - } - - pub unsafe fn get_short_field(self, this: jobject, field: jfieldID) -> jshort { - ((**self.env).v1_2.GetShortField)(self.env, this, field) - } - - pub unsafe fn get_int_field(self, this: jobject, field: jfieldID) -> jint { - ((**self.env).v1_2.GetIntField)(self.env, this, field) - } - - pub unsafe fn get_long_field(self, this: jobject, field: jfieldID) -> jlong { - ((**self.env).v1_2.GetLongField)(self.env, this, field) - } - - pub unsafe fn get_float_field(self, this: jobject, field: jfieldID) -> jfloat { - ((**self.env).v1_2.GetFloatField)(self.env, this, field) - } - - pub unsafe fn get_double_field(self, this: jobject, field: jfieldID) -> jdouble { - ((**self.env).v1_2.GetDoubleField)(self.env, this, field) - } + get_primitive_field! { get_boolean_field, bool, GetBooleanField } + get_primitive_field! { get_byte_field, jbyte, GetByteField } + get_primitive_field! { get_char_field, jchar, GetCharField } + get_primitive_field! { get_short_field, jshort, GetShortField } + get_primitive_field! { get_int_field, jint, GetIntField } + get_primitive_field! { get_long_field, jlong, GetLongField } + get_primitive_field! { get_float_field, jfloat, GetFloatField } + get_primitive_field! { get_double_field, jdouble, GetDoubleField } pub unsafe fn set_object_field(self, this: jobject, field: jfieldID, value: impl AsArg) { ((**self.env).v1_2.SetObjectField)(self.env, this, field, value.as_arg()); } - - pub unsafe fn set_boolean_field(self, this: jobject, field: jfieldID, value: bool) { - ((**self.env).v1_2.SetBooleanField)(self.env, this, field, if value { JNI_TRUE } else { JNI_FALSE }); - } - - pub unsafe fn set_byte_field(self, this: jobject, field: jfieldID, value: jbyte) { - ((**self.env).v1_2.SetByteField)(self.env, this, field, value); - } - - pub unsafe fn set_char_field(self, this: jobject, field: jfieldID, value: jchar) { - ((**self.env).v1_2.SetCharField)(self.env, this, field, value); - } - - pub unsafe fn set_short_field(self, this: jobject, field: jfieldID, value: jshort) { - ((**self.env).v1_2.SetShortField)(self.env, this, field, value); - } - - pub unsafe fn set_int_field(self, this: jobject, field: jfieldID, value: jint) { - ((**self.env).v1_2.SetIntField)(self.env, this, field, value); - } - - pub unsafe fn set_long_field(self, this: jobject, field: jfieldID, value: jlong) { - ((**self.env).v1_2.SetLongField)(self.env, this, field, value); - } - - pub unsafe fn set_float_field(self, this: jobject, field: jfieldID, value: jfloat) { - ((**self.env).v1_2.SetFloatField)(self.env, this, field, value); - } - - pub unsafe fn set_double_field(self, this: jobject, field: jfieldID, value: jdouble) { - ((**self.env).v1_2.SetDoubleField)(self.env, this, field, value); - } + set_primitive_field! { set_boolean_field, bool, SetBooleanField } + set_primitive_field! { set_byte_field, jbyte, SetByteField } + set_primitive_field! { set_char_field, jchar, SetCharField } + set_primitive_field! { set_short_field, jshort, SetShortField } + set_primitive_field! { set_int_field, jint, SetIntField } + set_primitive_field! { set_long_field, jlong, SetLongField } + set_primitive_field! { set_float_field, jfloat, SetFloatField } + set_primitive_field! { set_double_field, jdouble, SetDoubleField } // Static Fields @@ -694,39 +507,14 @@ impl<'env> Env<'env> { Some(Local::from_raw(self, result)) } } - - pub unsafe fn get_static_boolean_field(self, class: jclass, field: jfieldID) -> bool { - let result = ((**self.env).v1_2.GetStaticBooleanField)(self.env, class, field); - result != JNI_FALSE - } - - pub unsafe fn get_static_byte_field(self, class: jclass, field: jfieldID) -> jbyte { - ((**self.env).v1_2.GetStaticByteField)(self.env, class, field) - } - - pub unsafe fn get_static_char_field(self, class: jclass, field: jfieldID) -> jchar { - ((**self.env).v1_2.GetStaticCharField)(self.env, class, field) - } - - pub unsafe fn get_static_short_field(self, class: jclass, field: jfieldID) -> jshort { - ((**self.env).v1_2.GetStaticShortField)(self.env, class, field) - } - - pub unsafe fn get_static_int_field(self, class: jclass, field: jfieldID) -> jint { - ((**self.env).v1_2.GetStaticIntField)(self.env, class, field) - } - - pub unsafe fn get_static_long_field(self, class: jclass, field: jfieldID) -> jlong { - ((**self.env).v1_2.GetStaticLongField)(self.env, class, field) - } - - pub unsafe fn get_static_float_field(self, class: jclass, field: jfieldID) -> jfloat { - ((**self.env).v1_2.GetStaticFloatField)(self.env, class, field) - } - - pub unsafe fn get_static_double_field(self, class: jclass, field: jfieldID) -> jdouble { - ((**self.env).v1_2.GetStaticDoubleField)(self.env, class, field) - } + get_static_primitive_field! { get_static_boolean_field, bool, GetStaticBooleanField } + get_static_primitive_field! { get_static_byte_field, jbyte, GetStaticByteField } + get_static_primitive_field! { get_static_char_field, jchar, GetStaticCharField } + get_static_primitive_field! { get_static_short_field, jshort, GetStaticShortField } + get_static_primitive_field! { get_static_int_field, jint, GetStaticIntField } + get_static_primitive_field! { get_static_long_field, jlong, GetStaticLongField } + get_static_primitive_field! { get_static_float_field, jfloat, GetStaticFloatField } + get_static_primitive_field! { get_static_double_field, jdouble, GetStaticDoubleField } pub unsafe fn set_static_object_field( self, @@ -736,41 +524,12 @@ impl<'env> Env<'env> { ) { ((**self.env).v1_2.SetStaticObjectField)(self.env, class, field, value.as_arg()); } - - pub unsafe fn set_static_boolean_field(self, class: jclass, field: jfieldID, value: bool) { - ((**self.env).v1_2.SetStaticBooleanField)(self.env, class, field, if value { JNI_TRUE } else { JNI_FALSE }); - } - - pub unsafe fn set_static_byte_field(self, class: jclass, field: jfieldID, value: jbyte) { - ((**self.env).v1_2.SetStaticByteField)(self.env, class, field, value); - } - - pub unsafe fn set_static_char_field(self, class: jclass, field: jfieldID, value: jchar) { - ((**self.env).v1_2.SetStaticCharField)(self.env, class, field, value); - } - - pub unsafe fn set_static_short_field(self, class: jclass, field: jfieldID, value: jshort) { - ((**self.env).v1_2.SetStaticShortField)(self.env, class, field, value); - } - - pub unsafe fn set_static_int_field(self, class: jclass, field: jfieldID, value: jint) { - ((**self.env).v1_2.SetStaticIntField)(self.env, class, field, value); - } - - pub unsafe fn set_static_long_field(self, class: jclass, field: jfieldID, value: jlong) { - ((**self.env).v1_2.SetStaticLongField)(self.env, class, field, value); - } - - pub unsafe fn set_static_float_field(self, class: jclass, field: jfieldID, value: jfloat) { - ((**self.env).v1_2.SetStaticFloatField)(self.env, class, field, value); - } - - pub unsafe fn set_static_double_field(self, class: jclass, field: jfieldID, value: jdouble) { - ((**self.env).v1_2.SetStaticDoubleField)(self.env, class, field, value); - } - - pub fn throw(self, throwable: &Ref) { - let res = unsafe { ((**self.env).v1_2.Throw)(self.env, throwable.as_raw()) }; - assert_eq!(res, 0); - } + set_static_primitive_field! { set_static_boolean_field, bool, SetStaticBooleanField } + set_static_primitive_field! { set_static_byte_field, jbyte, SetStaticByteField } + set_static_primitive_field! { set_static_char_field, jchar, SetStaticCharField } + set_static_primitive_field! { set_static_short_field, jshort, SetStaticShortField } + set_static_primitive_field! { set_static_int_field, jint, SetStaticIntField } + set_static_primitive_field! { set_static_long_field, jlong, SetStaticLongField } + set_static_primitive_field! { set_static_float_field, jfloat, SetStaticFloatField } + set_static_primitive_field! { set_static_double_field, jdouble, SetStaticDoubleField } } From 39854c8d4525cdb99a05cf7aad1e0264045d3cfa Mon Sep 17 00:00:00 2001 From: wuwbobo2021 Date: Tue, 15 Jul 2025 13:54:44 +0800 Subject: [PATCH 2/5] Simplify exception checking in `Env` methods --- java-spaghetti/src/array.rs | 3 +- java-spaghetti/src/env.rs | 144 +++++++++++++---------------- java-spaghetti/src/refs/ref_.rs | 7 +- java-spaghetti/src/string_chars.rs | 2 + 4 files changed, 71 insertions(+), 85 deletions(-) diff --git a/java-spaghetti/src/array.rs b/java-spaghetti/src/array.rs index f9beb56..e3c4c0f 100644 --- a/java-spaghetti/src/array.rs +++ b/java-spaghetti/src/array.rs @@ -110,8 +110,7 @@ macro_rules! primitive_array { let jnienv = env.as_raw(); unsafe { let object = ((**jnienv).v1_2.$new_array)(jnienv, size); - let exception = ((**jnienv).v1_2.ExceptionOccurred)(jnienv); - assert!(exception.is_null()); // Only sane exception here is an OOM exception + env.exception_check_raw().expect("OOM"); Local::from_raw(env, object) } } diff --git a/java-spaghetti/src/env.rs b/java-spaghetti/src/env.rs index 6b79614..94f1559 100644 --- a/java-spaghetti/src/env.rs +++ b/java-spaghetti/src/env.rs @@ -6,19 +6,19 @@ use std::sync::atomic::{AtomicPtr, Ordering}; use jni_sys::*; -use crate::{AsArg, Local, Ref, ReferenceType, StringChars, ThrowableType, VM}; +use crate::{AsArg, JMethodID, Local, Ref, ReferenceType, StringChars, ThrowableType, VM}; /// FFI: Use **Env** instead of `*const JNIEnv`. This represents a per-thread Java exection environment. /// /// A "safe" alternative to `jni_sys::JNIEnv` raw pointers, with the following caveats: /// -/// 1) A null env will result in **undefined behavior**. Java should not be invoking your native functions with a null -/// *mut JNIEnv, however, so I don't believe this is a problem in practice unless you've bindgened the C header +/// 1) A null `env` will result in **undefined behavior**. Java should not be invoking your native functions with a null +/// `*mut JNIEnv`, however, so I don't believe this is a problem in practice unless you've bindgened the C header /// definitions elsewhere, calling them (requiring `unsafe`), and passing null pointers (generally UB for JNI /// functions anyways, so can be seen as a caller soundness issue.) /// -/// 2) Allowing the underlying JNIEnv to be modified is **undefined behavior**. I don't believe the JNI libraries -/// modify the JNIEnv, so as long as you're not accepting a *mut JNIEnv elsewhere, using unsafe to dereference it, +/// 2) Allowing the underlying `JNIEnv` to be modified is **undefined behavior**. I don't believe the JNI libraries +/// modify the `JNIEnv`, so as long as you're not accepting a `*mut JNIEnv` elsewhere, using `unsafe` to dereference it, /// and mucking with the methods on it yourself, I believe this "should" be fine. /// /// Most methods of `Env` are supposed to be used by generated bindings. @@ -148,13 +148,19 @@ impl<'env> Env<'env> { /// Note that there is `ExceptionCheck` in JNI functions, which does not create a /// local reference to the exception object. pub(crate) fn exception_check(self) -> Result<(), Local<'env, E>> { + self.exception_check_raw() + .map_err(|throwable| unsafe { Local::from_raw(self, throwable) }) + } + + /// The same as `exception_check`, except that it may return a raw local reference of the exception. + pub(crate) fn exception_check_raw(self) -> Result<(), jthrowable> { unsafe { let exception = ((**self.env).v1_2.ExceptionOccurred)(self.env); if exception.is_null() { Ok(()) } else { ((**self.env).v1_2.ExceptionClear)(self.env); - Err(Local::from_raw(self, exception)) + Err(exception as jthrowable) } } } @@ -164,37 +170,32 @@ impl<'env> Env<'env> { assert_eq!(res, 0); } - unsafe fn exception_to_string(self, exception: jobject) -> String { - static METHOD_GET_MESSAGE: OnceLock = OnceLock::new(); - let throwable_get_message = *METHOD_GET_MESSAGE.get_or_init(|| { + unsafe fn raw_exception_to_string(self, exception: jobject) -> String { + static METHOD_GET_MESSAGE: OnceLock = OnceLock::new(); + let throwable_get_message = METHOD_GET_MESSAGE.get_or_init(|| { // use JNI FindClass to avoid infinte recursion. - let throwable_class = self.require_class_jni(c"java/lang/Throwable"); - let method = self.require_method(throwable_class, c"getMessage", c"()Ljava/lang/String;"); - ((**self.env).v1_2.DeleteLocalRef)(self.env, throwable_class); - method.addr() - }) as jmethodID; // it is a global ID + let throwable_class = self.require_class_jni(c"java/lang/Throwable").unwrap(); + JMethodID::from_raw(self.require_method(throwable_class, c"getMessage", c"()Ljava/lang/String;")) + }); let message = - ((**self.env).v1_2.CallObjectMethodA)(self.env, exception, throwable_get_message, ptr::null_mut()); - let e2: *mut _jobject = ((**self.env).v1_2.ExceptionOccurred)(self.env); - if !e2.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); - panic!("exception happened calling Throwable.getMessage()"); + ((**self.env).v1_2.CallObjectMethodA)(self.env, exception, throwable_get_message.as_raw(), ptr::null_mut()); + self.exception_check_raw() + .expect("exception happened calling Throwable.getMessage()"); + if message.is_null() { + return "??? (Throwable.getMessage() returned null string)".to_string(); } + let message_string = StringChars::from_env_jstring(self, message).to_string_lossy(); + ((**self.env).v1_2.DeleteLocalRef)(self.env, message); - StringChars::from_env_jstring(self, message).to_string_lossy() + message_string } /// Note: the returned `jclass` is actually a new local reference of the class object. pub unsafe fn require_class(self, class: &CStr) -> jclass { // First try with JNI FindClass. - let c = ((**self.env).v1_2.FindClass)(self.env, class.as_ptr()); - let exception: *mut _jobject = ((**self.env).v1_2.ExceptionOccurred)(self.env); - if !exception.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); - } - if !c.is_null() { - return c; + if let Some(class) = self.require_class_jni(class) { + return class; } // If class is not found and we have a classloader set, try that. @@ -208,25 +209,25 @@ impl<'env> Env<'env> { .collect::>(); let string = unsafe { self.new_string(chars.as_ptr(), chars.len() as jsize) }; - static CL_METHOD: OnceLock = OnceLock::new(); - let cl_method = *CL_METHOD.get_or_init(|| { + static CL_METHOD: OnceLock = OnceLock::new(); + let cl_method = CL_METHOD.get_or_init(|| { // We still use JNI FindClass for this, to avoid a chicken-and-egg situation. // If the system class loader cannot find java.lang.ClassLoader, things are pretty broken! - let cl_class = self.require_class_jni(c"java/lang/ClassLoader"); - let cl_method = self.require_method(cl_class, c"loadClass", c"(Ljava/lang/String;)Ljava/lang/Class;"); - ((**self.env).v1_2.DeleteLocalRef)(self.env, cl_class); - cl_method.addr() - }) as jmethodID; // it is a global ID + let cl_class = self.require_class_jni(c"java/lang/ClassLoader").unwrap(); + JMethodID::from_raw(self.require_method( + cl_class, + c"loadClass", + c"(Ljava/lang/String;)Ljava/lang/Class;", + )) + }); let args = [jvalue { l: string }]; let result: *mut _jobject = - ((**self.env).v1_2.CallObjectMethodA)(self.env, classloader, cl_method, args.as_ptr()); - let exception: *mut _jobject = ((**self.env).v1_2.ExceptionOccurred)(self.env); - if !exception.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); + ((**self.env).v1_2.CallObjectMethodA)(self.env, classloader, cl_method.as_raw(), args.as_ptr()); + if let Err(exception) = self.exception_check_raw() { panic!( "exception happened calling loadClass(): {}", - self.exception_to_string(exception) + self.raw_exception_to_string(exception) ); } else if result.is_null() { panic!("loadClass() returned null"); @@ -241,51 +242,38 @@ impl<'env> Env<'env> { panic!("couldn't load class {class:?}"); } - unsafe fn require_class_jni(self, class: &CStr) -> jclass { - let res = ((**self.env).v1_2.FindClass)(self.env, class.as_ptr()); - if res.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); - panic!("could not find class {class:?}"); + unsafe fn require_class_jni(self, class: &CStr) -> Option { + let cls = ((**self.env).v1_2.FindClass)(self.env, class.as_ptr()); + self.exception_check_raw().ok()?; + if cls.is_null() { + return None; } - res + Some(cls) } // used only for debugging unsafe fn get_class_name(self, class: jclass) -> String { - let classclass = self.require_class_jni(c"java/lang/Class"); - - // don't use self.require_method() here to avoid recursion! - let method = ((**self.env).v1_2.GetMethodID)( - self.env, - classclass, - c"getName".as_ptr(), - c"()Ljava/lang/String;".as_ptr(), - ); - if method.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); - ((**self.env).v1_2.DeleteLocalRef)(self.env, classclass); - return "??? (couldn't get class getName method)".to_string(); - } - - let string = ((**self.env).v1_2.CallObjectMethod)(self.env, class, method); - if string.is_null() { - return "??? (getName returned null string)".to_string(); - } - let chars = ((**self.env).v1_2.GetStringUTFChars)(self.env, string, ptr::null_mut()); - if chars.is_null() { - ((**self.env).v1_2.DeleteLocalRef)(self.env, string); - ((**self.env).v1_2.DeleteLocalRef)(self.env, classclass); - return "??? (GetStringUTFChars returned null chars)".to_string(); + static METHOD_GET_NAME: OnceLock = OnceLock::new(); + let method = METHOD_GET_NAME.get_or_init(|| { + // don't use `self.require_method()` here to avoid recursion! + let class_class = self.require_class_jni(c"java/lang/Class").unwrap(); + let method = ((**self.env).v1_2.GetMethodID)( + self.env, + class_class, + c"getName".as_ptr(), + c"()Ljava/lang/String;".as_ptr(), + ); + JMethodID::from_raw(method) + }); + let jstring = ((**self.env).v1_2.CallObjectMethod)(self.env, class, method.as_raw()); + self.exception_check_raw() + .expect("exception happened calling Class.getName()"); + if jstring.is_null() { + return "??? (Class.getName() returned null string)".to_string(); } - - let cchars = CStr::from_ptr(chars); - let res = cchars.to_string_lossy().to_string(); - - ((**self.env).v1_2.ReleaseStringUTFChars)(self.env, string, chars); - ((**self.env).v1_2.DeleteLocalRef)(self.env, string); - ((**self.env).v1_2.DeleteLocalRef)(self.env, classclass); - - res + let string = StringChars::from_env_jstring(self, jstring).to_string_lossy(); + ((**self.env).v1_2.DeleteLocalRef)(self.env, jstring); + string } pub unsafe fn require_method(self, class: jclass, method: &CStr, descriptor: &CStr) -> jmethodID { diff --git a/java-spaghetti/src/refs/ref_.rs b/java-spaghetti/src/refs/ref_.rs index 6163a83..3744138 100644 --- a/java-spaghetti/src/refs/ref_.rs +++ b/java-spaghetti/src/refs/ref_.rs @@ -198,10 +198,7 @@ impl<'env, T: ReferenceType> Drop for Monitor<'env, T> { let jnienv = env.as_raw(); let result = unsafe { ((**jnienv).v1_2.MonitorExit)(jnienv, self.inner.as_raw()) }; assert!(result == jni_sys::JNI_OK); - let exception = unsafe { ((**jnienv).v1_2.ExceptionOccurred)(jnienv) }; - assert!( - exception.is_null(), - "exception happened calling JNI MonitorExit, the monitor is probably broken previously" - ); + env.exception_check_raw() + .expect("exception happened calling JNI MonitorExit, the monitor is probably broken previously"); } } diff --git a/java-spaghetti/src/string_chars.rs b/java-spaghetti/src/string_chars.rs index 75fb3c1..0b7a939 100644 --- a/java-spaghetti/src/string_chars.rs +++ b/java-spaghetti/src/string_chars.rs @@ -28,6 +28,8 @@ impl<'env> StringChars<'env> { let chars = unsafe { env.get_string_chars(string) }; let length = unsafe { env.get_string_length(string) }; + debug_assert!(!chars.is_null() || length == 0); + Self { env, string, From 5f9ad9e526059d53eb18b26c1bb992252456f2c9 Mon Sep 17 00:00:00 2001 From: wuwbobo2021 Date: Wed, 16 Jul 2025 13:09:13 +0800 Subject: [PATCH 3/5] Replace `JniType` with `ReferenceType`; use new types in `Env` and caches --- java-spaghetti-gen/src/emit/classes.rs | 34 +--- java-spaghetti-gen/src/emit/fields.rs | 12 +- java-spaghetti-gen/src/emit/methods.rs | 17 +- java-spaghetti/src/array.rs | 62 ++++--- java-spaghetti/src/env.rs | 228 ++++++++++++++----------- java-spaghetti/src/id_cache.rs | 103 +++++++++-- java-spaghetti/src/jni_type.rs | 79 +++++---- java-spaghetti/src/lib.rs | 51 +++++- java-spaghetti/src/refs/ref_.rs | 5 +- 9 files changed, 366 insertions(+), 225 deletions(-) diff --git a/java-spaghetti-gen/src/emit/classes.rs b/java-spaghetti-gen/src/emit/classes.rs index fc3549d..daf9565 100644 --- a/java-spaghetti-gen/src/emit/classes.rs +++ b/java-spaghetti-gen/src/emit/classes.rs @@ -103,11 +103,6 @@ impl Class { let rust_name = format_ident!("{}", &self.rust.struct_name); - let referencetype_impl = match self.java.is_static() { - true => quote!(), - false => quote!(unsafe impl ::java_spaghetti::ReferenceType for #rust_name {}), - }; - let mut out = TokenStream::new(); let java_path = cstring(self.java.path().as_str()); @@ -117,11 +112,13 @@ impl Class { #attributes #visibility enum #rust_name {} - #referencetype_impl - - unsafe impl ::java_spaghetti::JniType for #rust_name { - fn static_with_jni_type(callback: impl FnOnce(&::std::ffi::CStr) -> R) -> R { - callback(#java_path) + unsafe impl ::java_spaghetti::ReferenceType for #rust_name { + fn jni_reference_type_name() -> ::std::borrow::Cow<'static, ::std::ffi::CStr> { + ::std::borrow::Cow::Borrowed(#java_path) + } + unsafe fn jni_class_cache_once_lock() -> &'static ::std::sync::OnceLock<::java_spaghetti::JClass> { + static CLASS_CACHE: ::std::sync::OnceLock<::java_spaghetti::JClass> = ::std::sync::OnceLock::new(); + &CLASS_CACHE } } )); @@ -147,23 +144,6 @@ impl Class { let mut contents = TokenStream::new(); - let object = context - .java_to_rust_path(Id("java/lang/Object"), &self.rust.mod_) - .unwrap(); - - let class = cstring(self.java.path().as_str()); - - contents.extend(quote!( - fn __class_global_ref(__jni_env: ::java_spaghetti::Env) -> ::java_spaghetti::sys::jobject { - static __CLASS: ::std::sync::OnceLock<::java_spaghetti::Global<#object>> = ::std::sync::OnceLock::new(); - __CLASS - .get_or_init(|| unsafe { - ::java_spaghetti::Local::from_raw(__jni_env, __jni_env.require_class(#class)).as_global() - }) - .as_raw() - } - )); - let mut methods: Vec = self .java .methods() diff --git a/java-spaghetti-gen/src/emit/fields.rs b/java-spaghetti-gen/src/emit/fields.rs index e5cf9ce..1273809 100644 --- a/java-spaghetti-gen/src/emit/fields.rs +++ b/java-spaghetti-gen/src/emit/fields.rs @@ -129,7 +129,7 @@ impl<'a> Field<'a> { let set_field = format_ident!("set{static_fragment}_{field_fragment}_field"); let this_or_class = match self.java.is_static() { - false => quote!(self.as_raw()), + false => quote!(self), true => quote!(__jni_class), }; @@ -142,11 +142,12 @@ impl<'a> Field<'a> { #[doc = #get_docs] #attributes pub fn #get<'env>(#env_param) -> #rust_get_type { + use ::java_spaghetti::ReferenceType; static __FIELD: ::std::sync::OnceLock<::java_spaghetti::JFieldID> = ::std::sync::OnceLock::new(); #env_let - let __jni_class = Self::__class_global_ref(__jni_env); + let __jni_class = Self::jni_get_class(__jni_env).unwrap(); unsafe { - let __jni_field = __FIELD.get_or_init(|| ::java_spaghetti::JFieldID::from_raw(__jni_env.#require_field(__jni_class, #java_name, #descriptor))).as_raw(); + let __jni_field = *__FIELD.get_or_init(|| __jni_env.#require_field(__jni_class, #java_name, #descriptor)); __jni_env.#get_field(#this_or_class, __jni_field) } } @@ -164,11 +165,12 @@ impl<'a> Field<'a> { #[doc = #set_docs] #attributes pub fn #set<#lifetimes>(#env_param, value: #rust_set_type) { + use ::java_spaghetti::ReferenceType; static __FIELD: ::std::sync::OnceLock<::java_spaghetti::JFieldID> = ::std::sync::OnceLock::new(); #env_let - let __jni_class = Self::__class_global_ref(__jni_env); + let __jni_class = Self::jni_get_class(__jni_env).unwrap(); unsafe { - let __jni_field = __FIELD.get_or_init(|| ::java_spaghetti::JFieldID::from_raw(__jni_env.#require_field(__jni_class, #java_name, #descriptor))).as_raw(); + let __jni_field = *__FIELD.get_or_init(|| __jni_env.#require_field(__jni_class, #java_name, #descriptor)); __jni_env.#set_field(#this_or_class, __jni_field, value); } } diff --git a/java-spaghetti-gen/src/emit/methods.rs b/java-spaghetti-gen/src/emit/methods.rs index 87c84af..2b090b4 100644 --- a/java-spaghetti-gen/src/emit/methods.rs +++ b/java-spaghetti-gen/src/emit/methods.rs @@ -141,27 +141,28 @@ impl<'a> Method<'a> { let method_name = format_ident!("{method_name}"); let call = if self.java.is_constructor() { - quote!(__jni_env.new_object_a(__jni_class, __jni_method, __jni_args.as_ptr())) + quote!(__jni_env.new_object_a(__jni_class, __jni_method, __jni_args)) } else if self.java.is_static() { let call = format_ident!("call_static_{ret_method_fragment}_method_a"); - quote!( __jni_env.#call(__jni_class, __jni_method, __jni_args.as_ptr())) + quote!( __jni_env.#call(__jni_class, __jni_method, __jni_args)) } else { let call = format_ident!("call_{ret_method_fragment}_method_a"); - quote!( __jni_env.#call(self.as_raw(), __jni_method, __jni_args.as_ptr())) + quote!( __jni_env.#call(self, __jni_method, __jni_args)) }; out.extend(quote!( #[doc = #docs] #attributes pub fn #method_name<'env>(#params_decl) -> ::std::result::Result<#ret_decl, ::java_spaghetti::Local<'env, #throwable>> { + use ::java_spaghetti::ReferenceType; static __METHOD: ::std::sync::OnceLock<::java_spaghetti::JMethodID> = ::std::sync::OnceLock::new(); unsafe { - let __jni_args = [#params_array]; + let __jni_args = &[#params_array]; #env_let - let __jni_class = Self::__class_global_ref(__jni_env); - let __jni_method = __METHOD.get_or_init(|| - ::java_spaghetti::JMethodID::from_raw(__jni_env.#require_method(__jni_class, #java_name, #descriptor)) - ).as_raw(); + let __jni_class = Self::jni_get_class(__jni_env).unwrap(); + let __jni_method = *__METHOD.get_or_init(|| + __jni_env.#require_method(__jni_class, #java_name, #descriptor) + ); #call } diff --git a/java-spaghetti/src/array.rs b/java-spaghetti/src/array.rs index e3c4c0f..e62acb7 100644 --- a/java-spaghetti/src/array.rs +++ b/java-spaghetti/src/array.rs @@ -1,11 +1,14 @@ +use std::borrow::Cow; +use std::collections::HashMap; use std::ffi::{CStr, CString}; use std::marker::PhantomData; use std::ops::{Bound, RangeBounds}; use std::ptr::null_mut; +use std::sync::{LazyLock, OnceLock, RwLock}; use jni_sys::*; -use crate::{AsArg, Env, JniType, Local, Ref, ReferenceType, ThrowableType}; +use crate::{AsArg, Env, JClass, Local, Ref, ReferenceType, ThrowableType}; /// A Java Array of some POD-like type such as `bool`, `jbyte`, `jchar`, `jshort`, `jint`, `jlong`, `jfloat`, or `jdouble`. /// @@ -96,10 +99,13 @@ macro_rules! primitive_array { /// A [PrimitiveArray] implementation. pub enum $name {} - unsafe impl ReferenceType for $name {} - unsafe impl JniType for $name { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback($type_str) + unsafe impl ReferenceType for $name { + fn jni_reference_type_name() -> Cow<'static, CStr> { + Cow::Borrowed($type_str) + } + unsafe fn jni_class_cache_once_lock() -> &'static OnceLock { + static CLASS_CACHE: OnceLock = OnceLock::new(); + &CLASS_CACHE } } @@ -196,18 +202,33 @@ primitive_array! { DoubleArray, c"[D", jdouble { NewDoubleArray SetDoubleArray /// See also [PrimitiveArray] for arrays of reference types. pub struct ObjectArray(core::convert::Infallible, PhantomData<(T, E)>); -unsafe impl ReferenceType for ObjectArray {} - -unsafe impl JniType for ObjectArray { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - T::static_with_jni_type(|inner| { - let inner = inner.to_bytes(); - let mut buf = Vec::with_capacity(inner.len() + 4); - buf.extend_from_slice(b"[L"); - buf.extend_from_slice(inner); - buf.extend_from_slice(b";"); - callback(&CString::new(buf).unwrap()) - }) +// NOTE: This is a performance compromise for returning `&'static JClass`, still faster than non-cached `FindClass`. +static OBJ_ARR_CLASSES: LazyLock>>> = + LazyLock::new(|| RwLock::new(HashMap::new())); + +unsafe impl ReferenceType for ObjectArray { + fn jni_reference_type_name() -> Cow<'static, CStr> { + let item_type = T::jni_reference_type_name(); + let item_type = item_type.to_string_lossy(); + let array_type = if !item_type.starts_with('[') { + format!("[L{item_type};") + } else { + format!("[{item_type}") + }; + Cow::Owned(CString::new(array_type).unwrap()) + } + + unsafe fn jni_class_cache_once_lock() -> &'static OnceLock { + let t = Self::jni_reference_type_name(); + let class_map_reader = OBJ_ARR_CLASSES.read().unwrap(); + if let Some(&once_lock) = class_map_reader.get(t.as_ref()) { + once_lock + } else { + drop(class_map_reader); + let once_lock: &'static OnceLock<_> = Box::leak(Box::new(OnceLock::new())); + let _ = OBJ_ARR_CLASSES.write().unwrap().insert(t.into_owned(), once_lock); + once_lock + } } } @@ -215,7 +236,7 @@ impl ObjectArray { /// Uses JNI `NewObjectArray` to create a new Java object array. pub fn new<'env>(env: Env<'env>, size: usize) -> Local<'env, Self> { assert!(size <= i32::MAX as usize); // jsize == jint == i32 - let class = T::static_with_jni_type(|t| unsafe { env.require_class(t) }); + let class = T::jni_get_class(env).unwrap().as_raw(); let size = size as jsize; let object = unsafe { @@ -225,11 +246,6 @@ impl ObjectArray { }; // Only sane exception here is an OOM exception env.exception_check::().map_err(|_| "OOM").unwrap(); - - unsafe { - let env = env.as_raw(); - ((**env).v1_2.DeleteLocalRef)(env, class); - } unsafe { Local::from_raw(env, object) } } diff --git a/java-spaghetti/src/env.rs b/java-spaghetti/src/env.rs index 94f1559..98ab11e 100644 --- a/java-spaghetti/src/env.rs +++ b/java-spaghetti/src/env.rs @@ -6,7 +6,9 @@ use std::sync::atomic::{AtomicPtr, Ordering}; use jni_sys::*; -use crate::{AsArg, JMethodID, Local, Ref, ReferenceType, StringChars, ThrowableType, VM}; +use crate::{ + AsArg, ClassLoaderError, JClass, JFieldID, JMethodID, Local, Ref, ReferenceType, StringChars, ThrowableType, VM, +}; /// FFI: Use **Env** instead of `*const JNIEnv`. This represents a per-thread Java exection environment. /// @@ -152,7 +154,7 @@ impl<'env> Env<'env> { .map_err(|throwable| unsafe { Local::from_raw(self, throwable) }) } - /// The same as `exception_check`, except that it may return a raw local reference of the exception. + /// The same as `exception_check`, except that it may return an owned raw local reference of the exception. pub(crate) fn exception_check_raw(self) -> Result<(), jthrowable> { unsafe { let exception = ((**self.env).v1_2.ExceptionOccurred)(self.env); @@ -175,7 +177,7 @@ impl<'env> Env<'env> { let throwable_get_message = METHOD_GET_MESSAGE.get_or_init(|| { // use JNI FindClass to avoid infinte recursion. let throwable_class = self.require_class_jni(c"java/lang/Throwable").unwrap(); - JMethodID::from_raw(self.require_method(throwable_class, c"getMessage", c"()Ljava/lang/String;")) + self.require_method(&throwable_class, c"getMessage", c"()Ljava/lang/String;") }); let message = @@ -191,81 +193,89 @@ impl<'env> Env<'env> { message_string } - /// Note: the returned `jclass` is actually a new local reference of the class object. - pub unsafe fn require_class(self, class: &CStr) -> jclass { + pub unsafe fn require_class(self, class: &CStr) -> Result { // First try with JNI FindClass. - if let Some(class) = self.require_class_jni(class) { - return class; + let required = self.require_class_jni(class); + if let Ok(class) = required { + return Ok(class); } // If class is not found and we have a classloader set, try that. let classloader = CLASS_LOADER.load(Ordering::Relaxed); - if !classloader.is_null() { - let chars = class - .to_str() - .unwrap() - .replace('/', ".") - .encode_utf16() - .collect::>(); - let string = unsafe { self.new_string(chars.as_ptr(), chars.len() as jsize) }; - - static CL_METHOD: OnceLock = OnceLock::new(); - let cl_method = CL_METHOD.get_or_init(|| { - // We still use JNI FindClass for this, to avoid a chicken-and-egg situation. - // If the system class loader cannot find java.lang.ClassLoader, things are pretty broken! - let cl_class = self.require_class_jni(c"java/lang/ClassLoader").unwrap(); - JMethodID::from_raw(self.require_method( - cl_class, - c"loadClass", - c"(Ljava/lang/String;)Ljava/lang/Class;", - )) - }); - - let args = [jvalue { l: string }]; - let result: *mut _jobject = - ((**self.env).v1_2.CallObjectMethodA)(self.env, classloader, cl_method.as_raw(), args.as_ptr()); - if let Err(exception) = self.exception_check_raw() { - panic!( - "exception happened calling loadClass(): {}", - self.raw_exception_to_string(exception) - ); - } else if result.is_null() { - panic!("loadClass() returned null"); - } + if classloader.is_null() { + return Err(required.unwrap_err()); + } - ((**self.env).v1_2.DeleteLocalRef)(self.env, string); + let java_bin_name = class.to_str().unwrap().replace('/', "."); + let chars = java_bin_name.encode_utf16().collect::>(); + let string = unsafe { self.new_string(chars.as_ptr(), chars.len() as jsize) }; + + static CL_METHOD: OnceLock = OnceLock::new(); + let cl_method = CL_METHOD.get_or_init(|| { + // We still use JNI FindClass for this, to avoid a chicken-and-egg situation. + // If the system class loader cannot find java.lang.ClassLoader, things are pretty broken! + let cl_class = self.require_class_jni(c"java/lang/ClassLoader").unwrap(); + self.require_method(&cl_class, c"loadClass", c"(Ljava/lang/String;)Ljava/lang/Class;") + }); - return result as jclass; + let args = [jvalue { l: string }]; + let result: *mut _jobject = + ((**self.env).v1_2.CallObjectMethodA)(self.env, classloader, cl_method.as_raw(), args.as_ptr()); + let ex_check = self.exception_check_raw(); + ((**self.env).v1_2.DeleteLocalRef)(self.env, string); + + if let Err(exception) = ex_check { + let err_msg = format!( + "exception happened calling loadClass(): {}", + self.raw_exception_to_string(exception) + ); + ((**self.env).v1_2.DeleteLocalRef)(self.env, exception); + return Err(ClassLoaderError(err_msg)); + } else if result.is_null() { + return Err(ClassLoaderError(format!( + "loadClass() returned null for {}", + class.to_string_lossy() + ))); } - // If neither found the class, panic. - panic!("couldn't load class {class:?}"); + Ok(JClass::from_raw(self, result as jclass)) } - unsafe fn require_class_jni(self, class: &CStr) -> Option { + unsafe fn require_class_jni(self, class: &CStr) -> Result { + // Note: the returned `cls` is actually a new local reference of the class object. let cls = ((**self.env).v1_2.FindClass)(self.env, class.as_ptr()); - self.exception_check_raw().ok()?; + if let Err(exception) = self.exception_check_raw() { + let err_msg = format!( + "exception happened calling JNI FindClass: {}", + self.raw_exception_to_string(exception) + ); + ((**self.env).v1_2.DeleteLocalRef)(self.env, exception); + return Err(ClassLoaderError(err_msg)); + } if cls.is_null() { - return None; + return Err(ClassLoaderError(format!( + "JNI FindClass returned null for {}", + class.to_string_lossy() + ))); } - Some(cls) + Ok(JClass::from_raw(self, cls)) } // used only for debugging - unsafe fn get_class_name(self, class: jclass) -> String { + unsafe fn get_class_name(self, class: &JClass) -> String { static METHOD_GET_NAME: OnceLock = OnceLock::new(); let method = METHOD_GET_NAME.get_or_init(|| { // don't use `self.require_method()` here to avoid recursion! let class_class = self.require_class_jni(c"java/lang/Class").unwrap(); - let method = ((**self.env).v1_2.GetMethodID)( + let method_raw = ((**self.env).v1_2.GetMethodID)( self.env, - class_class, + class_class.as_raw(), c"getName".as_ptr(), c"()Ljava/lang/String;".as_ptr(), ); - JMethodID::from_raw(method) + JMethodID::from_raw(method_raw) }); - let jstring = ((**self.env).v1_2.CallObjectMethod)(self.env, class, method.as_raw()); + let jstring = ((**self.env).v1_2.CallObjectMethod)(self.env, class.as_raw(), method.as_raw()); self.exception_check_raw() .expect("exception happened calling Class.getName()"); if jstring.is_null() { @@ -276,56 +286,56 @@ impl<'env> Env<'env> { string } - pub unsafe fn require_method(self, class: jclass, method: &CStr, descriptor: &CStr) -> jmethodID { - let res = ((**self.env).v1_2.GetMethodID)(self.env, class, method.as_ptr(), descriptor.as_ptr()); + pub unsafe fn require_method(self, class: &JClass, method: &CStr, descriptor: &CStr) -> JMethodID { + let res = ((**self.env).v1_2.GetMethodID)(self.env, class.as_raw(), method.as_ptr(), descriptor.as_ptr()); if res.is_null() { ((**self.env).v1_2.ExceptionClear)(self.env); let class_name = self.get_class_name(class); panic!("could not find method {method:?} {descriptor:?} on class {class_name:?}"); } - res + JMethodID::from_raw(res) } - pub unsafe fn require_static_method(self, class: jclass, method: &CStr, descriptor: &CStr) -> jmethodID { - let res = ((**self.env).v1_2.GetStaticMethodID)(self.env, class, method.as_ptr(), descriptor.as_ptr()); + pub unsafe fn require_static_method(self, class: &JClass, method: &CStr, descriptor: &CStr) -> JMethodID { + let res = ((**self.env).v1_2.GetStaticMethodID)(self.env, class.as_raw(), method.as_ptr(), descriptor.as_ptr()); if res.is_null() { ((**self.env).v1_2.ExceptionClear)(self.env); let class_name = self.get_class_name(class); panic!("could not find static method {method:?} {descriptor:?} on class {class_name:?}"); } - res + JMethodID::from_raw(res) } - pub unsafe fn require_field(self, class: jclass, field: &CStr, descriptor: &CStr) -> jfieldID { - let res = ((**self.env).v1_2.GetFieldID)(self.env, class, field.as_ptr(), descriptor.as_ptr()); + pub unsafe fn require_field(self, class: &JClass, field: &CStr, descriptor: &CStr) -> JFieldID { + let res = ((**self.env).v1_2.GetFieldID)(self.env, class.as_raw(), field.as_ptr(), descriptor.as_ptr()); if res.is_null() { ((**self.env).v1_2.ExceptionClear)(self.env); let class_name = self.get_class_name(class); panic!("could not find field {field:?} {descriptor:?} on class {class_name:?}"); } - res + JFieldID::from_raw(res) } - pub unsafe fn require_static_field(self, class: jclass, field: &CStr, descriptor: &CStr) -> jfieldID { - let res = ((**self.env).v1_2.GetStaticFieldID)(self.env, class, field.as_ptr(), descriptor.as_ptr()); + pub unsafe fn require_static_field(self, class: &JClass, field: &CStr, descriptor: &CStr) -> JFieldID { + let res = ((**self.env).v1_2.GetStaticFieldID)(self.env, class.as_raw(), field.as_ptr(), descriptor.as_ptr()); if res.is_null() { ((**self.env).v1_2.ExceptionClear)(self.env); let class_name = self.get_class_name(class); panic!("could not find static field {field:?} {descriptor:?} on class {class_name:?}"); } - res + JFieldID::from_raw(res) } } macro_rules! call_primitive_method_a { ($name:ident, $ret_type:ident, $call:ident) => { - pub unsafe fn $name( + pub unsafe fn $name( self, - this: jobject, - method: jmethodID, - args: *const jvalue, + this: &Ref<'env, T>, + method: JMethodID, + args: &[jvalue], ) -> Result<$ret_type, Local<'env, E>> { - let result = ((**self.env).v1_2.$call)(self.env, this, method, args); + let result = ((**self.env).v1_2.$call)(self.env, this.as_raw(), method.as_raw(), args.as_ptr()); self.exception_check()?; Ok(result) } @@ -336,11 +346,11 @@ macro_rules! call_static_primitive_method_a { ($name:ident, $ret_type:ident, $call:ident) => { pub unsafe fn $name( self, - class: jclass, - method: jmethodID, - args: *const jvalue, + class: &JClass, + method: JMethodID, + args: &[jvalue], ) -> Result<$ret_type, Local<'env, E>> { - let result = ((**self.env).v1_2.$call)(self.env, class, method, args); + let result = ((**self.env).v1_2.$call)(self.env, class.as_raw(), method.as_raw(), args.as_ptr()); self.exception_check()?; Ok(result) } @@ -349,32 +359,32 @@ macro_rules! call_static_primitive_method_a { macro_rules! get_primitive_field { ($name:ident, $ret_type:ident, $call:ident) => { - pub unsafe fn $name(self, this: jobject, field: jfieldID) -> $ret_type { - ((**self.env).v1_2.$call)(self.env, this, field) + pub unsafe fn $name(self, this: &Ref<'env, T>, field: JFieldID) -> $ret_type { + ((**self.env).v1_2.$call)(self.env, this.as_raw(), field.as_raw()) } }; } macro_rules! set_primitive_field { ($name:ident, $arg_type:ident, $call:ident) => { - pub unsafe fn $name(self, this: jobject, field: jfieldID, value: $arg_type) { - ((**self.env).v1_2.$call)(self.env, this, field, value); + pub unsafe fn $name(self, this: &Ref<'env, T>, field: JFieldID, value: $arg_type) { + ((**self.env).v1_2.$call)(self.env, this.as_raw(), field.as_raw(), value); } }; } macro_rules! get_static_primitive_field { ($name:ident, $ret_type:ident, $call:ident) => { - pub unsafe fn $name(self, class: jclass, field: jfieldID) -> $ret_type { - ((**self.env).v1_2.$call)(self.env, class, field) + pub unsafe fn $name(self, class: &JClass, field: JFieldID) -> $ret_type { + ((**self.env).v1_2.$call)(self.env, class.as_raw(), field.as_raw()) } }; } macro_rules! set_static_primitive_field { ($name:ident, $arg_type:ident, $call:ident) => { - pub unsafe fn $name(self, class: jclass, field: jfieldID, value: $arg_type) { - ((**self.env).v1_2.$call)(self.env, class, field, value); + pub unsafe fn $name(self, class: &JClass, field: JFieldID, value: $arg_type) { + ((**self.env).v1_2.$call)(self.env, class.as_raw(), field.as_raw(), value); } }; } @@ -387,11 +397,11 @@ type void = (); impl<'env> Env<'env> { pub unsafe fn new_object_a( self, - class: jclass, - method: jmethodID, - args: *const jvalue, + class: &JClass, + method: JMethodID, + args: &[jvalue], ) -> Result, Local<'env, E>> { - let result = ((**self.env).v1_2.NewObjectA)(self.env, class, method, args); + let result = ((**self.env).v1_2.NewObjectA)(self.env, class.as_raw(), method.as_raw(), args.as_ptr()); self.exception_check()?; assert!(!result.is_null()); Ok(Local::from_raw(self, result)) @@ -399,13 +409,13 @@ impl<'env> Env<'env> { // Instance Methods - pub unsafe fn call_object_method_a( + pub unsafe fn call_object_method_a( self, - this: jobject, - method: jmethodID, - args: *const jvalue, + this: &Ref<'env, T>, + method: JMethodID, + args: &[jvalue], ) -> Result>, Local<'env, E>> { - let result = ((**self.env).v1_2.CallObjectMethodA)(self.env, this, method, args); + let result = ((**self.env).v1_2.CallObjectMethodA)(self.env, this.as_raw(), method.as_raw(), args.as_ptr()); self.exception_check()?; if result.is_null() { Ok(None) @@ -428,11 +438,12 @@ impl<'env> Env<'env> { pub unsafe fn call_static_object_method_a( self, - class: jclass, - method: jmethodID, - args: *const jvalue, + class: &JClass, + method: JMethodID, + args: &[jvalue], ) -> Result>, Local<'env, E>> { - let result = ((**self.env).v1_2.CallStaticObjectMethodA)(self.env, class, method, args); + let result = + ((**self.env).v1_2.CallStaticObjectMethodA)(self.env, class.as_raw(), method.as_raw(), args.as_ptr()); self.exception_check()?; if result.is_null() { Ok(None) @@ -452,8 +463,12 @@ impl<'env> Env<'env> { // Instance Fields - pub unsafe fn get_object_field(self, this: jobject, field: jfieldID) -> Option> { - let result = ((**self.env).v1_2.GetObjectField)(self.env, this, field); + pub unsafe fn get_object_field( + self, + this: &Ref<'env, T>, + field: JFieldID, + ) -> Option> { + let result = ((**self.env).v1_2.GetObjectField)(self.env, this.as_raw(), field.as_raw()); if result.is_null() { None } else { @@ -469,8 +484,13 @@ impl<'env> Env<'env> { get_primitive_field! { get_float_field, jfloat, GetFloatField } get_primitive_field! { get_double_field, jdouble, GetDoubleField } - pub unsafe fn set_object_field(self, this: jobject, field: jfieldID, value: impl AsArg) { - ((**self.env).v1_2.SetObjectField)(self.env, this, field, value.as_arg()); + pub unsafe fn set_object_field( + self, + this: &Ref<'env, T>, + field: JFieldID, + value: impl AsArg, + ) { + ((**self.env).v1_2.SetObjectField)(self.env, this.as_raw(), field.as_raw(), value.as_arg()); } set_primitive_field! { set_boolean_field, bool, SetBooleanField } set_primitive_field! { set_byte_field, jbyte, SetByteField } @@ -485,10 +505,10 @@ impl<'env> Env<'env> { pub unsafe fn get_static_object_field( self, - class: jclass, - field: jfieldID, + class: &JClass, + field: JFieldID, ) -> Option> { - let result = ((**self.env).v1_2.GetStaticObjectField)(self.env, class, field); + let result = ((**self.env).v1_2.GetStaticObjectField)(self.env, class.as_raw(), field.as_raw()); if result.is_null() { None } else { @@ -506,11 +526,11 @@ impl<'env> Env<'env> { pub unsafe fn set_static_object_field( self, - class: jclass, - field: jfieldID, + class: &JClass, + field: JFieldID, value: impl AsArg, ) { - ((**self.env).v1_2.SetStaticObjectField)(self.env, class, field, value.as_arg()); + ((**self.env).v1_2.SetStaticObjectField)(self.env, class.as_raw(), field.as_raw(), value.as_arg()); } set_static_primitive_field! { set_static_boolean_field, bool, SetStaticBooleanField } set_static_primitive_field! { set_static_byte_field, jbyte, SetStaticByteField } diff --git a/java-spaghetti/src/id_cache.rs b/java-spaghetti/src/id_cache.rs index 8253f23..e83d327 100644 --- a/java-spaghetti/src/id_cache.rs +++ b/java-spaghetti/src/id_cache.rs @@ -1,16 +1,87 @@ -//! New types for `jfieldID` and `jmethodID` that implement `Send` and `Sync`. +//! New type for cached class objects as JNI global references; new types for `jfieldID` and `jmethodID` that +//! implement `Send` and `Sync`. //! //! Inspired by: . -//! -//! According to the JNI spec field IDs may be invalidated when the corresponding class is unloaded: -//! -//! -//! You should generally not be interacting with these types directly, but it must be public for codegen. -use crate::sys::{jfieldID, jmethodID}; +use crate::sys::{jclass, jfieldID, jmethodID, jobject}; +use crate::{Env, VM}; + +/// New type for cached class objects as JNI global references. +/// +/// Holding a `JClass` global reference prevents the corresponding Java class from being unloaded. +#[derive(Debug)] +pub struct JClass { + class: jclass, + vm: VM, +} + +unsafe impl Send for JClass {} +unsafe impl Sync for JClass {} + +impl JClass { + /// Creates a `JClass` from an owned JNI local reference of a class object and *deletes* the + /// local reference. + /// + /// # Safety + /// + /// `class` must be a valid JNI local reference to a `java.lang.Class` object. + /// Do not use the passed `class` local reference after calling this function. + /// + /// It is safe to pass the returned value of JNI `FindClass` to it if no exeception occurred. + pub unsafe fn from_raw<'env>(env: Env<'env>, class: jclass) -> Self { + assert!(!class.is_null(), "from_raw jclass argument is null"); + let jnienv = env.as_raw(); + let class_global = unsafe { ((**jnienv).v1_2.NewGlobalRef)(jnienv, class) }; + unsafe { ((**jnienv).v1_2.DeleteLocalRef)(jnienv, class) } + unsafe { Self::from_raw_global(env.vm(), class_global) } + } + + /// Wraps an owned raw JNI global reference of a class object. + /// + /// # Safety + /// + /// `class` must be a valid JNI global reference to a `java.lang.Class` object. + pub unsafe fn from_raw_global(vm: VM, class: jobject) -> Self { + assert!(!class.is_null(), "from_raw_global jclass argument is null"); + Self { + class: class as jclass, + vm, + } + } + + pub fn as_raw(&self) -> jclass { + self.class + } +} + +impl Clone for JClass { + fn clone(&self) -> Self { + self.vm.with_env(|env| { + let env = env.as_raw(); + let class_global = unsafe { ((**env).v1_2.NewGlobalRef)(env, self.class) }; + assert!(!class_global.is_null()); + unsafe { Self::from_raw_global(self.vm, class_global) } + }) + } +} + +// XXX: Unfortunately, static items (e.g. `OnceLock`) may not call drop() at the end of the Rust program: +// JNI global references may be leaked if `java-spaghetti`-based libraries are unloaded and reloaded by the VM. +impl Drop for JClass { + fn drop(&mut self) { + self.vm.with_env(|env| { + let env = env.as_raw(); + unsafe { ((**env).v1_2.DeleteGlobalRef)(env, self.class) } + }); + } +} -#[doc(hidden)] +/// New type for `jfieldID`, implements `Send` and `Sync`. +/// +/// According to the JNI spec, field IDs may be invalidated when the corresponding class is unloaded: +/// . #[repr(transparent)] +#[derive(Clone, Copy, Debug)] pub struct JFieldID { internal: jfieldID, } @@ -20,13 +91,13 @@ unsafe impl Send for JFieldID {} unsafe impl Sync for JFieldID {} impl JFieldID { - /// Creates a [`JFieldID`] that wraps the given `raw` [`jfieldID`]. + /// Creates a [`JFieldID`] that wraps the given raw [`jfieldID`]. /// /// # Safety /// - /// Expects a valid, non-`null` ID. + /// Expects a valid, non-null ID. pub unsafe fn from_raw(raw: jfieldID) -> Self { - debug_assert!(!raw.is_null(), "from_raw fieldID argument"); + assert!(!raw.is_null(), "from_raw jfieldID argument is null"); Self { internal: raw } } @@ -35,8 +106,12 @@ impl JFieldID { } } -#[doc(hidden)] +/// New type for `jmethodID`, implements `Send` and `Sync`. +/// +/// According to the JNI spec, method IDs may be invalidated when the corresponding class is unloaded: +/// . #[repr(transparent)] +#[derive(Clone, Copy, Debug)] pub struct JMethodID { internal: jmethodID, } @@ -50,9 +125,9 @@ impl JMethodID { /// /// # Safety /// - /// Expects a valid, non-`null` ID. + /// Expects a valid, non-null ID. pub unsafe fn from_raw(raw: jmethodID) -> Self { - debug_assert!(!raw.is_null(), "from_raw methodID argument"); + assert!(!raw.is_null(), "from_raw jmethodID argument is null"); Self { internal: raw } } diff --git a/java-spaghetti/src/jni_type.rs b/java-spaghetti/src/jni_type.rs index 2de0a2a..a949ce3 100644 --- a/java-spaghetti/src/jni_type.rs +++ b/java-spaghetti/src/jni_type.rs @@ -1,71 +1,80 @@ -use std::ffi::CStr; +//! XXX: This type came from the original [jni-glue](https://docs.rs/jni-glue/0.0.10/src/jni_glue/jni_type.rs.html), +//! I'm not sure of its possible funcationality in the future, but it's currently preserved. +//! +//! Side note: While primitive array type signatures like c"[I" can be passed to the JNI `FindClass`, a primitive "class" +//! like `int.class` cannot be obtained by passing c"I" to `FindClass`. Primitive "classes" might be obtained from +//! [java.lang.reflect.Method](https://docs.oracle.com/javase/8/docs/api/java/lang/reflect/Method.html#getParameterTypes). + +use std::borrow::Cow; +use std::ffi::{CStr, CString}; use jni_sys::*; -/// JNI bindings rely on this type being accurate. -/// -/// # Safety -/// -/// **unsafe**: Passing the wrong type can cause unsoundness, since the code that interacts with JNI blindly trusts it's correct. -/// -/// Why the awkward callback style instead of returning `&'static CStr`? Arrays of arrays may need to dynamically -/// construct their type strings, which would need to leak. Worse, we can't easily intern those strings via -/// lazy_static without running into: -/// -/// ```text -/// error[E0401]: can't use generic parameters from outer function -/// ``` +use crate::ReferenceType; + +#[doc(hidden)] pub unsafe trait JniType { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R; + fn jni_type_name() -> Cow<'static, CStr>; } unsafe impl JniType for () { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"V") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"V") } } unsafe impl JniType for bool { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"Z") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"Z") } } unsafe impl JniType for jbyte { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"B") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"B") } } unsafe impl JniType for jchar { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"C") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"C") } } unsafe impl JniType for jshort { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"S") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"S") } } unsafe impl JniType for jint { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"I") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"I") } } unsafe impl JniType for jlong { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"J") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"J") } } unsafe impl JniType for jfloat { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"F") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"F") } } unsafe impl JniType for jdouble { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"D") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"D") } } unsafe impl JniType for &CStr { - fn static_with_jni_type(callback: impl FnOnce(&CStr) -> R) -> R { - callback(c"Ljava/lang/String;") + fn jni_type_name() -> Cow<'static, CStr> { + Cow::Borrowed(c"Ljava/lang/String;") + } +} + +unsafe impl JniType for T { + fn jni_type_name() -> Cow<'static, CStr> { + let type_name = Self::jni_reference_type_name(); + if type_name.to_bytes()[0] != b'[' { + Cow::Owned(CString::new(format!("L{};", type_name.to_string_lossy())).unwrap()) + } else { + type_name + } } } diff --git a/java-spaghetti/src/lib.rs b/java-spaghetti/src/lib.rs index 517fff4..61ddb7b 100644 --- a/java-spaghetti/src/lib.rs +++ b/java-spaghetti/src/lib.rs @@ -8,6 +8,8 @@ #![feature(arbitrary_self_types)] +use std::borrow::Cow; +use std::ffi::CStr; use std::fmt; /// public jni-sys reexport. @@ -47,7 +49,7 @@ pub use string_chars::*; pub use vm::*; /// Error returned on failed `.cast()`.` -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub struct CastError; impl std::error::Error for CastError {} @@ -57,13 +59,52 @@ impl fmt::Display for CastError { } } -/// A marker type indicating this is a valid exception type that all exceptions thrown by java should be compatible with +/// Error returned on failed [Env::require_class]. +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct ClassLoaderError(String); + +impl std::error::Error for ClassLoaderError {} +impl fmt::Display for ClassLoaderError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "ClassLoader failed: {}", &self.0) + } +} + +/// A marker type indicating this is a valid exception type that all exceptions thrown by Java should be compatible with. pub trait ThrowableType: ReferenceType {} +/// A marker type indicating this is a Java reference type. JNI bindings rely on this type being accurate. +/// /// You should generally not be interacting with this type directly, but it must be public for codegen. -#[doc(hidden)] -#[warn(clippy::missing_safety_doc)] -pub unsafe trait ReferenceType: JniType + Sized + 'static {} +/// +/// # Safety +/// +/// **unsafe**: Passing the wrong type name may be a soundness bug as although the Android JVM will simply panic and abort, +/// I have no idea if that is a guarantee or not. +pub unsafe trait ReferenceType: JniType + Sized + 'static { + /// Returns a string value compatible with JNI + /// [FindClass](https://docs.oracle.com/javase/8/docs/technotes/guides/jni/spec/functions.html#FindClass). + fn jni_reference_type_name() -> Cow<'static, CStr>; + + /// Returns the reference to the `OnceLock` dedicated to this reference type. + /// + /// This should be initialized manually if the class is loaded dynamically with `dalvik.system.DexClassLoader`. + /// + /// # Safety + /// + /// It must be initialized by the class with the binary name returned by `jni_reference_type_name()`. + unsafe fn jni_class_cache_once_lock() -> &'static std::sync::OnceLock; + + /// Returns a cached `JClass` of the class object for this reference type. + fn jni_get_class<'env>(env: Env<'env>) -> Result<&'static JClass, ClassLoaderError> { + let once_lock = unsafe { Self::jni_class_cache_once_lock() }; + if let Some(cls) = once_lock.get() { + return Ok(cls); + } + let required = unsafe { env.require_class(&Self::jni_reference_type_name()) }?; + Ok(once_lock.get_or_init(|| required)) + } +} /// Marker trait indicating `Self` can be assigned to `T`. /// diff --git a/java-spaghetti/src/refs/ref_.rs b/java-spaghetti/src/refs/ref_.rs index 3744138..029d024 100644 --- a/java-spaghetti/src/refs/ref_.rs +++ b/java-spaghetti/src/refs/ref_.rs @@ -87,11 +87,8 @@ impl<'env, T: ReferenceType> Ref<'env, T> { pub(crate) fn check_assignable(&self) -> Result<(), crate::CastError> { let env = self.env(); let jnienv = env.as_raw(); - let class = U::static_with_jni_type(|t| unsafe { env.require_class(t) }); + let class = U::jni_get_class(env).unwrap().as_raw(); let assignable = unsafe { ((**jnienv).v1_2.IsInstanceOf)(jnienv, self.as_raw(), class) }; - unsafe { - ((**jnienv).v1_2.DeleteLocalRef)(jnienv, class); - } if assignable { Ok(()) } else { Err(crate::CastError) } } From ea16d52d15f7268394bdee6457ce3afda6bfdab3 Mon Sep 17 00:00:00 2001 From: wuwbobo2021 Date: Thu, 17 Jul 2025 04:41:28 +0800 Subject: [PATCH 4/5] Make proxy generation working again --- java-spaghetti-gen/src/emit/class_proxy.rs | 72 +++++++++++++++------- java-spaghetti-gen/src/emit/classes.rs | 4 +- java-spaghetti/src/env.rs | 48 +++++++++++++-- java-spaghetti/src/id_cache.rs | 8 +++ 4 files changed, 104 insertions(+), 28 deletions(-) diff --git a/java-spaghetti-gen/src/emit/class_proxy.rs b/java-spaghetti-gen/src/emit/class_proxy.rs index 4641445..2d6574f 100644 --- a/java-spaghetti-gen/src/emit/class_proxy.rs +++ b/java-spaghetti-gen/src/emit/class_proxy.rs @@ -10,7 +10,6 @@ use super::fields::RustTypeFlavor; use super::methods::Method; use crate::emit::Context; use crate::emit::fields::emit_type; -use crate::parser_util::Id; impl Class { #[allow(clippy::vec_init_then_push)] @@ -22,9 +21,6 @@ impl Class { let rust_name = format_ident!("{}", &self.rust.struct_name); - let object = context - .java_to_rust_path(Id("java/lang/Object"), &self.rust.mod_) - .unwrap(); let throwable = context.throwable_rust_path(&self.rust.mod_); let rust_proxy_name = format_ident!("{}Proxy", &self.rust.struct_name); @@ -36,6 +32,7 @@ impl Class { self.java.path().as_str().replace("$", "_") ); + let mut native_regs = Vec::new(); for method in methods { let Some(rust_name) = method.rust_name() else { continue }; if method.java.is_static() @@ -61,6 +58,14 @@ impl Class { let native_name = format_ident!("{native_name}"); let rust_name = format_ident!("{rust_name}"); + let mut native_method_desc = method.java.descriptor().to_string(); + native_method_desc.insert(1, 'J'); + native_regs.push(( + cstring(&format!("native_{}", method.java.name())), + cstring(&native_method_desc), + native_name.clone(), + )); + let ret = match &method.java.descriptor.return_type { ReturnDescriptor::Void => quote!(()), ReturnDescriptor::Return(desc) => emit_type( @@ -157,34 +162,59 @@ impl Class { pub fn new_proxy<'env>( env: ::java_spaghetti::Env<'env>, proxy: ::std::sync::Arc, + proxy_class: ::std::option::Option<::java_spaghetti::JClass>, ) -> Result<::java_spaghetti::Local<'env, Self>, ::java_spaghetti::Local<'env, #throwable>> { - static __CLASS: ::std::sync::OnceLock<::java_spaghetti::Global<#object>> = - ::std::sync::OnceLock::new(); + static __CLASS: ::std::sync::OnceLock<::java_spaghetti::JClass> = ::std::sync::OnceLock::new(); let __jni_class = __CLASS .get_or_init(|| unsafe { - ::java_spaghetti::Local::from_raw(env, env.require_class(#java_proxy_path),) - .as_global() - }) - .as_raw(); + let required = env.require_class(#java_proxy_path); + if let Ok(proxy_class) = required { + proxy_class + } else if let Some(proxy_class) = proxy_class { + let bin_name = env.get_class_name(&proxy_class).replace('.', "/"); + let expected = #java_proxy_path.to_string_lossy(); + if bin_name != expected { + panic!("wrong proxy_class, expected: {}, provided: {}", expected, bin_name) + } + Self::register_proxy_methods(env, &proxy_class); + proxy_class + } else { + panic!("{}", required.unwrap_err()) + } + }); let b = ::std::boxed::Box::new(proxy); let ptr = ::std::boxed::Box::into_raw(b); static __METHOD: ::std::sync::OnceLock<::java_spaghetti::JMethodID> = ::std::sync::OnceLock::new(); unsafe { - let __jni_args = [::java_spaghetti::sys::jvalue { + let __jni_args = &[::java_spaghetti::sys::jvalue { j: ptr.expose_provenance() as i64, }]; - let __jni_method = __METHOD - .get_or_init(|| { - ::java_spaghetti::JMethodID::from_raw(env.require_method( - __jni_class, - c"", - c"(J)V", - )) - }) - .as_raw(); - env.new_object_a(__jni_class, __jni_method, __jni_args.as_ptr()) + let __jni_method = *__METHOD.get_or_init(|| env.require_method(__jni_class, c"", c"(J)V")); + env.new_object_a(__jni_class, __jni_method, __jni_args) + } + } + )); + + let mut register_calls = TokenStream::new(); + for (native_method_name, descriptor, extern_name) in native_regs { + register_calls.extend(quote!( + { + let method_name = #native_method_name; + let descriptor = #descriptor; + let fn_ptr = #extern_name as *mut _; + let _ = env.register_native_method(proxy_class, method_name, descriptor, fn_ptr); + } + )); + } + contents.extend(quote!( + fn register_proxy_methods<'env>( + env: ::java_spaghetti::Env<'env>, + proxy_class: &::java_spaghetti::JClass, + ) { + unsafe { + #register_calls } } )); diff --git a/java-spaghetti-gen/src/emit/classes.rs b/java-spaghetti-gen/src/emit/classes.rs index daf9565..571f952 100644 --- a/java-spaghetti-gen/src/emit/classes.rs +++ b/java-spaghetti-gen/src/emit/classes.rs @@ -117,8 +117,8 @@ impl Class { ::std::borrow::Cow::Borrowed(#java_path) } unsafe fn jni_class_cache_once_lock() -> &'static ::std::sync::OnceLock<::java_spaghetti::JClass> { - static CLASS_CACHE: ::std::sync::OnceLock<::java_spaghetti::JClass> = ::std::sync::OnceLock::new(); - &CLASS_CACHE + static __CLASS: ::std::sync::OnceLock<::java_spaghetti::JClass> = ::std::sync::OnceLock::new(); + &__CLASS } } )); diff --git a/java-spaghetti/src/env.rs b/java-spaghetti/src/env.rs index 98ab11e..111ab02 100644 --- a/java-spaghetti/src/env.rs +++ b/java-spaghetti/src/env.rs @@ -1,4 +1,4 @@ -use std::ffi::CStr; +use std::ffi::{CStr, c_char, c_void}; use std::marker::PhantomData; use std::ptr::{self, null_mut}; use std::sync::OnceLock; @@ -172,7 +172,7 @@ impl<'env> Env<'env> { assert_eq!(res, 0); } - unsafe fn raw_exception_to_string(self, exception: jobject) -> String { + pub(crate) unsafe fn raw_exception_to_string(self, exception: jobject) -> String { static METHOD_GET_MESSAGE: OnceLock = OnceLock::new(); let throwable_get_message = METHOD_GET_MESSAGE.get_or_init(|| { // use JNI FindClass to avoid infinte recursion. @@ -241,7 +241,7 @@ impl<'env> Env<'env> { Ok(JClass::from_raw(self, result as jclass)) } - unsafe fn require_class_jni(self, class: &CStr) -> Result { + pub(crate) unsafe fn require_class_jni(self, class: &CStr) -> Result { // Note: the returned `cls` is actually a new local reference of the class object. let cls = ((**self.env).v1_2.FindClass)(self.env, class.as_ptr()); if let Err(exception) = self.exception_check_raw() { @@ -261,8 +261,8 @@ impl<'env> Env<'env> { Ok(JClass::from_raw(self, cls)) } - // used only for debugging - unsafe fn get_class_name(self, class: &JClass) -> String { + /// Gets the binary name (not internal form) of the class with `Class.getName()`. Returns "??? (details)" on error. + pub unsafe fn get_class_name(self, class: &JClass) -> String { static METHOD_GET_NAME: OnceLock = OnceLock::new(); let method = METHOD_GET_NAME.get_or_init(|| { // don't use `self.require_method()` here to avoid recursion! @@ -286,6 +286,44 @@ impl<'env> Env<'env> { string } + /// Binds the function pointer to the native method of `class` according to method name and signature. + /// Returns `false` if the method is not found or the JNI `RegisterNatives` returns a negative value. + /// + /// # Safety + /// + /// The native method pointer must be a valid, non-null pointer to a function that match the signature + /// of the corresponding Java method. + pub unsafe fn register_native_method( + &self, + class: &JClass, + method: &CStr, + descriptor: &CStr, + fn_ptr: *mut c_void, + ) -> bool { + // `RegisterNatives` shouldn't modify `name` and `signature`, but still clone them. + let (method, descriptor) = (method.to_owned(), descriptor.to_owned()); + let mut native_methods = [JNINativeMethod { + name: method.as_ptr() as *mut c_char, + signature: descriptor.as_ptr() as *mut c_char, + fnPtr: fn_ptr, + }]; + let jnienv = self.as_raw(); + let res = ((**jnienv).v1_2.RegisterNatives)(jnienv, class.as_raw(), native_methods.as_mut_ptr(), 1); + + if let Err(exception) = self.exception_check_raw() { + eprintln!( + "exception happened calling JNI RegisterNatives: {}", + self.raw_exception_to_string(exception) + ); + ((**jnienv).v1_2.DeleteLocalRef)(jnienv, exception); + return false; + } else if res < 0 { + eprintln!("JNI RegisterNatives failed: returned value is {res}"); + return false; + } + true + } + pub unsafe fn require_method(self, class: &JClass, method: &CStr, descriptor: &CStr) -> JMethodID { let res = ((**self.env).v1_2.GetMethodID)(self.env, class.as_raw(), method.as_ptr(), descriptor.as_ptr()); if res.is_null() { diff --git a/java-spaghetti/src/id_cache.rs b/java-spaghetti/src/id_cache.rs index e83d327..24d29d3 100644 --- a/java-spaghetti/src/id_cache.rs +++ b/java-spaghetti/src/id_cache.rs @@ -49,9 +49,17 @@ impl JClass { } } + /// Returns the raw JNI reference pointer. pub fn as_raw(&self) -> jclass { self.class } + + /// Turns it into a raw global reference; prevents `DeleteGlobalRef` from being called on dropping. + pub fn into_raw(self) -> jclass { + let class = self.class; + std::mem::forget(self); // Don't delete the object. + class + } } impl Clone for JClass { From f91a7ec1c1328fb2d08ff9a1b59e3c6920b0340b Mon Sep 17 00:00:00 2001 From: wuwbobo2021 Date: Sat, 19 Jul 2025 00:58:26 +0800 Subject: [PATCH 5/5] Add safe `JClass::from_ref`; Close #13 --- java-spaghetti-gen/src/emit/class_proxy.rs | 10 ++++-- java-spaghetti-gen/src/emit/methods.rs | 11 ++++--- java-spaghetti/src/env.rs | 36 +++++++++++++++++----- java-spaghetti/src/id_cache.rs | 27 ++++++++++++++-- 4 files changed, 68 insertions(+), 16 deletions(-) diff --git a/java-spaghetti-gen/src/emit/class_proxy.rs b/java-spaghetti-gen/src/emit/class_proxy.rs index 2d6574f..bbc120f 100644 --- a/java-spaghetti-gen/src/emit/class_proxy.rs +++ b/java-spaghetti-gen/src/emit/class_proxy.rs @@ -151,13 +151,14 @@ impl Class { _class: *mut (), // self class, ignore ptr: i64, ) { - let ptr: *mut std::sync::Arc = ::std::ptr::with_exposed_provenance_mut(ptr as usize); + let ptr: *mut ::std::sync::Arc = ::std::ptr::with_exposed_provenance_mut(ptr as usize); let _ = unsafe { Box::from_raw(ptr) }; } )); let java_proxy_path = cstring(&java_proxy_path); + // XXX: use `OnceLock::get_or_try_init` for `__METHOD` when it becomes stable. contents.extend(quote!( pub fn new_proxy<'env>( env: ::java_spaghetti::Env<'env>, @@ -191,7 +192,12 @@ impl Class { let __jni_args = &[::java_spaghetti::sys::jvalue { j: ptr.expose_provenance() as i64, }]; - let __jni_method = *__METHOD.get_or_init(|| env.require_method(__jni_class, c"", c"(J)V")); + let __jni_method = if let Some(&__jni_method) = __METHOD.get() { + __jni_method + } else { + let __jni_method = env.require_method(__jni_class, c"", c"(J)V")?; + *__METHOD.get_or_init(|| __jni_method) + }; env.new_object_a(__jni_class, __jni_method, __jni_args) } } diff --git a/java-spaghetti-gen/src/emit/methods.rs b/java-spaghetti-gen/src/emit/methods.rs index 2b090b4..d0d423c 100644 --- a/java-spaghetti-gen/src/emit/methods.rs +++ b/java-spaghetti-gen/src/emit/methods.rs @@ -150,6 +150,7 @@ impl<'a> Method<'a> { quote!( __jni_env.#call(self, __jni_method, __jni_args)) }; + // XXX: use `OnceLock::get_or_try_init` when it becomes stable. out.extend(quote!( #[doc = #docs] #attributes @@ -160,10 +161,12 @@ impl<'a> Method<'a> { let __jni_args = &[#params_array]; #env_let let __jni_class = Self::jni_get_class(__jni_env).unwrap(); - let __jni_method = *__METHOD.get_or_init(|| - __jni_env.#require_method(__jni_class, #java_name, #descriptor) - ); - + let __jni_method = if let Some(&__jni_method) = __METHOD.get() { + __jni_method + } else { + let __jni_method = __jni_env.#require_method(__jni_class, #java_name, #descriptor)?; + *__METHOD.get_or_init(|| __jni_method) + }; #call } } diff --git a/java-spaghetti/src/env.rs b/java-spaghetti/src/env.rs index 111ab02..2b5ea34 100644 --- a/java-spaghetti/src/env.rs +++ b/java-spaghetti/src/env.rs @@ -177,7 +177,7 @@ impl<'env> Env<'env> { let throwable_get_message = METHOD_GET_MESSAGE.get_or_init(|| { // use JNI FindClass to avoid infinte recursion. let throwable_class = self.require_class_jni(c"java/lang/Throwable").unwrap(); - self.require_method(&throwable_class, c"getMessage", c"()Ljava/lang/String;") + self.require_method_forced(&throwable_class, c"getMessage", c"()Ljava/lang/String;") }); let message = @@ -215,7 +215,7 @@ impl<'env> Env<'env> { // We still use JNI FindClass for this, to avoid a chicken-and-egg situation. // If the system class loader cannot find java.lang.ClassLoader, things are pretty broken! let cl_class = self.require_class_jni(c"java/lang/ClassLoader").unwrap(); - self.require_method(&cl_class, c"loadClass", c"(Ljava/lang/String;)Ljava/lang/Class;") + self.require_method_forced(&cl_class, c"loadClass", c"(Ljava/lang/String;)Ljava/lang/Class;") }); let args = [jvalue { l: string }]; @@ -265,7 +265,7 @@ impl<'env> Env<'env> { pub unsafe fn get_class_name(self, class: &JClass) -> String { static METHOD_GET_NAME: OnceLock = OnceLock::new(); let method = METHOD_GET_NAME.get_or_init(|| { - // don't use `self.require_method()` here to avoid recursion! + // don't use `self.require_method_forced()` here to avoid recursion! let class_class = self.require_class_jni(c"java/lang/Class").unwrap(); let method_raw = ((**self.env).v1_2.GetMethodID)( self.env, @@ -324,24 +324,44 @@ impl<'env> Env<'env> { true } - pub unsafe fn require_method(self, class: &JClass, method: &CStr, descriptor: &CStr) -> JMethodID { + pub(crate) unsafe fn require_method_forced(self, class: &JClass, method: &CStr, descriptor: &CStr) -> JMethodID { let res = ((**self.env).v1_2.GetMethodID)(self.env, class.as_raw(), method.as_ptr(), descriptor.as_ptr()); + let _ = self.exception_check_raw(); if res.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); let class_name = self.get_class_name(class); panic!("could not find method {method:?} {descriptor:?} on class {class_name:?}"); } JMethodID::from_raw(res) } - pub unsafe fn require_static_method(self, class: &JClass, method: &CStr, descriptor: &CStr) -> JMethodID { + pub unsafe fn require_method( + self, + class: &JClass, + method: &CStr, + descriptor: &CStr, + ) -> Result> { + let res = ((**self.env).v1_2.GetMethodID)(self.env, class.as_raw(), method.as_ptr(), descriptor.as_ptr()); + self.exception_check()?; + if res.is_null() { + let class_name = self.get_class_name(class); + panic!("could not find method {method:?} {descriptor:?} on class {class_name:?}"); + } + Ok(JMethodID::from_raw(res)) + } + + pub unsafe fn require_static_method( + self, + class: &JClass, + method: &CStr, + descriptor: &CStr, + ) -> Result> { let res = ((**self.env).v1_2.GetStaticMethodID)(self.env, class.as_raw(), method.as_ptr(), descriptor.as_ptr()); + self.exception_check()?; if res.is_null() { - ((**self.env).v1_2.ExceptionClear)(self.env); let class_name = self.get_class_name(class); panic!("could not find static method {method:?} {descriptor:?} on class {class_name:?}"); } - JMethodID::from_raw(res) + Ok(JMethodID::from_raw(res)) } pub unsafe fn require_field(self, class: &JClass, field: &CStr, descriptor: &CStr) -> JFieldID { diff --git a/java-spaghetti/src/id_cache.rs b/java-spaghetti/src/id_cache.rs index 24d29d3..ea15834 100644 --- a/java-spaghetti/src/id_cache.rs +++ b/java-spaghetti/src/id_cache.rs @@ -3,8 +3,10 @@ //! //! Inspired by: . +use std::sync::OnceLock; + use crate::sys::{jclass, jfieldID, jmethodID, jobject}; -use crate::{Env, VM}; +use crate::{Env, Ref, ReferenceType, VM}; /// New type for cached class objects as JNI global references. /// @@ -19,6 +21,28 @@ unsafe impl Send for JClass {} unsafe impl Sync for JClass {} impl JClass { + /// Creates a `JClass` from a JNI reference of a class object. Returns `None` if it is not a class object. + pub fn from_ref<'env, T: ReferenceType>(class: &Ref<'env, T>) -> Option { + let env = class.env(); + + static METHOD_GET_CLASS: OnceLock = OnceLock::new(); + let method = METHOD_GET_CLASS.get_or_init(|| unsafe { + let obj_class = env.require_class_jni(c"java/lang/Object").unwrap(); + env.require_method_forced(&obj_class, c"getClass", c"()Ljava/lang/Class;") + }); + + let jnienv = env.as_raw(); + let class_class = unsafe { ((**jnienv).v1_2.CallObjectMethod)(jnienv, class.as_raw(), method.as_raw()) }; + if env.exception_check_raw().is_err() || class_class.is_null() { + return None; + } + unsafe { + let class_class = Self::from_raw(env, class_class); + let name = env.get_class_name(&class_class); + (name.trim() == "java.lang.Class").then_some(Self::from_raw_global(env.vm(), class.as_global().into_raw())) + } + } + /// Creates a `JClass` from an owned JNI local reference of a class object and *deletes* the /// local reference. /// @@ -48,7 +72,6 @@ impl JClass { vm, } } - /// Returns the raw JNI reference pointer. pub fn as_raw(&self) -> jclass { self.class