I am trying to implement an optimal and fast running matrix in C++. I need some review of the code and ideas on how to improve the code quality if it shall be.
class Matrix {
// Some static assertions and useful types
static_assert(std::is_arithmetic_v<T>, "Matrix template parameter type must be arithmetic");
using DataType = std::vector<T>;
// Default constructors
public:
Matrix() : mCols(0), mRows(0), mData(0) { };
Matrix(const Matrix &other) = default;
Matrix(Matrix &&other) noexcept = default;
Matrix &operator=(const Matrix &other) = default;
Matrix &operator=(Matrix &&other) noexcept = default;
// Parameterized constructors
public:
Matrix(std::size_t cols, std::size_t rows) : mCols(cols), mRows(rows) {
mData.resize(cols * rows);
}
Matrix(std::size_t cols, std::size_t rows, const DataType &data) {
if (rows * cols != data.size()) {
throw std::invalid_argument("Invalid matrix mData");
}
mCols = cols;
mRows = rows;
mData = data;
}
Matrix(std::size_t cols, std::size_t rows, DataType &&data) {
if (rows * cols != data.size()) {
throw std::invalid_argument("Invalid matrix mData");
}
mCols = cols;
mRows = rows;
mData = std::move(data);
}
// Setters
public:
void set(std::size_t rows, std::size_t cols, const DataType &data) {
if (rows * cols != data.size()) {
throw std::invalid_argument("Invalid vector data");
}
mRows = rows;
mCols = cols;
mData.resize(rows * cols);
std::copy(data.begin(), data.end(), mData.begin());
}
void set(std::size_t rows, std::size_t cols, DataType &&data) {
if (rows * cols != data.size()) {
throw std::invalid_argument("Invalid vector data");
}
mRows = rows;
mCols = cols;
mData.resize(rows * cols);
std::move(data.begin(), data.end(), mData.begin());
}
void set(std::size_t rows, std::size_t cols) {
mData.resize(rows * cols);
}
void set(const DataType &data) {
if (mRows * mCols != data.size()) {
throw std::invalid_argument("Invalid vector data");
}
std::copy(data.begin(), data.end(), mData.begin());
}
void set(DataType &&data) {
if (mRows * mCols != data.size()) {
throw std::invalid_argument("Invalid vector data");
}
std::move(data.begin(), data.end(), mData.begin());
}
// Getters
public:
auto cols() const {
return mCols;
}
auto rows() const {
return mRows;
}
const DataType &data() const {
return mData;
}
// Operator overloads
public:
Matrix operator+(const Matrix &rhs) const {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may add matrices with same dimensions only");
}
Matrix result(mCols, mRows);
std::transform(mData.begin(), mData.end(), rhs.mData.begin(), result.mData.begin(), std::plus<>());
return result;
}
Matrix operator+(Matrix &&rhs) const {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may add matrices with same dimensions only");
}
Matrix result(mCols, mRows);
std::transform(mData.begin(), mData.end(), std::make_move_iterator(rhs.mData.begin()), result.mData.begin(), std::plus<>());
return result;
}
Matrix operator-(const Matrix &rhs) const {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may sub matrices with same dimensions only");
}
Matrix result(mCols, mRows);
std::transform(mData.begin(), mData.end(), rhs.mData.begin(), result.mData.begin(), std::minus<>());
return result;
}
Matrix operator-(Matrix &&rhs) const {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may sub matrices with same dimensions only");
}
Matrix result(mCols, mRows);
std::transform(mData.begin(), mData.end(), std::make_move_iterator(rhs.mData.begin()), result.begin(), std::minus<>());
return result;
}
Matrix &operator+=(const Matrix &rhs) {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may add matrices with same dimensions only");
}
std::transform(mData.begin(), mData.end(), rhs.mData.begin(), mData.begin(), std::plus<>());
return *this;
}
Matrix &operator+=(Matrix &&rhs) {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may add matrices with same dimensions only");
}
std::transform(mData.begin(), mData.end(), std::make_move_iterator(rhs.mData.begin()), mData.begin(), std::plus<>());
return *this;
}
Matrix &operator-=(const Matrix &rhs) {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may add matrices with same dimensions only");
}
std::transform(mData.begin(), mData.end(), rhs.mData.begin(), mData.begin(), std::minus<>());
return *this;
}
Matrix &operator-=(Matrix &&rhs) {
if (mCols != rhs.mCols || mRows != rhs.mRows) {
throw std::invalid_argument("Invalid vector mData, you may add matrices with same dimensions only");
}
std::transform(mData.begin(), mData.end(), std::make_move_iterator(rhs.mData.begin()), mData.begin(), std::minus<>());
return *this;
}
bool operator==(const Matrix &rhs) const {
return mCols == rhs.mCols && mRows == rhs.mRows && std::equal(mData.begin(), mData.end(), rhs.mData.begin(), rhs.mData.end());
}
bool operator!=(const Matrix &rhs) const {
return !(*this == rhs);
}
friend std::ostream &operator<<(std::ostream &os, const Matrix &matrix) {
for (int i = 0; i < matrix.mCols; ++i) {
for (int j = 0; j < matrix.mRows; ++j) {
os << matrix.mData[i * matrix.mCols + j] << " ";
}
os << "\n";
}
return os;
}
protected:
std::size_t mCols;
std::size_t mRows;
std::vector<T> mData;
};
I'll add determinant and matrix multiplication operators too, but that will come after reviewing this part.
3 Answers 3
for
Matrix() : mCols(0), mRows(0), mData(0) { };
semicolon should be removed as useless.
for
Matrix(std::size_t cols, std::size_t rows) : mCols(cols), mRows(rows) { mData.resize(cols * rows); }
You can use member list initialization too for
vector
:Matrix(std::size_t cols, std::size_t rows) : mCols(cols), mRows(rows), mData(cols * rows) {}
for
void set(std::size_t rows, std::size_t cols) { mData.resize(rows * cols); }
it is wrong, as changing row/columns change internal layout, so for example element of 2nd row, 2nd column won't be there anymore.
resize
seems to be a better name if you keep that function.for arithmetic operators
Matrix operator+(const Matrix &rhs) const
, you forget when lhs is a rvalue, so either add overload with this qualifierMatrix operator+(const Matrix &rhs) &&
or make all overload as (
friend
) free functions.You might implement
operator +
withoperator +=
.std::vector
already has anoperator ==
sobool operator==(const Matrix &rhs) const { return mCols == rhs.mCols && mRows == rhs.mRows && std::equal(mData.begin(), mData.end(), rhs.mData.begin(), rhs.mData.end()); }
can be simplified to
bool operator==(const Matrix &rhs) const { return mCols == rhs.mCols && mRows == rhs.mRows && mData == rhs.mData; }
-
\$\begingroup\$ All your comments are right, will work on them) Thanks) \$\endgroup\$Hrant Nurijanyan– Hrant Nurijanyan2022年02月17日 13:50:51 +00:00Commented Feb 17, 2022 at 13:50
-
\$\begingroup\$ If I make all overload as
friend
free functions should not I makefriend Matrix operator+(const Matrix& lhs, const Matrix& rhs);
friend Matrix operator+(Matrix&& lhs, const Matrix& rhs);
friend Matrix operator+(Matrix&& lhs, Matrix&& rhs);
friend Matrix operator+(const Matrix& lhs, Matrix&& rhs);
Overloads? Is there a good practice of doing this? I've read in Scott Meyer's book somewhere that perfect forwarding may help, is this the case? \$\endgroup\$Hrant Nurijanyan– Hrant Nurijanyan2022年02月17日 14:01:58 +00:00Commented Feb 17, 2022 at 14:01 -
\$\begingroup\$ If you want to handle each combinations, you need 4 overloads (as methods or as free functions). Forwarding reference won't help here as you need
Matrix
notT&&
. \$\endgroup\$Jarod42– Jarod422022年02月17日 15:53:36 +00:00Commented Feb 17, 2022 at 15:53 -
\$\begingroup\$ Since C++20, we can define
bool operator==(const Matrix &rhs) const = default;
(and we getoperator!=()
for free). It's unlucky the code is targeting C++17. \$\endgroup\$Toby Speight– Toby Speight2022年02月17日 17:03:34 +00:00Commented Feb 17, 2022 at 17:03
Thanks for using vector!
It is conceivable to implement a matrix class with std::unique_ptr<T[]>
instead of vector
, thereby shaving off 16 bytes (on common platform), and too many make the mistake.
The host of utilities provided by vector (such as resize
, and copies) more than make up for the excess 16 bytes, so only in extreme circumstances should it be considered to roll your own.
Data members should be private.
First of all, this class has no virtual functions, therefore it is not meant to be inherited from, and thus protected
does not make sense.
Even if it were a base class, however, protected
should still be avoided. You cannot maintain invariants concerning protected
data-members: they break encapsulation.
R-value setters should move their argument.
Firstly, mData = std::move(data);
works even if T
itself is not moveable.
Secondly, it is vastly cheaper to execute mData = std::move(data);
(just copying 3 pointers) than it is to copy each and every element individually.
As such, setters taking in DataType&& data
should just move data
into mData
directly, rather than performing element-wise moves.
Further considerations
- Have you considered stronger typing? At the moment it's very easy to accidentally swap rows and columns, using distinct types for each would make this a compile time error.
- Have you considered indexing? It's painful to have to do
matrix.data().at(row * matrix.cols() + col)
, a providedat
(oroperator()
) would make this so much sweeter. - By exposing the data-type so much, you expose whether data is stored in row-major or column-major format. This also prevents nifty optimizations like zero-cost transpose. Is this really necessary?
- Instead of taking
const DataType&
consider taking instd::span<const T>
where suitable; it avoids building a vector when one doesn't have one handy.
-
\$\begingroup\$ Thanks for a good answer, I'll definitely take these into consideration \$\endgroup\$Hrant Nurijanyan– Hrant Nurijanyan2022年02月19日 13:29:42 +00:00Commented Feb 19, 2022 at 13:29
-
\$\begingroup\$ @Matthieu M. Your comment was so good advice. Have you ever implemented it? If yes, could I refer to it? Thank you very much. \$\endgroup\$AnhPC03– AnhPC032024年01月11日 04:48:06 +00:00Commented Jan 11, 2024 at 4:48
-
1\$\begingroup\$ @AnhPC03: I don't rightly remember, to be honest. In any case it would be proprietary code anyway so you wouldn't be allowed to read it. Just implement it yourself (it's simple enough) and ask for review on this site :) \$\endgroup\$Matthieu M.– Matthieu M.2024年01月11日 08:25:03 +00:00Commented Jan 11, 2024 at 8:25
Since all binary operators do not modify your Matrix, they should be implemented as non-member functions.
If (and only if) they need access to private members, implement them a friend. Here, this is clearly the case.
Note that you need to declare them as a friend inside the matrix class, but additionaly declare them for the compiler.
class Matrix{
Matrix& operator*=(Matrix const& rhs); // modidies matrix, so it's a member function
// friend declaration, i.e. "if a function like this is called, it has private access"
friend Matrix operator+(Matrix const& lhs, Matrix const& rhs); // not a member function
// scalar types
template<typename U>
friend Matrix operator*(Matrix const& lhs, U const& rhs);
template<typename U>
friend Matrix operator*(U const& lhs, Matrix const& rhs);
// matrix multiplication
friend Matrix operator*(Matrix const& lhs, Matrix const& rhs);
};
// declaration of the actual function
template<typename U>
Matrix operator*(Matrix const& lhs, U const& rhs){
auto copy(lhs);
copy *= rhs; // use this operator here
return copy;
}
template<typename U>
Matrix operator*(U const& lhs, Matrix const& rhs){
return rhs * lhs; // call the previous implementation above
}
This has the advantage, that all these operators are found in the same place in your code, and not half of them are member functions (when it's possible) and half of them are not (because double * Matrix can not be defined as a member function).
If you want your Matrix and the type to be templated, you should be aware about decltype
Explore related questions
See similar questions with these tags.
stl
algorithm
functionalities, but they can moved all to private), or they can be used only in terms ofmData.begin() , mData.end()
and not to be inlined in matrix class \$\endgroup\$