Wednesday, November 16, 2011

GAMP: Generalized Approximate Message Passing in GraphLab

About a week ago I was invited to visit Ohio State University by Phil Schniter. I heard about a lot of interesting research activities, that I will probably report in this blog in the future. Specifically, Phil is working on inference on very related linear models (relative to my work).

I heard from Phil about the GAMP method, an extension of Montanari's AMP (approximate message passing methods). GAMP is supposed to overcome some of the difficulties in AMP, namely Gaussian matrix assumption. Here is an excerpt from the GAMP wiki:
Generalized Approximate Message Passing (GAMP) is an approximate, but computationally efficient method for estimation problems with linear mixing. In the linear mixing problem an unknown vector, x, with independent components, is first passed through linear transform z=Ax and then observed through a general probabilistic, componentwise measurement channel to yield a measurement vector y. The problem is to estimate x and z from y and A. This problem arises in a range of applications including compressed sensing. Optimal solutions to linear mixing estimation problems are, in general, computationally intractable as the complexity of most brute force algorithms grows exponentially in the dimension of the vector . GAMP approximately performs the estimation through a Gaussian approximation of loopy belief propagation that reduces the vector-valued estimation problem to a sequence of scalar estimation problems on the components of the vectors x and z.
This project is intended to develop the theory and applications of GAMP. We also provide open source MATLAB software for others who would like to experiment with the algorithm.
Below is a simplified version of GAMP matlab code I got from phil. It is a highly simplified instance of the GAMPmatlab code, which includes a library of signal priors and matrix constructions, as well as stepsize control (to improve stability when the matrix has highly correlated columns). Automatic tuning of GAMPmatlab parameters is provided by the EM-BG-GAMP code.

% this m-file uses GAMP to estimate the matrix X (of dimension NxT) from the noisy matrix observation Y=A*X+W of dimension MxT.  here, A has independent Gaussian entries with matrix mean and variance (A_mean,A_var), W has iid Gaussian entries with scalar mean and variance (0,psi), and X has iid Bernoulli-Gaussian entries with scalar activity rate "lam" and with non-zero elements having scalar mean and variance (theta,phi).

 % declare signal parameters
 N = 1000;      % # weights     [dflt=1000]
 M = 400;   % # observations    [dflt=400]
 T = 10;   % # dimensions per observation   [dflt=10]
 lam = 0.01*ones(N,T);  % BG: prior probability of a non-zero weight [dflt=0.01]
 theta = 0*ones(N,T);  % BG: prior mean of non-zero weights  [dflt=0]
 phi = 1*ones(N,T);   % BG: prior variance of non-zero weights [dflt=1]
 SNRdB = 15;   % SNR in dB      [dflt=15]

 % declare algorithmic parameters
 iterMax = 50;   % maximum number of AMP iterations  [dflt=20]

 % generate true Bernoulli-Gaussian signal (i.e., "weights")
 B_true = (rand(N,T)>1-lam);   % sparsity pattern
 X_true = B_true.*(theta+sqrt(phi).*randn(N,T));% sparse weights

 % generate measurement (i.e., "feature") matrix
 A_mean = randn(M,N)/sqrt(M);       % element-wise mean 
 A_mean2 = A_mean.^2;
 A_var = abs(A_mean).^2;   % element-wise variance
 A_true = A_mean + sqrt(A_var).*randn(M,N); % realization

 % generate noisy observations
 psi = norm(A_true*X_true,'fro')^2/(M*T)*10^(-SNRdB/10);   % additive noise variance 
 Y = A_true*X_true + sqrt(psi).*randn(M,T);  % AWGN-corrupted observation

 % initialize GAMP
 X_mean = lam.*theta;  % initialize posterior means
 X_var = lam.*phi;  % initialize posterior variances
 S_mean = zeros(M,T);  % initialized scaled filtered gradient
 NMSEdB_ = NaN*ones(1,iterMax); % for debugging: initialize NMSE-versus-iteration

 % iterate GAMP
 for iter=1:iterMax,

   % computations at observation nodes
   Z_var = A_mean2*X_var;
   Z_mean = A_mean*X_mean;
   P_var = Z_var + A_var*(X_var + X_mean.^2);
   P_mean = Z_mean - Z_var.*S_mean;
   S_var = 1./(P_var+psi);
   S_mean = S_var.*(Y-P_mean);
