Skip to content

Commit 5aa065e

Browse files
authored
Improve reliability of inverse estimation (#50)
* add test with challenging inverse estimation for TPS * fix: better default beta parameter. way more step size updates * perf: allocate fewer small arrays * rm unused variables
1 parent ce8f984 commit 5aa065e

File tree

2 files changed

+112
-16
lines changed

2 files changed

+112
-16
lines changed

src/main/java/net/imglib2/realtransform/inverse/InverseRealTransformGradientDescent.java

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
import org.slf4j.Logger;
4242
import org.slf4j.LoggerFactory;
4343

44-
import java.util.Arrays;
45-
4644
public class InverseRealTransformGradientDescent implements RealTransform
4745
{
4846
int ndims;
@@ -72,7 +70,7 @@ public class InverseRealTransformGradientDescent implements RealTransform
7270

7371
double stepSz = 1.0;
7472

75-
double beta = 0.7;
73+
double beta = 0.5;
7674

7775
double tolerance = 0.5;
7876

@@ -84,18 +82,22 @@ public class InverseRealTransformGradientDescent implements RealTransform
8482

8583
double jacobianRegularizationEps = 0.1;
8684

87-
int stepSizeMaxTries = 10;
85+
int stepSizeMaxTries = 1000;
8886

8987
double maxStepSize = Double.MAX_VALUE;
9088

9189
double minStepSize = 1e-9;
9290

9391
private DifferentiableRealTransform xfm;
9492

95-
private double[] guess; // initialization for iterative inverse
96-
9793
protected static Logger logger = LoggerFactory.getLogger( InverseRealTransformGradientDescent.class );
9894

95+
private double[] srcd;
96+
private double[] tgtd;
97+
98+
private double[] x_ap;
99+
private double[] phix_ap;
100+
99101
public InverseRealTransformGradientDescent( int ndims, DifferentiableRealTransform xfm )
100102
{
101103
this.ndims = ndims;
@@ -109,6 +111,11 @@ public InverseRealTransformGradientDescent( int ndims, DifferentiableRealTransfo
109111
target = new double[ ndims ];
110112
estimate = new double[ ndims ];
111113
estimateXfm = new double[ ndims ];
114+
115+
srcd = new double[ ndims ];
116+
tgtd = new double[ ndims ];
117+
x_ap = new double[ ndims ];
118+
phix_ap = new double[ ndims ];
112119
}
113120

114121
public void setBeta( double beta )
@@ -223,9 +230,16 @@ public RealTransform copy()
223230
return copy;
224231
}
225232

233+
/**
234+
* Unused. Pass guess as parameter.
235+
*
236+
* @param guess
237+
* initial guess.
238+
*/
239+
@Deprecated
226240
public void setGuess( final double[] guess )
227241
{
228-
this.guess = guess;
242+
// no op
229243
}
230244

231245
public void apply( final double[] s, final double[] t )
@@ -240,8 +254,6 @@ public void apply( final double[] s, final double[] t )
240254
@Deprecated
241255
public void apply( final float[] src, final float[] tgt )
242256
{
243-
double[] srcd = new double[ src.length ];
244-
double[] tgtd = new double[ tgt.length ];
245257
for ( int i = 0; i < src.length; i++ )
246258
srcd[ i ] = src[ i ];
247259

@@ -253,8 +265,6 @@ public void apply( final float[] src, final float[] tgt )
253265

254266
public void apply( final RealLocalizable src, final RealPositionable tgt )
255267
{
256-
double[] srcd = new double[ src.numDimensions() ];
257-
double[] tgtd = new double[ tgt.numDimensions() ];
258268
src.localize( srcd );
259269
apply( srcd, tgtd );
260270
tgt.setPosition( tgtd );
@@ -284,7 +294,6 @@ public double inverseTol( final double[] target, final double[] guess, final dou
284294
int k = 0;
285295
while ( error >= tolerance && k < maxIters )
286296
{
287-
288297
/*
289298
* xfm.jacobian( estimate );
290299
*
@@ -305,9 +314,7 @@ public double inverseTol( final double[] target, final double[] guess, final dou
305314
updateEstimate( t ); // go in negative direction to reduce cost
306315
xfm.apply( estimate, estimateXfm );
307316
updateError();
308-
309317
error = getError();
310-
311318
k++;
312319
}
313320

@@ -412,15 +419,13 @@ public boolean armijoCondition( double c, double t )
412419
double[] d = dir;
413420
double[] x = estimate; // give a convenient name
414421

415-
double[] x_ap = new double[ ndims ];
416422
for ( int i = 0; i < ndims; i++ )
417423
x_ap[ i ] = x[ i ] + t * d[ i ];
418424

419425
// don't have to do this in here - this should be reused
420426
// double[] phix = xfm.apply( x );
421427
// TODO make sure estimateXfm is updated at the correct time
422428
double[] phix = estimateXfm;
423-
double[] phix_ap = new double[ this.ndims ];
424429
xfm.apply( x_ap, phix_ap );
425430

426431
double fx = squaredError( phix );
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package net.imglib2.realtransform.inverse;
2+
3+
import static org.junit.Assert.assertArrayEquals;
4+
5+
import java.util.Arrays;
6+
7+
import org.junit.Test;
8+
9+
import net.imglib2.realtransform.InvertibleRealTransform;
10+
import net.imglib2.realtransform.ThinplateSplineTransform;
11+
12+
public class ChallengingInverseTests
13+
{
14+
final int nd = 2;
15+
16+
double[][] srcPts = transpose( new double[][] {
17+
{64.93826968043624,136.5169914263442},
18+
{949.2718628215123,196.43904910366302},
19+
{481.4932190179267,94.95814497271988},
20+
{842.9585346843338,934.8334372564303},
21+
{1180.2617303195639,887.4756819953235},
22+
{1148.3677318784103,724.1397505845674},
23+
{403.87246158700066,1121.4631843388372},
24+
{282.79379326150405,659.0099372622876}});
25+
26+
double[][] tgtPts = transpose(new double[][] {
27+
{1.2160852376668474,2.982585528555182},
28+
{12.643983273336763,4.749085819443516},
29+
{6.74658999452494,2.792347035690284},
30+
{10.537771388046828,14.179479680032008},
31+
{15.130672144356495,13.75823730297402},
32+
{14.790960549954892,11.652025417684083},
33+
{4.577866506453839,16.017301399163536},
34+
{3.544833825697882,9.943901208947302}});
35+
36+
@Test
37+
public void testInv() {
38+
39+
final ThinplateSplineTransform tps = new ThinplateSplineTransform( tgtPts, srcPts );
40+
final WrappedIterativeInvertibleRealTransform tf = new WrappedIterativeInvertibleRealTransform( tps );
41+
InverseRealTransformGradientDescent opt = tf.getOptimzer();
42+
opt.setBeta(0.5);
43+
opt.setMaxIters(200000);
44+
opt.setTolerance(0.001);
45+
46+
int idx = 0;
47+
final double[] p0 = new double[] { srcPts[0][idx], srcPts[1][idx] };
48+
final double[] q0 = new double[] { tgtPts[0][idx], tgtPts[1][idx] };
49+
final double[] res = new double[nd];
50+
final double[] resXfm = new double[nd];
51+
52+
testInverseAndRoundTrip( tf, p0, q0, res, resXfm );
53+
54+
q0[0] = 0;
55+
q0[1] = 0;
56+
testInverseAndRoundTrip( tf, null, q0, res, resXfm );
57+
58+
q0[0] = 50;
59+
q0[1] = 120;
60+
testInverseAndRoundTrip( tf, null, q0, res, resXfm );
61+
}
62+
63+
private static void testInverseAndRoundTrip( InvertibleRealTransform tf, double[] p0, double[] q0, double[] res, double[] resXfm ) {
64+
65+
tf.applyInverse( res, q0 );
66+
tf.apply( res, resXfm );
67+
assertArrayEquals( q0, resXfm, 0.5 );
68+
}
69+
70+
private static double[][] transpose(double[][] matrix) {
71+
72+
if (matrix == null || matrix.length == 0) {
73+
return new double[0][0];
74+
}
75+
76+
// Create transposed matrix with swapped dimensions
77+
int rows = matrix.length;
78+
int cols = matrix[0].length;
79+
final double[][] transposed = new double[cols][rows];
80+
81+
// Fill the transposed matrix
82+
for (int i = 0; i < rows; i++) {
83+
for (int j = 0; j < cols; j++) {
84+
transposed[j][i] = matrix[i][j];
85+
}
86+
}
87+
88+
return transposed;
89+
}
90+
91+
}

0 commit comments

Comments
 (0)