Skip to content

Commit 23ea534

Browse files
committed
Update LocalExtrema to use parallelization framework
1 parent 3dd9ecd commit 23ea534

File tree

1 file changed

+33
-52
lines changed

1 file changed

+33
-52
lines changed

src/main/java/net/imglib2/algorithm/localextrema/LocalExtrema.java

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,27 +35,30 @@
3535

3636
import java.util.ArrayList;
3737
import java.util.Arrays;
38+
import java.util.Collection;
3839
import java.util.List;
3940
import java.util.concurrent.Callable;
4041
import java.util.concurrent.ExecutionException;
4142
import java.util.concurrent.ExecutorService;
4243
import java.util.concurrent.Future;
4344
import java.util.stream.IntStream;
4445
import java.util.stream.LongStream;
45-
46-
import net.imglib2.Cursor;
47-
import net.imglib2.FinalInterval;
4846
import net.imglib2.Interval;
4947
import net.imglib2.Localizable;
5048
import net.imglib2.Point;
49+
import net.imglib2.RandomAccess;
5150
import net.imglib2.RandomAccessible;
5251
import net.imglib2.RandomAccessibleInterval;
5352
import net.imglib2.Sampler;
5453
import net.imglib2.algorithm.neighborhood.Neighborhood;
5554
import net.imglib2.algorithm.neighborhood.RectangleShape;
5655
import net.imglib2.algorithm.neighborhood.Shape;
56+
import net.imglib2.converter.readwrite.WriteConvertedRandomAccessible;
57+
import net.imglib2.loops.LoopBuilder;
58+
import net.imglib2.parallel.Parallelization;
59+
import net.imglib2.parallel.TaskExecutor;
60+
import net.imglib2.parallel.TaskExecutors;
5761
import net.imglib2.util.ConstantUtils;
58-
import net.imglib2.util.Intervals;
5962
import net.imglib2.util.ValuePair;
6063
import net.imglib2.view.IntervalView;
6164
import net.imglib2.view.Views;
@@ -306,7 +309,7 @@ public static < P, T > List< P > findLocalExtrema(
306309
* @param numTasks
307310
* Number of tasks for parallel execution
308311
* @param splitDim
309-
* Dimension along which input should be split for parallization
312+
* ignored
310313
* @return {@link List} of extrema
311314
* @throws ExecutionException
312315
* @throws InterruptedException
@@ -320,38 +323,8 @@ public static < P, T > List< P > findLocalExtrema(
320323
final int numTasks,
321324
final int splitDim ) throws InterruptedException, ExecutionException
322325
{
323-
324-
final long[] min = Intervals.minAsLongArray( interval );
325-
final long[] max = Intervals.maxAsLongArray( interval );
326-
327-
final long splitDimSize = interval.dimension( splitDim );
328-
final long splitDimMax = max[ splitDim ];
329-
final long splitDimMin = min[ splitDim ];
330-
final long taskSize = Math.max( splitDimSize / numTasks, 1 );
331-
332-
final ArrayList< Callable< List< P > > > tasks = new ArrayList<>();
333-
334-
for ( long start = splitDimMin, stop = splitDimMin + taskSize - 1; start <= splitDimMax; start += taskSize, stop += taskSize )
335-
{
336-
final long s = start;
337-
// need max here instead of dimension for constructor of
338-
// FinalInterval
339-
final long S = Math.min( stop, splitDimMax );
340-
tasks.add( () -> {
341-
final long[] localMin = min.clone();
342-
final long[] localMax = max.clone();
343-
localMin[ splitDim ] = s;
344-
localMax[ splitDim ] = S;
345-
return findLocalExtrema( source, new FinalInterval( localMin, localMax ), localNeighborhoodCheck, shape );
346-
} );
347-
}
348-
349-
final ArrayList< P > extrema = new ArrayList<>();
350-
final List< Future< List< P > > > futures = service.invokeAll( tasks );
351-
for ( final Future< List< P > > f : futures )
352-
extrema.addAll( f.get() );
353-
return extrema;
354-
326+
TaskExecutor taskExecutor = TaskExecutors.forExecutorServiceAndNumTasks( service, numTasks );
327+
return Parallelization.runWithExecutor( taskExecutor, () -> findLocalExtrema( source, interval, localNeighborhoodCheck, shape ) );
355328
}
356329

357330
/**
@@ -470,22 +443,30 @@ public static < P, T > List< P > findLocalExtrema(
470443
final LocalNeighborhoodCheck< P, T > localNeighborhoodCheck,
471444
final Shape shape )
472445
{
446+
WriteConvertedRandomAccessible< T, RandomAccess< T > > randomAccessible = new WriteConvertedRandomAccessible<>( source, sampler -> (RandomAccess< T >) sampler );
447+
RandomAccessibleInterval< RandomAccess< T > > centers = Views.interval( randomAccessible, interval);
448+
RandomAccessibleInterval< Neighborhood< T > > neighborhoods = Views.interval( shape.neighborhoodsRandomAccessible( source ), interval );
449+
List< List< P > > extremas = LoopBuilder.setImages( centers, neighborhoods ).multiThreaded().forEachChunk( chunk -> {
450+
List< P > extrema = new ArrayList<>();
451+
chunk.forEachPixel( ( center, neighborhood ) -> {
452+
P p = localNeighborhoodCheck.check( center, neighborhood );
453+
if ( p != null )
454+
extrema.add( p );
455+
} );
456+
return extrema;
457+
} );
458+
return concatenate( extremas );
459+
}
473460

474-
final IntervalView< T > sourceInterval = Views.interval( source, interval );
475-
476-
final ArrayList< P > extrema = new ArrayList<>();
477-
478-
final Cursor< T > center = Views.flatIterable( sourceInterval ).cursor();
479-
for ( final Neighborhood< T > neighborhood : shape.neighborhoods( sourceInterval ) )
480-
{
481-
center.fwd();
482-
final P p = localNeighborhoodCheck.check( center, neighborhood );
483-
if ( p != null )
484-
extrema.add( p );
485-
}
486-
487-
return extrema;
488-
461+
private static < P > List<P> concatenate( List<List<P>> lists )
462+
{
463+
if(lists.size() == 1)
464+
return lists.get( 0 );
465+
int size = lists.stream().mapToInt( List::size ).sum();
466+
List< P > result = new ArrayList<>( size );
467+
for ( List< P > list : lists )
468+
result.addAll( list );
469+
return result;
489470
}
490471

491472
/**

0 commit comments

Comments
 (0)