A couple days ago I posted my first implementation of my data structure here, which implements a cache-optimized tensor representation. I have worked the ideas from the responses into my implementation and I would like to see how I could further improve it.
Code Explanation
I will provide a rough explanation of the operation of the code as there is quite a bit to it.
Tensor Class
The data is stored in a single std::vector
in the Tensor
class. An offsets
member variable is also used to determine how far to step to get to the next element of a specific dimension.
Tensor::Operator[]
I have overloaded Tensor::operator[]
with both const and non-const variants. In the non-const variant a return type of decltype(auto)
is used to make sure that in the case where DIM == 1
a double&
is returned, in the case where DIM != 1
an object of type Tensor::View
is returned with dimension one less. The const variant uses a return type of auto
and so instead returns double
in the case where DIM == 1
and an object of type Tensor::ConstView
with one dimension less otherwise.
Tensor::resize
A resize function is provided in Tensor
which reallocates the memory buffer for a given dimension of tensor. the offsets
member variable is also given values which is used to determine how far to step through the vector for a given dimension.
Tensor::View class
The Tensor::View
class has two member variables, data
which points to an area in memory owned by the data
member of some Tensor
class, and offsets
which points to the start of the member variable offsets
in a tensor class. This class also implements a single operator[]
overload which either returns a double&
or another Tensor::View
of one dimension less.
Tensor::ConstView class
The Tensor::ConstView
is similar to the Tensor::View
class, only that it only has constant pointer member variable equivalents to the Tensor::View
member variables. A single operator[]
overload is given which returns either a double
or a Tensor::ConstView
of one dimension less.
My Concerns
My main concern is the amount of repetition in the code. This is not too much of an issue at the moment as there are not many member variables provided, however as I add more they will need to be implemented three times, once in the main Tensor
class and then once in each of the Tensor::View
and Tensor::ConstView
classes. I do not know a way around this; however, I would really like to hear some of your suggestions on this.
I am also concerned about the potential for poor performance as I need to pass two separate pointers into the constructor of each Tensor::View
or Tensor::ConstView
. I am not sure whether or not the performance penalty from this is substantial, and so I am interested to see what others think. Thank you.
template<typename Ty, int DIM>
class Tensor
{
private:
std::vector<Ty> data;
std::vector<size_t> offsets;
public:
Tensor() :
offsets(DIM + 1)
{}
Tensor(const std::vector<size_t> &dims) :
offsets(DIM + 1)
{
resize(&dims.data()[0]);
}
decltype(auto) operator[] (size_t i)
{
if constexpr (DIM == 1)
{
return data[i];
}
else
{
return typename Tensor<Ty, DIM-1>::View(&data.data()[i*offsets[DIM-1]], &offsets.data()[0]);
}
}
auto operator[] (size_t i) const
{
if constexpr (DIM == 1)
{
return data[i];
}
else
{
return typename Tensor<Ty, DIM-1>::ConstView(&data.data()[i*offsets[DIM-1]], &offsets.data()[0]);
}
}
void resize(const size_t *dims)
{
offsets[0] = 1;
for (size_t i = 1; i <= DIM; ++i)
{
offsets[i] = offsets[i - 1] * dims[DIM - i];
}
data.resize(offsets[DIM]);
}
size_t size(int axis=0) const
{
return offsets[DIM-axis] / offsets[DIM-axis-1];
}
class View
{
private:
Ty *data;
size_t *offsets;
View() = delete;
View(Ty *data, size_t *offsets) :
data(data), offsets(offsets)
{}
public:
friend class Tensor<Ty, DIM+1>;
decltype(auto) operator[] (size_t i)
{
if constexpr (DIM == 1)
{
return data[i];
}
else
{
return typename Tensor<Ty, DIM-1>::View(&data[i*offsets[DIM-1]], offsets);
}
}
size_t size(int axis=0) const
{
return offsets[DIM-axis] / offsets[DIM-axis-1];
}
};
class ConstView
{
private:
const Ty *data;
const size_t *offsets;
ConstView() = delete;
ConstView(const Ty *data, const size_t *offsets) :
data(data), offsets(offsets)
{}
public:
friend class Tensor<Ty, DIM+1>;
auto operator[] (size_t i) const
{
if constexpr (DIM == 1)
{
return data[i];
}
else
{
return typename Tensor<Ty, DIM-1>::ConstView(&data[i*offsets[DIM-1]], offsets);
}
}
size_t size(int axis=0) const
{
return offsets[DIM-axis] / offsets[DIM-axis-1];
}
};
};
1 Answer 1
We're missing some includes to make this work - never depend on the client code to have previously included the headers we need. And we're assuming using std::size_t;
- avoid writing that in a header file, because there's no way to undo it at the end of our code. Just write the type name in full wherever we need it.
The other big omission is the unit tests. Without those, we have to infer how the class is expected to be used. And without writing the tests, you're likely to overlook some of the edge cases, such as what happens when DIM
is zero or dims
is empty.
I'm a bit concerned about delving into dims.data()
just to index it - why not index dims
directly? I don't see any good reason to write dims.data()[0]
instead of dims[0]
. And why does resize()
want a pointer to any array, rather than accepting a range? What we have is dangerous, because we can easily pass a pointer to a smaller amount of memory than the function will read.
When you have access to C++23, you'll find that the const/non-const overload pairs can each be combined, using the new feature of deducing this
. You're well placed to take advantage of that when the time is right.
I think it should be possible to combine the View
and ConstView
inner types with judicious use of type_traits
magic to make the non-const []
operator disappear from the view when Ty
is a const type.
Finally, it looks like we could do with a deducible constructor (perhaps accepting a std::array
of size N
), then we could write a deduction guide for it, and not require the caller to keep track of N
everywhere.
Ty& at(std::initializer_list<size_t> indices)
andconst Ty& at(/*ditto*/) const
. \$\endgroup\$