29 #define HDF5_STATUS_CHECK(status) { \
31 std::cerr << __FILE__ << ":" << __LINE__ << \
32 ": Problem with writing to file. Status code=" \
33 << status << std::endl; \
57 mat.reset(
new double [n*m]);
67 mat.reset(
new double [n*m]);
68 std::memcpy(mat.get(), orig.
getpointer(), n*m*
sizeof(double));
79 mat = std::move(orig.mat);
86 mat.reset(
new double [n*m]);
87 std::memcpy(mat.get(), orig.
getpointer(), n*m*
sizeof(double));
97 for(
int i=0;i<n*m;i++)
109 daxpy_(&dim,&alpha,orig.mat.get(),&inc,mat.get(),&inc);
120 daxpy_(&dim,&alpha,orig.mat.get(),&inc,mat.get(),&inc);
182 assert(A.n == n && B.m == m);
184 dgemm_(&trans,&trans,&A.n,&B.m,&A.m,&alpha,A.mat.get(),&A.n,B.mat.get(),&B.n,&beta,mat.get(),&A.n);
199 int count_sing = std::min(n,m);
201 std::unique_ptr<double []> sing_vals(
new double[count_sing]);
204 int lwork = std::max( 3*count_sing + std::max(n,m), 5*count_sing);
205 std::unique_ptr<double []> work(
new double[lwork]);
207 std::unique_ptr<double []> vt(
new double[n*n]);
211 dgesvd_(&jobu,&jobvt,&n,&m,mat.get(),&n,sing_vals.get(),vt.get(),&n,0,&m,work.get(),&lwork,&info);
214 std::cerr <<
"svd failed. info = " << info << std::endl;
223 void matrix::mvprod(
const double *x,
double *y,
double beta)
const
230 dsymv_(&uplo,&n,&alpha,mat.get(),&n,x,&incx,&beta,y,&incx);
237 std::cout << i <<
"\t" << j <<
"\t" << (*
this)(i,j) << std::endl;
244 for(
int i=0;i<std::min(m,n);i++)
245 result += (*
this)(i,i);
258 std::vector<double> col(n);
260 std::memcpy(col.data(), &mat[idx*n],
sizeof(double)*n);
282 hid_t file_id, dataset_id, dataspace_id, attribute_id;
285 file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT);
288 hsize_t dimarray = n*m;
289 dataspace_id = H5Screate_simple(1, &dimarray, NULL);
291 dataset_id = H5Dcreate(file_id,
"matrix", H5T_IEEE_F64LE, dataspace_id, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
293 status = H5Dwrite(dataset_id, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, mat.get() );
296 status = H5Sclose(dataspace_id);
299 dataspace_id = H5Screate(H5S_SCALAR);
301 attribute_id = H5Acreate (dataset_id,
"n", H5T_STD_I32LE, dataspace_id, H5P_DEFAULT, H5P_DEFAULT);
302 status = H5Awrite (attribute_id, H5T_NATIVE_INT, &n );
305 status = H5Aclose(attribute_id);
308 attribute_id = H5Acreate (dataset_id,
"m", H5T_STD_I32LE, dataspace_id, H5P_DEFAULT, H5P_DEFAULT);
309 status = H5Awrite (attribute_id, H5T_NATIVE_INT, &m );
312 status = H5Aclose(attribute_id);
315 status = H5Sclose(dataspace_id);
318 status = H5Dclose(dataset_id);
321 status = H5Fclose(file_id);
327 hid_t file_id, dataset_id, attribute_id;
331 file_id = H5Fopen(filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
334 dataset_id = H5Dopen(file_id,
"matrix", H5P_DEFAULT);
337 attribute_id = H5Aopen(dataset_id,
"n", H5P_DEFAULT);
340 status = H5Aread(attribute_id, H5T_NATIVE_INT, &n_);
343 status = H5Aclose(attribute_id);
346 attribute_id = H5Aopen(dataset_id,
"m", H5P_DEFAULT);
349 status = H5Aread(attribute_id, H5T_NATIVE_INT, &m_);
352 status = H5Aclose(attribute_id);
356 std::cerr <<
"Matrix size not compatable with file: " << n_ <<
"x" << m_ << std::endl;
359 status = H5Dread(dataset_id, H5T_NATIVE_DOUBLE, H5S_ALL, H5S_ALL, H5P_DEFAULT, mat.get());
363 status = H5Dclose(dataset_id);
366 status = H5Fclose(file_id);
void SaveToFile(std::string filename) const
matrix & operator=(const matrix &orig)
void ReadFromFile(std::string filename) const
matrix & prod(matrix const &A, matrix const &B)
void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info)
void dsymv_(char *uplo, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy)
const double * GetColumnRaw(unsigned int idx) const
std::unique_ptr< double[]> svd()
#define HDF5_STATUS_CHECK(status)
void daxpy_(int *n, double *alpha, double *x, int *incx, double *y, int *incy)
std::vector< double > GetColumn(unsigned int) const
double operator()(int x, int y) const
double & operator[](int x)
double * getpointer() const
matrix & operator-=(const matrix &orig)
matrix & operator+=(const matrix &orig)
void mvprod(const double *x, double *y, double beta) const
void dgemm_(char *transA, char *transB, int *m, int *n, int *k, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc)