【矩阵运算c++实现】矩阵封装实现Matrix类
Matrix类封装了矩阵运算里的常用几种函数
Matrix.h
#include <string>
#include <sstream>typedef struct MatrixShape {int row;int col;int size() const {return row*col;}std::string ToStr() const {std::stringstream ss;ss << "(" << row << "," << col << ")";return ss.str();}
} MatrixShape;template<typename T>
class Matrix {
public:Matrix(int row, int col);Matrix(int row, int col, const T array[]);Matrix(int row, int col, const std::initializer_list<T> &array);Matrix(const Matrix &other);Matrix(Matrix &&other) noexcept ;virtual ~Matrix();Matrix operator+(const Matrix &other);Matrix operator-(const Matrix &other);Matrix operator*(const Matrix &other);// Not implemented yetMatrix operator/(const Matrix &other) = delete;Matrix &operator=(const Matrix &other);Matrix &operator=(Matrix &&other) noexcept;bool operator==(const Matrix &other);bool operator!=(const Matrix &other);T &operator[](int index);const T &operator[](int index) const;const T &At(int r, int c) const;const MatrixShape &Shape() const;void Fill(const T array[]);void Fill(const std::initializer_list<T> &array);bool Reshape(int row, int col);Matrix<T> & ReverseSelf();Matrix<T> Reverse();void Eye();void Zeros();Matrix Dot(const Matrix &other);Matrix Dot(T value);Matrix Divide(T value);T Extract();private:MatrixShape shape_;T *data_;
};
Matrix.cpp
#include "Matrix.h"#include <cassert>
#include <cstring>template class Matrix<int>;
template class Matrix<long>;
template class Matrix<float>;
template class Matrix<double>;template<typename T>
Matrix<T>::Matrix(int row, int col) : shape_({row, col}) {data_ = new T[row*col];
}template<typename T>
Matrix<T>::Matrix(int row, int col, const T array[]) : shape_({row, col}) {data_ = new T[row*col];memcpy(data_, array, sizeof(T)*row*col);
}template<typename T>
Matrix<T>::Matrix(int row, int col, const std::initializer_list<T> &array) : shape_({row, col}) {data_ = new T[row*col];for (int i=0; i < row*col; i++) {data_[i] = *(array.begin()+i);}
}template<typename T>
Matrix<T>::Matrix(const Matrix<T> &other) : shape_({other.shape_.row, other.shape_.col}) {data_ = new T[shape_.size()];
}template<typename T>
Matrix<T>::Matrix(Matrix<T> &&other) noexcept : shape_(other.shape_) {this->data_ = other.data_;other.data_ = nullptr;
}template<typename T>
Matrix<T>::~Matrix() {if (data_) {delete[] data_;data_ = nullptr;}
}template<typename T>
Matrix<T> Matrix<T>::operator+(const Matrix<T> &other) {assert(this->shape_.col == other.shape_.col && this->shape_.row == other.shape_.row);Matrix<T> c(this->shape_.row, other.shape_.col);const int size = this->shape_.size();for (int n=0; n<size; n++) {c.data_[n] = this->data_[n] + other.data_[n];}return c;
}template<typename T>
Matrix<T> Matrix<T>::operator-(const Matrix<T> &other) {assert(this->shape_.col == other.shape_.col && this->shape_.row == other.shape_.row);Matrix<T> c(this->shape_.row, other.shape_.col);const int size = this->shape_.size();for (int n=0; n<size; n++) {c.data_[n] = this->data_[n] - other.data_[n];}return c;
}template<typename T>
Matrix<T> Matrix<T>::operator*(const Matrix<T> &other) {assert(this->shape_.col == other.shape_.row);Matrix<T> c(this->shape_.row, other.shape_.col);int m = this->shape_.row*other.shape_.col;for (int n=0; n<m; n++) {int ii = n/other.shape_.col; // column indexint jj = n%other.shape_.col; // row indexc.data_[n] = 0;for (int l = 0; l< this->shape_.col; l++) {c.data_[n] += this->data_[this->shape_.row*ii+l] *other.data_[other.shape_.col*l+jj];}}return c;
}template<typename T>
Matrix<T> &Matrix<T>::operator=(const Matrix<T> &other) {if (*this == other) return *this;delete [] this->data_;const int size = other.shape_.size();this->data_ = new T[size];memcpy(this->data_, other.data_, size* sizeof(T));return *this;
}template<typename T>
Matrix<T> &Matrix<T>::operator=(Matrix<T> &&other) noexcept{this->shape_ = other.shape_;this->data_ = other.data_;other.data_ = nullptr;return *this;
}template<typename T>
bool Matrix<T>::operator==(const Matrix &other) {if (this->shape_.row != other.shape_.row || this->shape_.col != other.shape_.col) return false;const int size = other.shape_.size();for (int n = 0; n < size; n++) {if (this->data_[n] != other.data_[n]) return false;}return true;
}template<typename T>
bool Matrix<T>::operator!=(const Matrix<T> &other) {return !this->operator==(other);
}template<typename T>
T &Matrix<T>::operator[](int index) {return this->data_[index];
}template<typename T>
const T &Matrix<T>::operator[](int index) const {return this->data_[index];
}template<typename T>
const T &Matrix<T>::At(int r, int c) const {return this->data_[r*shape_.col+c];
}template<typename T>
const MatrixShape &Matrix<T>::Shape() const {return this->shape_;
}template<typename T>
void Matrix<T>::Fill(const T array[]) {memcpy(this->data_, array, sizeof(T)*this->shape_.size());
}template<typename T>
void Matrix<T>::Fill(const std::initializer_list<T> &array) {for (int i=0; i < this->shape_.size(); i++) {data_[i] = *(array.begin()+i);}
}template<typename T>
bool Matrix<T>::Reshape(int row, int col) {if (this->shape_.row == row && this->shape_.col == col) return true;if (this->shape_.size() != row*col) return false;this->shape_.row = row;this->shape_.col = col;return true;
}template<typename T>
Matrix<T> & Matrix<T>::ReverseSelf() {int tmp = this->shape_.row;this->shape_.row = this->shape_.col;this->shape_.col = tmp;return *this;
}template<typename T>
Matrix<T> Matrix<T>::Reverse() {Matrix<T> m(*this);m.ReverseSelf();return m;
}template<typename T>
void Matrix<T>::Eye() {for (int j=0;j<this->shape_.col; j++) {for (int i=0; i<this->shape_.row; i++) {if (i == j) {this->data_[i*this->shape_.row +j] = (T)1;} else {this->data_[i*this->shape_.row +j] = 0;}}}
}template<typename T>
void Matrix<T>::Zeros() {const int size = this->shape_.size();for (int i = 0; i < size; i++) {this->data_[i] = 0;}
}template<typename T>
Matrix<T> Matrix<T>::Dot(const Matrix<T> &other) {if (other.shape_.row == 1 && other.shape_.col == 1) {return this->Dot(other.Extract());}assert(this->shape_.col == other.shape_.col && this->shape_.row == other.shape_.row);Matrix<T> c(this->shape_.row, this->shape_.col);const int size = this->shape_.size();for (int n = 0; n < size; n++) {c.data_[n] = this->data_[n] * other.data_[n];}return c;
}template<typename T>
Matrix<T> Matrix<T>::Dot(T value) {Matrix<T> c(this->shape_.row, this->shape_.col);const int size = this->shape_.size();for (int n = 0; n < size; n++) {c.data_[n] = this->data_[n] * value;}return c;
}template<typename T>
Matrix<T> Matrix<T>::Divide(T value) {Matrix<T> c(this->shape_.row, this->shape_.col);const int size = this->shape_.size();for (int n = 0; n < size; n++) {c.data_[n] = this->data_[n] / value;}return c;
}template<typename T>
T Matrix<T>::Extract() {assert(this->shape_.col == 1 && this->shape_.row == 1);return this->data_[0];
}
测试
template<typename T>
void print_matrix(const Matrix<T> &m) {std::cout << "shape:" << m.Shape().ToStr() << std::endl;for (int i = 0; i < m.Shape().row; i++) {for (int j = 0; j < m.Shape().col; j++) {std::cout << m.At(i, j) << " ";}std::cout << std::endl;}
}int main(int argc, char* argv[]) {// a = { 1,2// 3,4// }// shape is (2,2)int a[4] = {1, 2, 3, 4};// b = { 0,1,2// 3,1,2// }// shape is (2,3)int b[6] = {0, 1, 2, 3, 1, 2};//matrix_multiply(a, 2, 2, b, 2, 3, c);Matrix<int> m1(2, 2, a);Matrix<int> m2(2, 3, b);Matrix<int> c = m1 * m2;print_matrix(c);Matrix<int> m3(2, 3, {1,0,0,1,0,0});c = c+m3;print_matrix(c);c = c-m3;print_matrix(c);c = c.Dot(2);print_matrix(c);c = c.Divide(2);print_matrix(c);c.ReverseSelf();print_matrix(c);Matrix<int> d(3,3);d.Zeros();print_matrix(d);d.Eye();print_matrix(d);return 0;
}