@@ -33,6 +33,12 @@ type ElectionPayload struct {
3333 SessionID string
3434}
3535
36+ type commandExecutor struct {}
37+
38+ func (c * commandExecutor ) Command (name string , arg ... string ) * exec.Cmd {
39+ return exec .Command (name , arg ... )
40+ }
41+
3642// New returns a new Ballot instance.
3743func New (ctx context.Context , name string ) (b * Ballot , err error ) {
3844 if ctx == nil {
@@ -55,6 +61,7 @@ func New(ctx context.Context, name string) (b *Ballot, err error) {
5561 b .leader .Store (false )
5662 b .Token = consulConfig .Token
5763 b .ctx = ctx
64+ b .executor = & commandExecutor {}
5865
5966 b .Name = name
6067 if b .LockDelay == 0 {
@@ -84,7 +91,7 @@ type Ballot struct {
8491 leader atomic.Bool `mapstructure:"-"`
8592 client ConsulClient `mapstructure:"-"`
8693 ctx context.Context `mapstructure:"-"`
87- exec CommandExecutor `mapstructure:"-"`
94+ executor CommandExecutor `mapstructure:"-"`
8895}
8996
9097// Copy *api.AgentService to *api.AgentServiceRegistration
@@ -154,15 +161,15 @@ func (b *Ballot) runCommand(command string, electionPayload *ElectionPayload) ([
154161 if err != nil {
155162 return nil , err
156163 }
157- cmd := b .exec .Command (args [0 ], args [1 :]... )
164+ cmd := b .executor .Command (args [0 ], args [1 :]... )
158165 cmd .Env = append (cmd .Env , fmt .Sprintf ("ADDRESS=%s" , electionPayload .Address ))
159166 cmd .Env = append (cmd .Env , fmt .Sprintf ("PORT=%d" , electionPayload .Port ))
160167 cmd .Env = append (cmd .Env , fmt .Sprintf ("SESSIONID=%s" , electionPayload .SessionID ))
161168 return cmd .Output ()
162169}
163170
164171// updateServiceTags updates the service tags.
165- func (b * Ballot ) updateServiceTags () error {
172+ func (b * Ballot ) updateServiceTags (isLeader bool ) error {
166173 service , _ , err := b .getService ()
167174 if err != nil {
168175 return err
@@ -175,10 +182,10 @@ func (b *Ballot) updateServiceTags() error {
175182 hasPrimaryTag := slices .Contains (registration .Tags , b .PrimaryTag )
176183
177184 // Update tags based on leadership status
178- if b . IsLeader () && ! hasPrimaryTag {
185+ if isLeader && ! hasPrimaryTag {
179186 // Add primary tag if not present and this node is the leader
180187 registration .Tags = append (registration .Tags , b .PrimaryTag )
181- } else if ! b . IsLeader () && hasPrimaryTag {
188+ } else if ! isLeader && hasPrimaryTag {
182189 // Remove primary tag if present and this node is not the leader
183190 index := slices .Index (registration .Tags , b .PrimaryTag )
184191 registration .Tags = append (registration .Tags [:index ], registration .Tags [index + 1 :]... )
@@ -187,6 +194,35 @@ func (b *Ballot) updateServiceTags() error {
187194 return nil
188195 }
189196
197+ // Run the command associated with the new leadership status
198+ var command string
199+ if isLeader {
200+ command = b .ExecOnPromote
201+ } else {
202+ command = b .ExecOnDemote
203+ }
204+ if command != "" && b .executor != nil {
205+ go func (isLeader bool , command string ) {
206+ // Run the command in a separate goroutine
207+ ctx , cancel := context .WithTimeout (b .ctx , (b .TTL + b .LockDelay )* 2 )
208+ defer cancel ()
209+ payload , err := b .waitForNextValidSessionData (ctx )
210+ output , err := b .runCommand (command , payload )
211+ if err != nil {
212+ log .WithFields (log.Fields {
213+ "caller" : "updateLeadershipStatus" ,
214+ "isLeader" : isLeader ,
215+ "error" : err ,
216+ }).Error ("failed to run command" )
217+ }
218+ log .WithFields (log.Fields {
219+ "caller" : "updateLeadershipStatus" ,
220+ "isLeader" : isLeader ,
221+ "output" : string (output ),
222+ }).Info ("ran command" )
223+ }(isLeader , command )
224+ }
225+
190226 // Log the updated tags
191227 log .WithFields (log.Fields {
192228 "caller" : "updateServiceTags" ,
@@ -345,7 +381,7 @@ func (b *Ballot) updateLeadershipStatus(isLeader bool) error {
345381 b .leader .Store (isLeader )
346382
347383 // Update service tags based on leadership status
348- err := b .updateServiceTags ()
384+ err := b .updateServiceTags (isLeader )
349385 if err != nil {
350386 return err
351387 }
@@ -447,6 +483,25 @@ func (b *Ballot) IsLeader() bool {
447483 return b .leader .Load () && b .sessionID .Load () != nil
448484}
449485
486+ func (b * Ballot ) waitForNextValidSessionData (ctx context.Context ) (data * ElectionPayload , err error ) {
487+ ticker := time .NewTicker (1 * time .Second )
488+ defer ticker .Stop ()
489+ for {
490+ select {
491+ case <- ticker .C :
492+ data , err := b .getSessionData ()
493+ if err != nil {
494+ return data , err
495+ }
496+ if data != nil {
497+ return data , nil
498+ }
499+ case <- ctx .Done ():
500+ return data , ctx .Err ()
501+ }
502+ }
503+ }
504+
450505func (b * Ballot ) getSessionData () (data * ElectionPayload , err error ) {
451506 sessionKey , _ , err := b .client .KV ().Get (b .Key , nil )
452507 if err != nil {
0 commit comments