0001 function img=inv_solve_diff_pdipm( inv_model, data1, data2)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024 pp= process_parameters(inv_model);
0025
0026 fwd_model= inv_model.fwd_model;
0027
0028 d=calc_difference_data( data1, data2, fwd_model);
0029
0030 img_bkgnd=calc_jacobian_bkgnd( inv_model );
0031 J=calc_jacobian(img_bkgnd);
0032
0033 alpha=calc_hyperparameter( inv_model );
0034 L=calc_R_prior( inv_model );
0035 W= calc_meas_icov( inv_model );
0036 if pp.norm_data==1
0037 W = sqrt(W);
0038 end
0039
0040 if pp.norm_data==2 && pp.norm_image==2
0041 x= pdipm_2_2( J,W,alpha*L,d, pp);
0042 elseif pp.norm_data==2 && pp.norm_image==1
0043 x= pdipm_2_1( J,W,alpha*L,d, pp);
0044 elseif pp.norm_data==1 && pp.norm_image==2
0045 x= pdipm_1_2( J,W,alpha*L,d, pp);
0046 elseif pp.norm_data==1 && pp.norm_image==1
0047 x= pdipm_1_1( J,W,alpha*L,d, pp);
0048 end
0049
0050
0051 img.name = 'inv_solve_diff_pdipm';
0052 img.elem_data = x;
0053 img.fwd_model = fwd_model;
0054
0055 function s= pdipm_2_2( J,W,L,d, pp);
0056 [s]= initial_values( J, L, pp);
0057
0058 R = L'*L;
0059 ds= (J'*W*J + R)\(J'*W*(d - J*s) - R*s);
0060 s= s + ds;
0061
0062 function m= pdipm_1_2( J,W,L,d, pp);
0063 [m,x,jnk,sz]= initial_values( J, L, pp);
0064
0065 I_M = speye(sz.M, sz.M);
0066 for loop = 1:pp.max_iter
0067
0068 f = J*m - d; F= spdiag(f);
0069 X= spdiag(x);
0070 e = sqrt(f.^2 + pp.beta);E= spdiag(e);
0071
0072
0073 dFc_dm = (I_M - X*inv(E)*F)*J;
0074 dFc_dx = -E;
0075 dFf_dm = L'*L;
0076 dFf_dx = J'*W;
0077
0078 dmdx = -[dFc_dm, dFc_dx; dFf_dm, dFf_dx] \ ...
0079 [ f-E*x; J'*W*x + L'*L*m ];
0080
0081 dm = dmdx( 1:sz.N);
0082 dx = x_update(x, dmdx(sz.N+(1:sz.M)));
0083
0084 m= m + dm; x= x + dx;
0085 loop_display(i)
0086 debug([mean(abs([m,dm])) mean(abs([x,dx]))])
0087 pp = manage_beta(pp);
0088 end
0089
0090 function m= pdipm_2_1( J,W,L,d, pp);
0091 [m,jnk,y,sz]= initial_values( J, L, pp);
0092
0093 I_D = speye(sz.D, sz.D);
0094 for loop = 1:pp.max_iter
0095
0096 g = L*m; G= spdiag(g);
0097 Y= spdiag(y);
0098 s = sqrt(g.^2 + pp.beta);S= spdiag(s);
0099
0100
0101 dFf_dm = 2*J'*W*J;
0102 dFf_dy = L';
0103 dFc_dm = (I_D - Y*inv(S)*G)*L;
0104 dFc_dy = -S;
0105
0106 dmdy = -[dFf_dm, dFf_dy; dFc_dm, dFc_dy] \ ...
0107 [ J'*W*(J*m-d) + L'*y; g-S*y ];
0108
0109 dm = dmdy( 1:sz.N );
0110 dy = x_update(y, dmdy(sz.N+(1:sz.D)));
0111
0112 m= m + dm; y= y + dy;
0113 loop_display(i)
0114 debug([mean(abs([m,dm])), mean(abs([y,dy]))]);
0115 pp = manage_beta(pp);
0116 end
0117
0118 function m= pdipm_1_1( J,W,L,d, pp);
0119 [m,x,y,sz]= initial_values( J, L, pp);
0120
0121 I_M = speye(sz.M,sz.M);
0122 I_D = speye(sz.D,sz.D);
0123 Z_N = sparse(sz.N,sz.N);
0124 Z_DM= sparse(sz.D,sz.M);
0125 for loop = 1:pp.max_iter
0126
0127 g = L*m; G= spdiag(g);
0128 r = sqrt(g.^2 + pp.beta);R= spdiag(r);
0129 Y= spdiag(y);
0130
0131 f = J*m - d; F= spdiag(f);
0132 e = sqrt(f.^2 + pp.beta);E= spdiag(e);
0133 X= spdiag(x);
0134
0135
0136 As1 = Z_N;
0137 As2 = (I_M - X*inv(E)*F) * J;
0138 As3 = (I_D - Y*inv(R)*G) * L;
0139 Ax1 = J'*W;
0140 Ax2 = -E;
0141 Ax3 = Z_DM;
0142 Ay1 = L';
0143 Ay2 = Z_DM';
0144 Ay3 = -R;
0145 B1 = J'*W*x + L'*y;
0146 B2 = f - E*x;
0147 B3 = g - R*y;
0148
0149 DD = -[As1,Ax1,Ay1; ...
0150 As2,Ax2,Ay2; ...
0151 As3,Ax3,Ay3] \ [B1;B2;B3];
0152
0153 dm = DD(1:sz.N);
0154 dx = x_update(x, DD(sz.N + (1:sz.M)) );
0155 dy = x_update(y, DD(sz.N + sz.M + (1:sz.D)) );
0156
0157 m= m + dm;
0158 x= x + dx;
0159 y= y + dy;
0160 loop_display(i)
0161 debug([mean(abs([m,dm])), mean(abs([x,dx])), mean(abs([y,dy]))]);
0162 pp = manage_beta(pp);
0163 end
0164
0165
0166 function sM = spdiag(V)
0167 lV = length(V);
0168 sM = spdiags(V,0,lV,lV);
0169
0170 function [s,x,y,sz]= initial_values( J, L, pp);
0171 [sz.M,sz.N] = size(J);
0172 [sz.D ] = size(L,1);
0173 y= zeros( sz.D, 1 );
0174 s= zeros( sz.N, 1 );
0175 x= zeros( sz.M, 1 );
0176
0177
0178 function dx = x_update( x, dx)
0179
0180 dx(dx==0) = eps;
0181
0182 sx = sign(dx);
0183 clr = sx - x;
0184
0185 fac = clr./dx;
0186
0187 dx = dx*min(abs(fac));
0188
0189
0190 function debug(vals)
0191
0192
0193 function pp = manage_beta(pp);
0194 pp.beta = pp.beta * pp.beta_reduce;
0195 if pp.beta < pp.beta_minimum;
0196 pp.beta = pp.beta_minimum;
0197 end
0198
0199 function pp= process_parameters(imdl);
0200 try pp.max_iter = imdl.parameters.max_iterations;
0201 catch pp.max_iter = 10;
0202 end
0203
0204 try pp.min_change = imdl.parameters.min_change;
0205 catch pp.min_change = 0;
0206 end
0207
0208 try pp.beta = imdl.inv_solve_diff_pdipm.beta;
0209 catch pp.beta = 1e-6;
0210 end
0211
0212 pp.beta_reduce = 0.2;
0213 pp.beta_minimum= 1e-16;
0214
0215 try pp.norm_data = imdl.inv_solve_diff_pdipm.norm_data;
0216 catch pp.norm_data = 2;
0217 end
0218
0219 try pp.norm_image = imdl.inv_solve_diff_pdipm.norm_image;
0220 catch pp.norm_image = 2;
0221 end
0222
0223 function loop_display(i)
0224 fprintf('+');