// Compile with
// g++ -O3 -ffast-math -march=pentium4 lr.cpp libacml.a  -o lr -lgfortran


struct Matrix {
  double* data;
  int n;
  int m;
  int lda;
  
  Matrix(double* data_, int n_, int m_, int lda_)
  : data(data_), n(n_), m(m_), lda(lda_) {}

  double& operator()(int i, int j) {
    return data[j*lda+i];
  }
};

#include <algorithm>
#include <cassert>
#include <cmath>
#include <iostream>
#include <vector>


// C = beta*C + alpha*A*B
void dgemm(double alpha, Matrix& A, Matrix& B, double beta, Matrix& C) {
  assert(C.n==A.n);
  assert(C.m==B.m);
  assert(A.m==B.n);
  
  for (int j=0; j<C.m; ++j)
    for (int i=0; i<C.n; ++i) {
      double tmp = 0;
      for (int k=0; k<A.m; ++k)
	tmp += A(i,k)*B(k,j);
      C(i,j) = beta*C(i,j) + alpha*tmp;
    }
}

int const N = 32;

// C = C + A*B
void dgemmFix(Matrix& A, Matrix& B, Matrix& C) {

  assert(C.n==A.n);
  assert(C.m==B.m);
  assert(A.m==B.n);
  assert(C.n==N);
  assert(C.m==N);
  assert(A.m==N);
  assert(A.lda==N);
  
  for (int j=0; j<N; ++j)
    for (int i=0; i<N; ++i) {
      double tmp = 0;
      for (int k=0; k<N; ++k)
	tmp += A(i,k)*B(k,j);
      C(i,j) = C(i,j) + tmp;
    }
}
void blockDGEMM(double alpha, Matrix& A, Matrix& B, double beta, Matrix& C) {

  if (A.n<=N && A.m<=N && B.m<=N) {
    dgemm(alpha,A,B,beta,C);
    return;
  }

  std::vector<double> tmpDataA(N*N), tmpDataB(N*N), tmpDataC(N*N);

  for (int j=0; j<C.m; j+=N)
    for (int i=0; i<C.n; i+=N) {
      int n = std::min(N,C.n-i);
      int m = std::min(N,C.m-j);

      Matrix tmpC(&tmpDataC[0],n,m,N);

      // Copy beta*C block to tmpC
      for (int sj=0; sj<m; ++sj)
	for (int si=0; si<n; ++si)
	  tmpC(si,sj) = beta*C(i+si,j+sj);

      for (int k=0; k<A.n; k+=N) {
	int o = std::min(N,A.n-k);

	Matrix tmpA(&tmpDataA[0],n,o,N), tmpB(&tmpDataB[0],o,m,N);

	// Copy A and B to temporaries
	for (int sk=0; sk<o; ++sk)
	  for (int si=0; si<n; ++si)
	    tmpA(si,sk) = A(i+si,k+sk);
	for (int sj=0; sj<m; ++sj)
	  for (int sk=0; sk<0; ++sk)
	    tmpB(sk,sj) = B(k+sk,j+sj);

	// tmpC = tmpC + tmpA*tmpB
	if (n==N && m==N && o==N)
	  dgemmFix(tmpA,tmpB,tmpC);
	else
	  dgemm(1.0,tmpA,tmpB,1.0,tmpC);
      }

      // Copy tmpC to C block
      for (int sj=0; sj<m; ++sj)
	for (int si=0; si<n; ++si)
	  C(i+si,j+sj) = tmpC(si,sj);
    }
}

  

  



// In-place LR decomposition of A in column-major storage
// without pivoting (for simplicity)

void lr(Matrix& A) {
  assert(A.n==A.m);
  
  // step through pivot elements on diagonal
  for (int i=0; i<A.n-1; ++i) {
    // step through rows
    for (int j=i+1; j<A.n; ++j) {
      A(j,i) /= A(i,i);

      // update trailing matrix
      for (int k=i+1; k<A.n; ++k) 
	A(j,k) -= A(j,i)*A(i,k); 
    }
  }
}

// inplace solve of L*X=B. The diagonal of L is assumed to be 
// the identity - only subdiagonal elements are accessed.
void trisolveL(Matrix& L, Matrix& B) {
  for (int i=0; i<B.n; ++i)
    for (int j=0; j<B.m; ++j)
      for (int k=0; k<i; ++k)
	B(i,j) -= L(i,k)*B(k,j);
}

