@@ -544,3 +544,280 @@ pub fn log_every_n_sec(input: TokenStream) -> TokenStream {
544544
545545 TokenStream :: from ( expanded)
546546}
547+
548+ // Constant to control the default time limit
549+ const DEFAULT_TIME_LIMIT_MS : u64 = 200 ;
550+
551+ /// A timed version of the `#[test]` attribute that fails if the test exceeds a time limit.
552+ ///
553+ /// # Usage
554+ ///
555+ /// ```rust,ignore
556+ /// #[timed_test]
557+ /// fn my_test() {
558+ /// // test code
559+ /// }
560+ ///
561+ /// // With custom time limit (in milliseconds)
562+ /// #[timed_test(500)]
563+ /// fn my_slower_test() {
564+ /// // test code
565+ /// }
566+ /// ```
567+ ///
568+ /// The default time limit is 200ms. If the test exceeds this limit, it will panic with a message
569+ /// indicating how long it took.
570+ #[ proc_macro_attribute]
571+ pub fn timed_test ( attr : TokenStream , input : TokenStream ) -> TokenStream {
572+ let time_limit_ms = parse_time_limit ( attr, DEFAULT_TIME_LIMIT_MS ) ;
573+ let input_fn = parse_macro_input ! ( input as ItemFn ) ;
574+
575+ let fn_name = & input_fn. sig . ident ;
576+ let fn_attrs = & input_fn. attrs ;
577+ let fn_vis = & input_fn. vis ;
578+ let fn_sig = & input_fn. sig ;
579+ let fn_block = & input_fn. block ;
580+
581+ let expanded = quote ! {
582+ #[ test]
583+ #( #fn_attrs) *
584+ #fn_vis #fn_sig {
585+ let start = std:: time:: Instant :: now( ) ;
586+ #fn_block
587+ let elapsed = start. elapsed( ) ;
588+ let elapsed_ms = elapsed. as_millis( ) as u64 ;
589+ if elapsed_ms > #time_limit_ms {
590+ panic!(
591+ "Test `{}` exceeded time limit: took {}ms, limit is {}ms" ,
592+ stringify!( #fn_name) ,
593+ elapsed_ms,
594+ #time_limit_ms
595+ ) ;
596+ }
597+ }
598+ } ;
599+
600+ TokenStream :: from ( expanded)
601+ }
602+
603+ /// A timed version of the `#[tokio::test]` attribute that fails if the test exceeds a time limit.
604+ ///
605+ /// # Usage
606+ ///
607+ /// ```rust,ignore
608+ /// #[timed_tokio_test]
609+ /// async fn my_test() {
610+ /// // test code
611+ /// }
612+ ///
613+ /// // With custom time limit (in milliseconds)
614+ /// #[timed_tokio_test(1000)]
615+ /// async fn my_slower_test() {
616+ /// // test code
617+ /// }
618+ /// ```
619+ ///
620+ /// The default time limit is 200ms. If the test exceeds this limit, it will panic with a message
621+ /// indicating how long it took.
622+ #[ proc_macro_attribute]
623+ pub fn timed_tokio_test ( attr : TokenStream , input : TokenStream ) -> TokenStream {
624+ let time_limit_ms = parse_time_limit ( attr, DEFAULT_TIME_LIMIT_MS ) ;
625+ let input_fn = parse_macro_input ! ( input as ItemFn ) ;
626+
627+ let fn_name = & input_fn. sig . ident ;
628+ let fn_attrs = & input_fn. attrs ;
629+ let fn_vis = & input_fn. vis ;
630+ let fn_sig = & input_fn. sig ;
631+ let fn_block = & input_fn. block ;
632+
633+ let time_limit_duration = quote ! {
634+ std:: time:: Duration :: from_millis( #time_limit_ms)
635+ } ;
636+
637+ let expanded = quote ! {
638+ #[ tokio:: test]
639+ #( #fn_attrs) *
640+ #fn_vis #fn_sig {
641+ let start = std:: time:: Instant :: now( ) ;
642+ let test_future = async move {
643+ #fn_block
644+ } ;
645+ match tokio:: time:: timeout( #time_limit_duration, test_future) . await {
646+ Ok ( _) => {
647+ let elapsed = start. elapsed( ) ;
648+ let elapsed_ms = elapsed. as_millis( ) as u64 ;
649+ // Even if we completed within timeout, check if we exceeded the limit
650+ // (timeout might have been slightly longer due to scheduling)
651+ if elapsed_ms > #time_limit_ms {
652+ panic!(
653+ "Test `{}` exceeded time limit: took {}ms, limit is {}ms" ,
654+ stringify!( #fn_name) ,
655+ elapsed_ms,
656+ #time_limit_ms
657+ ) ;
658+ }
659+ }
660+ Err ( _) => {
661+ let elapsed = start. elapsed( ) ;
662+ let elapsed_ms = elapsed. as_millis( ) as u64 ;
663+ panic!(
664+ "Test `{}` exceeded time limit: took {}ms, limit is {}ms" ,
665+ stringify!( #fn_name) ,
666+ elapsed_ms,
667+ #time_limit_ms
668+ ) ;
669+ }
670+ }
671+ }
672+ } ;
673+
674+ TokenStream :: from ( expanded)
675+ }
676+
677+ /// A timed version of the `#[rstest]` attribute that fails if the test exceeds a time limit.
678+ ///
679+ /// # Usage
680+ ///
681+ /// ```rust,ignore
682+ /// #[timed_rstest]
683+ /// fn my_test(param: u32) {
684+ /// // test code
685+ /// }
686+ ///
687+ /// // With custom time limit (in milliseconds)
688+ /// #[timed_rstest(500)]
689+ /// fn my_slower_test(param: u32) {
690+ /// // test code
691+ /// }
692+ ///
693+ /// // Can be combined with tokio::test for async tests
694+ /// #[timed_rstest]
695+ /// #[tokio::test]
696+ /// async fn my_async_test(param: u32) {
697+ /// // test code
698+ /// }
699+ /// ```
700+ ///
701+ /// The default time limit is 200ms. If the test exceeds this limit, it will panic with a message
702+ /// indicating how long it took.
703+ ///
704+ /// Note: When combined with `#[tokio::test]`, the timing will work correctly for async tests.
705+ #[ proc_macro_attribute]
706+ pub fn timed_rstest ( attr : TokenStream , input : TokenStream ) -> TokenStream {
707+ let time_limit_ms = parse_time_limit ( attr, DEFAULT_TIME_LIMIT_MS ) ;
708+ let input_fn = parse_macro_input ! ( input as ItemFn ) ;
709+
710+ let fn_name = & input_fn. sig . ident ;
711+ let fn_attrs = & input_fn. attrs ;
712+ let fn_vis = & input_fn. vis ;
713+ let fn_sig = & input_fn. sig ;
714+ let fn_block = & input_fn. block ;
715+
716+ // Check if this is an async function by looking at the signature
717+ let is_async = fn_sig. asyncness . is_some ( ) ;
718+
719+ // Check if tokio::test is in the attributes
720+ // #[tokio::test] has path segments: ["tokio", "test"]
721+ let has_tokio_test = fn_attrs. iter ( ) . any ( |attr| {
722+ let path = attr. path ( ) ;
723+ if path. segments . len ( ) == 2 {
724+ path. segments [ 0 ] . ident == "tokio" && path. segments [ 1 ] . ident == "test"
725+ } else {
726+ false
727+ }
728+ } ) ;
729+
730+ let is_async_test = is_async || has_tokio_test;
731+
732+ if is_async_test {
733+ let time_limit_duration = quote ! {
734+ std:: time:: Duration :: from_millis( #time_limit_ms)
735+ } ;
736+
737+ // Separate tokio::test from other attributes
738+ let ( tokio_test_attrs, other_attrs) : ( Vec < _ > , Vec < _ > ) = fn_attrs. iter ( ) . partition ( |attr| {
739+ let path = attr. path ( ) ;
740+ path. segments . len ( ) == 2
741+ && path. segments [ 0 ] . ident == "tokio"
742+ && path. segments [ 1 ] . ident == "test"
743+ } ) ;
744+
745+ let expanded = quote ! {
746+ #( #tokio_test_attrs) *
747+ #[ :: rstest:: rstest]
748+ #( #other_attrs) *
749+ #fn_vis #fn_sig {
750+ let start = std:: time:: Instant :: now( ) ;
751+ let result = tokio:: time:: timeout( #time_limit_duration, async {
752+ #fn_block
753+ } ) . await ;
754+ match result {
755+ Ok ( _) => {
756+ let elapsed = start. elapsed( ) ;
757+ let elapsed_ms = elapsed. as_millis( ) as u64 ;
758+ if elapsed_ms > #time_limit_ms {
759+ panic!(
760+ "Test `{}` exceeded time limit: took {}ms, limit is {}ms" ,
761+ stringify!( #fn_name) ,
762+ elapsed_ms,
763+ #time_limit_ms
764+ ) ;
765+ }
766+ }
767+ Err ( _) => {
768+ let elapsed = start. elapsed( ) ;
769+ let elapsed_ms = elapsed. as_millis( ) as u64 ;
770+ panic!(
771+ "Test `{}` exceeded time limit: took {}ms, limit is {}ms" ,
772+ stringify!( #fn_name) ,
773+ elapsed_ms,
774+ #time_limit_ms
775+ ) ;
776+ }
777+ }
778+ }
779+ } ;
780+
781+ TokenStream :: from ( expanded)
782+ } else {
783+ let expanded = quote ! {
784+ #[ :: rstest:: rstest]
785+ #( #fn_attrs) *
786+ #fn_vis #fn_sig {
787+ let start = std:: time:: Instant :: now( ) ;
788+ #fn_block
789+ let elapsed = start. elapsed( ) ;
790+ let elapsed_ms = elapsed. as_millis( ) as u64 ;
791+ if elapsed_ms > #time_limit_ms {
792+ panic!(
793+ "Test `{}` exceeded time limit: took {}ms, limit is {}ms" ,
794+ stringify!( #fn_name) ,
795+ elapsed_ms,
796+ #time_limit_ms
797+ ) ;
798+ }
799+ }
800+ } ;
801+
802+ TokenStream :: from ( expanded)
803+ }
804+ }
805+
806+ /// Helper function to parse the time limit from the attribute.
807+ /// Returns the default if the attribute is empty.
808+ fn parse_time_limit ( attr : TokenStream , default : u64 ) -> u64 {
809+ if attr. is_empty ( ) {
810+ return default;
811+ }
812+
813+ let attr_str = attr. to_string ( ) ;
814+ let trimmed = attr_str. trim ( ) ;
815+
816+ if trimmed. is_empty ( ) {
817+ return default;
818+ }
819+
820+ trimmed. parse :: < u64 > ( ) . unwrap_or_else ( |_| {
821+ panic ! ( "Expected a positive integer for time limit in milliseconds, got: {}" , trimmed)
822+ } )
823+ }
0 commit comments