@@ -105,42 +105,72 @@ func (wg *WorkerGenerator) GenerateWorkerPod(
105105 }, nil
106106}
107107
108- func SelectWorker (ctx context.Context , k8sClient client.Client , workloadName string , workerStatuses []tfv1.WorkerStatus ) (* tfv1.WorkerStatus , error ) {
109- if len (workerStatuses ) == 0 {
108+ func SelectWorker (
109+ ctx context.Context ,
110+ k8sClient client.Client ,
111+ workload * tfv1.TensorFusionWorkload ,
112+ maxSkew int32 ,
113+ ) (* tfv1.WorkerStatus , error ) {
114+ if len (workload .Status .WorkerStatuses ) == 0 {
110115 return nil , fmt .Errorf ("no available worker" )
111116 }
112- usageMapping := make (map [string ]int , len (workerStatuses ))
113- for _ , workerStatus := range workerStatuses {
117+ usageMapping := make (map [string ]int , len (workload . Status . WorkerStatuses ))
118+ for _ , workerStatus := range workload . Status . WorkerStatuses {
114119 usageMapping [workerStatus .WorkerName ] = 0
115120 }
116121
117122 connectionList := tfv1.TensorFusionConnectionList {}
118- if err := k8sClient .List (ctx , & connectionList , client.MatchingLabels {constants .WorkloadKey : workloadName }); err != nil {
123+ if err := k8sClient .List (ctx , & connectionList , client.MatchingLabels {constants .WorkloadKey : workload . Name }); err != nil {
119124 return nil , fmt .Errorf ("list TensorFusionConnection: %w" , err )
120125 }
121126
122127 for _ , connection := range connectionList .Items {
123128 if connection .Status .WorkerName != "" {
124- continue
129+ usageMapping [ connection . Status . WorkerName ] ++
125130 }
126- usageMapping [connection .Status .WorkerName ]++
127131 }
128132
129- var minUsageWorker * tfv1.WorkerStatus
130- // Initialize with max int value
133+ // First find the minimum usage
131134 minUsage := int (^ uint (0 ) >> 1 )
132- for _ , workerStatus := range workerStatuses {
135+ // Initialize with max int value
136+ for _ , workerStatus := range workload .Status .WorkerStatuses {
133137 if workerStatus .WorkerPhase == tfv1 .WorkerFailed {
134138 continue
135139 }
136140 usage := usageMapping [workerStatus .WorkerName ]
137141 if usage < minUsage {
138142 minUsage = usage
139- minUsageWorker = & workerStatus
140143 }
141144 }
142- if minUsageWorker == nil {
145+
146+ // Collect all eligible workers that are within maxSkew of the minimum usage
147+ var eligibleWorkers []* tfv1.WorkerStatus
148+ for _ , workerStatus := range workload .Status .WorkerStatuses {
149+ if workerStatus .WorkerPhase == tfv1 .WorkerFailed {
150+ continue
151+ }
152+ usage := usageMapping [workerStatus .WorkerName ]
153+ // Worker is eligible if its usage is within maxSkew of the minimum usage
154+ if usage <= minUsage + int (maxSkew ) {
155+ eligibleWorkers = append (eligibleWorkers , & workerStatus )
156+ }
157+ }
158+
159+ if len (eligibleWorkers ) == 0 {
143160 return nil , fmt .Errorf ("no available worker" )
144161 }
145- return minUsageWorker , nil
162+
163+ // Choose the worker with the minimum usage among eligible workers
164+ selectedWorker := eligibleWorkers [0 ]
165+ selectedUsage := usageMapping [selectedWorker .WorkerName ]
166+ for i := 1 ; i < len (eligibleWorkers ); i ++ {
167+ worker := eligibleWorkers [i ]
168+ usage := usageMapping [worker .WorkerName ]
169+ if usage < selectedUsage {
170+ selectedWorker = worker
171+ selectedUsage = usage
172+ }
173+ }
174+
175+ return selectedWorker , nil
146176}
0 commit comments