void trisolveLBlock(Matrix& L, Matrix& B) {
  if (L.n<=4*N) {
    trisolveL(L,B);
    return;
  }

  int bs = L.n/2;

  Matrix L00(L.data,bs,bs,L.lda);
  Matrix L10(L.data+bs,L.n-bs,bs,L.lda);
  Matrix L11(L.data+bs+bs*L.lda,L.n-bs,L.m-bs,L.lda);
  Matrix B0(B.data,bs,B.m,B.lda);
  Matrix B1(B.data+bs,B.n-bs,B.m,B.lda);

  // solve L00*X0 = B0;
  trisolveLBlock(L00,B0);
  
  // compute B1 <- B1 - L10*X0
  blockDGEMM(-1,L10,B0,1,B1);
  
  // solve L11*X1 = B1
  trisolveLBlock(L11,B1);
}


// inplace solve of X*R = B
void trisolveR(Matrix& R, Matrix& B) {
  // solve as R'*X' = B'
  for (int i=0; i<B.m; ++i) 
    for (int j=0; j<B.n; ++j) {
      for (int k=0; k<i; ++k)
	B(j,i) -= R(k,i)*B(j,k);
      B(j,i) /= R(i,i);
    }
}

void trisolveRBlock(Matrix& R, Matrix& B) {
  if (R.n<=4*N) {
    trisolveR(R,B);
    return;
  }

  int bs = R.n/2;

  Matrix R00(R.data,bs,bs,R.lda);
  Matrix R01(R.data+bs*R.lda,bs,bs,R.lda);
  Matrix R11(R.data+bs+bs*R.lda,R.n-bs,R.m-bs,R.lda);
  Matrix B0(B.data,B.n,bs,B.lda);
  Matrix B1(B.data+bs*B.lda,B.n,B.m-bs,B.lda);

  // solve X0*R00 = B0
  trisolveRBlock(R00,B0);

  // compute B1 <- B1 - X0*R01
  blockDGEMM(-1,B0,R01,1,B1);

  // solve X1*R11 = B1
  trisolveRBlock(R11,B1);
}

void blockLR(Matrix& A) {
  if (A.n <= N) {
    lr(A);
    return;
  }

  int bs = A.n/2;

  
  Matrix A00(A.data,bs,bs,A.lda);
  Matrix A01(A.data+bs*A.lda,bs,A.m-bs,A.lda);
  Matrix A10(A.data+bs,A.n-bs,bs,A.lda);
  Matrix A11(A.data+bs*A.lda+bs,A.n-bs,A.m-bs,A.lda);
  
  // Compute L00*R00 = A00
  blockLR(A00);

  // Solve L00*R01 = A01
  trisolveLBlock(A00,A01);

  // Solve R00'*L10' = A10'
  trisolveRBlock(A00,A10);

  // Compute L11R11 = A11 - L10*R01
  blockDGEMM(-1,A10,A01,1,A11);
  blockLR(A11);
}


void trisolve(Matrix& LR, Matrix& b) {
  assert(LR.n==LR.m);
  assert(LR.n==b.n);
  
  // inplace solve Lz = b
  for (int i=0; i<LR.n; ++i) 
    for (int j=0; j<i; ++j)
      for (int k=0; k<b.m; ++k)
	b(i,k) -= LR(i,j)*b(j,k);
      
  // inplace solve Rx = b
  for (int i=LR.n-1; i>=0; --i) 
    for (int k=0; k<b.m; ++k) {
      for (int j=i+1; j<LR.n; ++j)
	b(i,k) -= LR(i,j)*b(j,k);
      b(i,k) /= LR(i,i);
    }
}


#include <boost/timer.hpp>
#include <cmath>
#include <vector>

double check(int n) {
  std::vector<double> data(n*n);
  Matrix A(&data[0],n,n,n);

  for (int j=0; j<n; ++j)
    for (int i=0; i<n; ++i)
      A(i,j) = 1.0/std::pow(2.0,std::abs(i-j));
    
  std::vector<double> save(data);

  boost::timer timer;
  
  int run = 0;
  for (; timer.elapsed()<20; ++run) {

    std::copy(save.begin(),save.end(),data.begin());
    
    lr(A);

/*    std::vector<double> rhs(n);
    std::fill(rhs.begin(),rhs.end(),1.0); 
    Matrix b;
    b.data = &rhs[0];
    b.n = n;
    b.m = 1;
    b.lda = n;
    
    trisolve(A,b);
    std::cout << "b = [";
    for (int i=0; i<n; ++i)
      std::cout << rhs[i] << ';';
    std::cout << "];\n";
    abort();
 */   
    // std::cout.precision(7);
    // for (int i=0; i<n; ++i) {
    //   for (int j=0; j<n; ++j) {
    //     std::cout.width(10);
    //     std::cout << A(i,j) << "  ";
    //   }
    //   std::cout << "\n";
    // }
    // abort();
  }

  return timer.elapsed()/run;
}

double checkBlock(int n) {
  std::vector<double> data(n*n);
  Matrix A(&data[0],n,n,n);

  for (int j=0; j<n; ++j)
    for (int i=0; i<n; ++i)
      A(i,j) = 1.0/std::pow(2.0,std::abs(i-j));

  std::vector<double> save(data);
    
  boost::timer timer;
  
  int run = 0;
  for (; timer.elapsed()<20; ++run) {
    
    std::copy(save.begin(),save.end(),data.begin());
    blockLR(A);

  }

  return timer.elapsed()/run;
}

double checkDGEMM(int n) {
  std::vector<double> dataA(n*n), dataB(n*n), dataC(n*n);
  Matrix A(&dataA[0],n,n,n), B(&dataB[0],n,n,n), C(&dataC[0],n,n,n);

  for (int j=0; j<n; ++j)
    for (int i=0; i<n; ++i) {
      A(i,j) = 1.0/std::pow(2.0,std::abs(i-j));
      B(i,j) = A(i,j);
    }
    
  boost::timer timer;
  
  int run = 0;
  for (; timer.elapsed()<20; ++run) {
    
    dgemm(1.0,A,B,0.0,C);
  }

  return timer.elapsed()/run;
}

double checkBlockDGEMM(int n) {
  std::vector<double> dataA(n*n), dataB(n*n), dataC(n*n);
  Matrix A(&dataA[0],n,n,n), B(&dataB[0],n,n,n), C(&dataC[0],n,n,n);

  for (int j=0; j<n; ++j)
    for (int i=0; i<n; ++i) {
      A(i,j) = 1.0/std::pow(2.0,std::abs(i-j));
      B(i,j) = A(i,j);
    }
    
  boost::timer timer;
  
  int run = 0;
  for (; timer.elapsed()<20; ++run) {

    
    blockDGEMM(1.0,A,B,0.0,C);
  }

  return timer.elapsed()/run;
}

#include "acml.h"

double checkACMLDGEMM(int n) {
  std::vector<double> dataA(n*n), dataB(n*n), dataC(n*n);
  Matrix A(&dataA[0],n,n,n), B(&dataB[0],n,n,n), C(&dataC[0],n,n,n);

  for (int j=0; j<n; ++j)
    for (int i=0; i<n; ++i) {
      A(i,j) = 1.0/std::pow(2.0,std::abs(i-j));
      B(i,j) = A(i,j);
    }
    
  boost::timer timer;
  
  int run = 0;
  for (; timer.elapsed()<20; ++run) {

    dgemm('N','N',C.n,C.m,A.m,1.0,A.data,A.lda,B.data,B.lda,0.0,C.data,C.lda);
  }

  return timer.elapsed()/run;
}



int main(void) {
  // std::cout << "# n  flops naiv   flops blocked\n";
  // for (int i=8; i<5000; i=(int)(1.21*i)+4) {
  //   double flops = std::pow((double)i,3.0);
  //   std::cout << i << " " << flops/checkBlockDGEMM(i) << '\n';
  //   std::cout.flush();
  // }
  // abort();

  for (int i=32; i<10000; i=(int)(1.21*i)+4) {
    double flops = 2./3 * std::pow((double)i,3.0) - 0.5*i*i - i/6.0;
    // std::cout << i << " " << flops/check(i)<< ' ' << flops/checkBlock(i,128) << '\n';
    std::cout << i << " " << flops/checkBlock(i) << '\n';

    std::cout.flush();
  }
  return 0;
}
