@@ -2,7 +2,6 @@ use std::sync::{Arc, Condvar, Mutex};
22use tonic:: { Code , Request , Response , Status } ;
33use triggered:: Trigger ;
44
5- use crate :: extended_appointment:: UUID ;
65use crate :: protos as msgs;
76use crate :: protos:: private_tower_services_server:: PrivateTowerServices ;
87use crate :: protos:: public_tower_services_server:: PublicTowerServices ;
@@ -280,27 +279,40 @@ impl PrivateTowerServices for Arc<InternalAPI> {
280279 . map_or( "an unknown address" . to_owned( ) , |a| a. to_string( ) )
281280 ) ;
282281
283- let mut matching_appointments = vec ! [ ] ;
284- let locator = Locator :: from_slice ( & request . into_inner ( ) . locator ) . map_err ( |_| {
282+ let req_data = request . into_inner ( ) ;
283+ let locator = Locator :: from_slice ( & req_data . locator ) . map_err ( |_| {
285284 Status :: new (
286285 Code :: InvalidArgument ,
287286 "The provided locator does not match the expected format (16-byte hexadecimal string)" ,
288287 )
289288 } ) ?;
290289
291- for ( _ , appointment ) in self
290+ let mut appointments : Vec < ( UserId , Appointment ) > = self
292291 . watcher
293292 . get_watcher_appointments_with_locator ( locator)
293+ . into_values ( )
294+ . map ( |appointment| ( appointment. user_id , appointment. inner ) )
295+ . collect ( ) ;
296+
297+ let user_id_slice = req_data. user_id ;
298+ if !( & user_id_slice. is_empty ( ) ) {
299+ let user_id = UserId :: from_slice ( & user_id_slice) . map_err ( |_| {
300+ Status :: new (
301+ Code :: InvalidArgument ,
302+ "The Provided user_id does not match expected format (33-byte hex string)" ,
303+ )
304+ } ) ?;
305+ appointments. retain ( |( appointment_user_id, _) | * appointment_user_id == user_id) ;
306+ }
307+
308+ let mut matching_appointments: Vec < common_msgs:: AppointmentData > = appointments
294309 . into_iter ( )
295- {
296- matching_appointments. push ( common_msgs:: AppointmentData {
310+ . map ( |( _, appointment) | common_msgs:: AppointmentData {
297311 appointment_data : Some (
298- common_msgs:: appointment_data:: AppointmentData :: Appointment (
299- appointment. inner . into ( ) ,
300- ) ,
312+ common_msgs:: appointment_data:: AppointmentData :: Appointment ( appointment. into ( ) ) ,
301313 ) ,
302314 } )
303- }
315+ . collect ( ) ;
304316
305317 for ( _, tracker) in self
306318 . watcher
@@ -390,7 +402,6 @@ impl PrivateTowerServices for Arc<InternalAPI> {
390402 Some ( ( info, locators) ) => Ok ( Response :: new ( msgs:: GetUserResponse {
391403 available_slots : info. available_slots ,
392404 subscription_expiry : info. subscription_expiry ,
393- // TODO: Should make `get_appointments` queryable using the (user_id, locator) pair for consistency.
394405 appointments : locators
395406 . into_iter ( )
396407 . map ( |locator| locator. to_vec ( ) )
@@ -511,14 +522,97 @@ mod tests_private_api {
511522
512523 let locator = Locator :: new ( get_random_tx ( ) . txid ( ) ) . to_vec ( ) ;
513524 let response = internal_api
514- . get_appointments ( Request :: new ( msgs:: GetAppointmentsRequest { locator } ) )
525+ . get_appointments ( Request :: new ( msgs:: GetAppointmentsRequest {
526+ locator,
527+ user_id : Vec :: new ( ) ,
528+ } ) )
515529 . await
516530 . unwrap ( )
517531 . into_inner ( ) ;
518532
519533 assert ! ( matches!( response, msgs:: GetAppointmentsResponse { .. } ) ) ;
520534 }
521535
536+ #[ tokio:: test]
537+ async fn test_get_appointments_with_and_without_user_id ( ) {
538+ // setup
539+ let ( internal_api, _s) = create_api ( ) . await ;
540+ let random_txn = get_random_tx ( ) ;
541+ let ( user_sk1, user_pk1) = get_random_keypair ( ) ;
542+ let user_id1 = UserId ( user_pk1) ;
543+ let ( user_sk2, user_pk2) = get_random_keypair ( ) ;
544+ let user_id2 = UserId ( user_pk2) ;
545+ internal_api. watcher . register ( user_id1) . unwrap ( ) ;
546+ internal_api. watcher . register ( user_id2) . unwrap ( ) ;
547+ let appointment1 =
548+ generate_dummy_appointment_with_user ( user_id1, Some ( & random_txn. clone ( ) . txid ( ) ) )
549+ . 1
550+ . inner ;
551+ let signature1 = cryptography:: sign ( & appointment1. to_vec ( ) , & user_sk1) . unwrap ( ) ;
552+ let appointment2 =
553+ generate_dummy_appointment_with_user ( user_id2, Some ( & random_txn. clone ( ) . txid ( ) ) )
554+ . 1
555+ . inner ;
556+ let signature2 = cryptography:: sign ( & appointment2. to_vec ( ) , & user_sk2) . unwrap ( ) ;
557+ internal_api
558+ . watcher
559+ . add_appointment ( appointment1. clone ( ) , signature1)
560+ . unwrap ( ) ;
561+ internal_api
562+ . watcher
563+ . add_appointment ( appointment2. clone ( ) , signature2)
564+ . unwrap ( ) ;
565+
566+ let locator = & appointment1. locator ;
567+
568+ // returns all appointments if user_id is absent
569+ let response = internal_api
570+ . get_appointments ( Request :: new ( msgs:: GetAppointmentsRequest {
571+ locator : locator. clone ( ) . to_vec ( ) ,
572+ user_id : Vec :: new ( ) ,
573+ } ) )
574+ . await
575+ . unwrap ( )
576+ . into_inner ( ) ;
577+ let dummy_appointments = response. appointments ;
578+ assert_eq ! ( & dummy_appointments. len( ) , & 2 ) ;
579+ let responses: Vec < Vec < u8 > > = dummy_appointments
580+ . into_iter ( )
581+ . filter_map ( |data| {
582+ if let Some ( common_msgs:: appointment_data:: AppointmentData :: Appointment (
583+ appointment,
584+ ) ) = data. appointment_data
585+ {
586+ return Some ( appointment. locator ) ;
587+ }
588+ return None ;
589+ } )
590+ . collect ( ) ;
591+ assert_eq ! ( responses[ 0 ] , locator. clone( ) . to_vec( ) ) ;
592+ assert_eq ! ( responses[ 1 ] , locator. clone( ) . to_vec( ) ) ;
593+
594+ // returns specific appointments if user_id is absent
595+ let response = internal_api
596+ . get_appointments ( Request :: new ( msgs:: GetAppointmentsRequest {
597+ locator : locator. clone ( ) . to_vec ( ) ,
598+ user_id : user_id1. clone ( ) . to_vec ( ) ,
599+ } ) )
600+ . await
601+ . unwrap ( )
602+ . into_inner ( ) ;
603+ let dummy_appointments = response. appointments ;
604+ assert_eq ! ( & dummy_appointments. len( ) , & 1 ) ;
605+ let dummy_appointmnets_data = & dummy_appointments[ 0 ] . appointment_data ;
606+ assert ! (
607+ matches!( dummy_appointmnets_data. clone( ) , Some ( common_msgs:: appointment_data:: AppointmentData :: Appointment (
608+ common_msgs:: Appointment {
609+ locator: ref app_loc,
610+ ..
611+ }
612+ ) ) if app_loc. clone( ) == locator. to_vec( ) )
613+ )
614+ }
615+
522616 #[ tokio:: test]
523617 async fn test_get_appointments_watcher ( ) {
524618 let ( internal_api, _s) = create_api ( ) . await ;
@@ -548,6 +642,7 @@ mod tests_private_api {
548642 let response = internal_api
549643 . get_appointments ( Request :: new ( msgs:: GetAppointmentsRequest {
550644 locator : locator. to_vec ( ) ,
645+ user_id : Vec :: new ( ) ,
551646 } ) )
552647 . await
553648 . unwrap ( )
@@ -599,6 +694,7 @@ mod tests_private_api {
599694 let response = internal_api
600695 . get_appointments ( Request :: new ( msgs:: GetAppointmentsRequest {
601696 locator : locator. to_vec ( ) ,
697+ user_id : Vec :: new ( ) ,
602698 } ) )
603699 . await
604700 . unwrap ( )
@@ -747,7 +843,10 @@ mod tests_private_api {
747843
748844 assert_eq ! ( response. available_slots, SLOTS - 1 ) ;
749845 assert_eq ! ( response. subscription_expiry, START_HEIGHT as u32 + DURATION ) ;
750- assert_eq ! ( response. appointments, Vec :: from( [ appointment. inner. locator. to_vec( ) ] ) ) ;
846+ assert_eq ! (
847+ response. appointments,
848+ Vec :: from( [ appointment. inner. locator. to_vec( ) ] )
849+ ) ;
751850 }
752851
753852 #[ tokio:: test]
0 commit comments