111 lines
2.8 KiB
C++
111 lines
2.8 KiB
C++
#pragma once
|
|
|
|
#include <stddef.h>
|
|
#include <ostream>
|
|
#include <vector>
|
|
|
|
template <typename T>
|
|
class MatrixView;
|
|
|
|
template <typename T>
|
|
class Matrix {
|
|
public:
|
|
Matrix(size_t nrow, size_t ncol) : _nrow{nrow}, _ncol{ncol} {
|
|
_elem.reserve(nrow * ncol);
|
|
};
|
|
|
|
size_t nrow() const { return _nrow; };
|
|
size_t ncol() const { return _ncol; };
|
|
size_t size() const { return _nrow * _ncol; };
|
|
|
|
T& operator()(int i) { return _elem[index(i)]; };
|
|
const T& operator()(int i) const { return _elem[index(i)]; };
|
|
T& operator()(int i, int j) { return _elem[index(i, j)]; };
|
|
const T& operator()(int i, int j) const { return _elem[index(i, j)]; };
|
|
|
|
private:
|
|
size_t _nrow;
|
|
size_t _ncol;
|
|
std::vector<T> _elem;
|
|
|
|
size_t index(int i) const {
|
|
int nelem = static_cast<int>(_nrow * _ncol);
|
|
while (i < 0) { i += nelem; }
|
|
while (i >= nelem) { i -= nelem; }
|
|
|
|
return i;
|
|
};
|
|
size_t index(int i, int j) const {
|
|
int nrow = static_cast<int>(_nrow);
|
|
int ncol = static_cast<int>(_ncol);
|
|
|
|
while (i < 0) { i += nrow; }
|
|
while (i >= nrow) { i -= nrow; }
|
|
while (j < 0) { j += ncol; }
|
|
while (j >= ncol) { j -= ncol; }
|
|
|
|
return static_cast<size_t>(i + j * _nrow);
|
|
};
|
|
};
|
|
|
|
template <typename T>
|
|
std::ostream& operator<<(std::ostream& out, const Matrix<T>& mat) {
|
|
for (size_t i = 0; i < mat.nrow(); ++i) {
|
|
for (size_t j = 0; j < mat.ncol(); ++j) {
|
|
out << mat(i, j) << ' ';
|
|
}
|
|
out << '\n';
|
|
}
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
class MatrixView {
|
|
public:
|
|
MatrixView(Matrix<T>& matrix, size_t index, size_t stride, size_t nelem) :
|
|
_matrix(matrix), _index{index}, _stride{stride},
|
|
_nelem{nelem} { };
|
|
|
|
const size_t size() const { return _nelem; };
|
|
|
|
T& operator()(int i) { return _matrix(_index + i * _stride); };
|
|
const T& operator()(int i) const { return _matrix(_index + i * _stride); };
|
|
|
|
protected:
|
|
Matrix<T>& _matrix;
|
|
size_t _index;
|
|
size_t _stride;
|
|
size_t _nelem;
|
|
};
|
|
|
|
template <typename T>
|
|
std::ostream& operator<<(std::ostream& out, const MatrixView<T>& view) {
|
|
for (size_t i = 0; i < view.size(); ++i) {
|
|
out << view(i) << ' ';
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
class Row : public MatrixView<T> {
|
|
public:
|
|
Row(Matrix<T>& matrix, size_t index) :
|
|
MatrixView<T>(matrix, index * matrix.ncol(), 1, matrix.nrow()) { };
|
|
};
|
|
|
|
template <typename T>
|
|
class Col : public MatrixView<T> {
|
|
public:
|
|
Col(Matrix<T>& matrix, size_t index) :
|
|
MatrixView<T>(matrix, index, matrix.nrow(), matrix.ncol()) { };
|
|
};
|
|
|
|
template <typename T>
|
|
class Diag : public MatrixView<T> {
|
|
public:
|
|
Diag(Matrix<T>& matrix) :
|
|
MatrixView<T>(matrix, 0, matrix.nrow() + 1,
|
|
matrix.nrow() < matrix.ncol() ? matrix.nrow() : matrix.ncol()) { };
|
|
};
|