% computations at variable nodes
R_var = 1./(A_mean2'*S_var + eps);
   R_mean = X_mean + R_var.*(A_mean'*S_mean);
   gam = (R_var.*theta + phi.*R_mean)./(R_var+phi);
   nu = phi.*R_var./(R_var+phi);
   pi = 1./(1+(1-lam)./lam .*sqrt((R_var+phi)./R_var) .* exp( ...
    -(R_mean.^2)./(2*R_var) + ((R_mean-theta).^2)./(2*(phi+R_var)) ));
   X_mean = pi.*gam;
   X_var = pi.*(nu+gam.^2) - (pi.*gam).^2;

   % for debugging: record normalized mean-squared error 
   NMSEdB_(iter) = 20*log10(norm(X_true-X_mean,'fro')/norm(X_true,'fro'));
 end;% iter

 % for debugging: plot posterior means, variances, and NMSE-versus-iteration
 subplot(311)
   errorbar(X_mean(:,1),sqrt(X_var(:,1)),'.');
   hold on; plot(X_true(:,1),'k.'); hold off;
   axis([0,size(X_true,1),-1.2*max(abs(X_true(:,1))),1.2*max(abs(X_true(:,1)))])
   ylabel('mean')
   grid on;
   title('GAMP estimates')
 subplot(312)
   plot(10*log10(X_var(:,1)),'.');
   axe=axis; axis([0,N,min(axe(3),-1),0]);
   ylabel('variance [dB]');
   grid on;
 subplot(313)
   plot(NMSEdB_,'.-')
   ylabel('NMSE [dB]')
   grid on;
   title(['NMSE=',num2str(NMSEdB_(iter),3),'dB '])
 drawnow;

Today I have implemented GAMP in GraphLab. The implementation in parallel is pretty straightforward. The main ovearhead is the matrix product with A (marked in red) and with A' (marked in green). Each of the variable and observation nodes is a graphlab node. There are two update functions, one for variable nodes and one for observation nodes. The code can be found in GraphLab linear solvers library.

Here is a snapshot from the implementation of the observation node update function:
void gamp__update_function(gl_types_gamp::iscope &scope,
         gl_types_gamp::icallback &scheduler) {

  /* GET current vertex data */
  vertex_data_gamp& vdata = scope.vertex_data();
  unsigned int id = scope.vertex();
  gamp_struct var(vdata, id);

  
/*
 *  computations at variable nodes
*/
   memset(var.R_var, 0, config.D*sizeof(sdouble));
   memset(var.R_mean, 0, config.D*sizeof(sdouble));
   for (unsigned int k=ps.m; k< ps.m+ps.n; k++){
      vertex_data_gamp  & other= g_gamp->vertex_data(k);
      gamp_struct obs(other,k);

      for (int i=0; i< config.D; i++){
        var.R_var[i] += pow(var.AT_i[k-ps.m],2) * obs.S_var[i];
        var.R_mean[i] += var.AT_i[k-ps.m] * obs.S_mean[i];
      }
  }

  for (int i=0; i< config.D; i++){
     var.R_var[i] = 1.0 / (var.R_var[i] + ps.gamp_eps);
     var.R_mean[i] = var.X_mean[i] + var.R_var[i] * var.R_mean[i];
     var.gam[i] = (var.R_var[i]*ps.gamp_theta + ps.gamp_phi*var.R_mean[i]) / (var.R_var[i] + ps.gamp_phi);
     var.nu[i] = ps.gamp_phi*var.R_var[i] / (var.R_var[i] + ps.gamp_phi);
     var.pi[i] = 1.0 / (1.0 + (1.0 - ps.gamp_lam) / ps.gamp_lam * sqrt((var.R_var[i]+ps.gamp_phi)/var.R_var[i]) * exp( \
      -(powf(var.R_mean[i],2)) / (2*var.R_var[i]) + (powf(var.R_mean[i] - ps.gamp_theta,2)) / (2*(ps.gamp_phi+var.R_var[i])) ));
     var.X_mean[i] = var.pi[i] * var.gam[i];
     var.X_var[i] = var.pi[i] * (var.nu[i]+ pow(var.gam[i],2)) - pow(var.pi[i] * var.gam[i],2);
  }
}
The code is slightly longer then matlab, and looks somewhat uglier since the vector loop operation have to be explicitly expanded. I am not yet happy with performance - since Matlab is highly optimized for the matrix product of dense matrices. I still need to think how to improve - maybe by blocking and using lapack...

2 comments:

  1. Danny,

    Outstanding. Phil was probably too modest to tell you that his latest solver ( http://www2.ece.ohio-state.edu/~vilaj/EMBGAMP/EMBGAMP.html ) blows away all the competition: http://media.aau.dk/null_space_pursuits/2011/11/phase-transitions-in-higher-dimension.html

    Cheers,

    Igor.

    ReplyDelete
  2. Igor - I am amazed!! I did not finish writing the blog entry and you are already aware of it.. :-)
    I will soon add some more implementation details..

    ReplyDelete