I have a task
Given natural number \$n\$. Print all triplets of numbers \$(x, y, z)\$ where \$x^2+y^2+z^2 = n\$.
And so far I came to this:
#include <array>
#include <algorithm>
#include <cmath>
#include <unordered_set>
struct ArrayHash {
// Custom hash function for std::array<unsigned int, 3>
std::size_t operator()(const std::array<unsigned int, 3>& arr) const {
std::size_t hashValue = 0;
for (const auto& elem : arr) {
hashValue ^= std::hash<unsigned int>()(elem) + 0x9e3779b9 + (hashValue << 6) + (hashValue >> 2);
}
return hashValue;
}
};
std::unordered_set<std::array<unsigned int, 3>, ArrayHash> findUniqueTriplets(const unsigned int &n) {
const int N_ROOTED = std::sqrt(n);
std::unordered_set<std::array<unsigned int, 3>, ArrayHash> unique_triplets;
for (unsigned int one = 1; one <= N_ROOTED; ++one) {
const unsigned int one_squared = one * one;
for (unsigned int two = 0; two <= N_ROOTED; ++two) {
const unsigned int sum_squares = one_squared + two * two;
if (sum_squares > n) {
break;
}
const unsigned int remaining = n - sum_squares;
const unsigned int three = std::sqrt(remaining);
if (three * three == remaining) {
std::array<unsigned int, 3> triplet = {one, two, three};
std::sort(triplet.begin(), triplet.end());
(void)unique_triplets.insert(triplet);
}
}
}
return unique_triplets;
}
Do you have recommendations about the code, for example:
- X looks weird
- better to use another type of variable for X
- there is a method with less difficulty (\$O(log(n))\$ instead of \$O(n)\$) for this part of the code
- etc.
The code needs to be fast & readable
3 Answers 3
This is not going to be a deep review, but I wanted to focus on a particular aspect of the design.
Whenever I see people saying they want highly performant code, but then I see things like std::unordered_set
being used... it gets my hackles up. std::unordered_set
is extremely expensive, compared to actually performant alternatives. Oh yes, it’s way more performant than most of the other options—like std::set
or most other node-based containers—in most cases. And yes, assuming certain use patterns and large enough amounts of data, std::unordered_set
could be the most performant choice. BUT! If you really want performance, you should be avoiding those use patterns like the plague.
Or to put it another way: if you really want maximum performance, you should probably look for ways to avoid needing std::unordered_set
.
And in this case, it’s actually pretty easy to avoid it. All you’re using it to do is guarantee that each triple in the solution set is unique. But there is a much easier way to do that.
All you have to do is only check tuples in lexicographical order. You can go up or down, doesn’t matter; I’ll illustrate by going down.
First start with:
- \$x^2={x_0}^2=n\$ and \$x=x_0=\lfloor\sqrt{n}\rfloor\$ (that is, the integer square root, not just the square root); and
- \$y^2={y_0}^2=n-{x_0}^2\$ and \$y=y_0=\lfloor\sqrt{n-{x_0}^2}\rfloor\$.
which means \$z^2={z_0}^2=n-({x_0}^2+{y_0}^2)\$ and \$z=z_0=\lfloor\sqrt{n-({x_0}^2+{y_0}^2)}\rfloor\$. If that’s a solution, record it. If not, move on.
Decrement \$y\$, checking \$z\$ each time to see if you’ve got a triple. Stop when \$z\$ is greater than \$y\$.
Decrement \$x\$ and calculate \$y\$. If \$x\$ is less than \$y\$ you’re done. Otherwise, go to 2.
I’m glossing over some tricky details, but the idea is you get results like:
If n = 129:
(11,2,2)
(10,5,2)
(8,8,1)
(8,7,4)
That is, the results are generated in (reverse) order.
Because the results are generated in a predictable order, it is trivial to avoid duplicated work, and results that are the "same" but for the order of the elements. That means no need for std::unordered_set
... no need for std::sort()
... no wasted work.
Not only all of that, but because you don’t need to generate all the results before you know what the final results are—because you don’t need to keep track of which results you’ve seen to avoid duplicates—you can produce results on the fly. I presume that the ultimate goal is to print the results; well, there is no need to rack up all the results in a container and then print; you can print-as-you-go.
You can either implement this as an algorithm, or as a range generator. The latter would be much more powerful, but more complicated. So I’ll illustrate with an algorithm:
template<std::ranges::output_range<std::array<int, 3> const&> R>
constexpr auto find_sum_of_squares_triplets_for(R&& out, int n) -> std::ranges::borrowed_iterator_t<R>
{
// implement your algorithm...
if (is_triplet(x, y, z))
// write std::array{x, y, z} to out
}
So, in summary:
- Search your triplets in lexicographical order, so you can avoid duplicate results without sorting or resorting to expensive constructs like
std::unordered_set
. - Output your results directly; don’t store them in a container and return the container. (If someone wants the results in a container, they can use
std::back_inserter
orstd::ranges::to()
to put them in one.)
Also, I’m pretty sure there is a constant-time equation that can check whether there are any results at all. You might want to do that right off the bat to see if there is any point running the loop.
-
1\$\begingroup\$ Additionally, there is a lower bound for x of
ceil(sqrt(x/3.0))
. \$\endgroup\$Davislor– Davislor2024年10月26日 11:28:50 +00:00Commented Oct 26, 2024 at 11:28 -
4\$\begingroup\$ The mathematical result is that there are no solutions precisely when \$n\$ can be written as a power of \4ドル\$ times a number of the form \8ドルk-1\$. \$\endgroup\$Greg Martin– Greg Martin2024年10月27日 00:24:23 +00:00Commented Oct 27, 2024 at 0:24
indi has already addressed algorithmic issues, but a few style notes that jump out at me.
There are a number of repetitions of type names that we could be factored out using using
.
std::array<unsigned int, 3>
std::unordered_set<std::array<unsigned int, 3>, ArrayHash>
The actual names you choose are up to you, but something like the following.
struct ArrayHash;
using uint_arr3 = std::array<unsigned int, 3>;
using set_t = std::unordered_set<uint_arr3, ArrayHash>;
Furthermore, your code uses inconsistent indentation. Be nice to your brain. You spend more time reading code (including your own!) than writing it.
Addressing these issues:
#include <array>
#include <algorithm>
#include <cmath>
#include <unordered_set>
struct ArrayHash;
using uint_arr3 = std::array<unsigned int, 3>;
using set_t = std::unordered_set<uint_arr3, ArrayHash>;
struct ArrayHash {
// Custom hash function for std::array<unsigned int, 3>
std::size_t operator()(const uint_arr3& arr) const {
std::size_t hashValue = 0;
for (const auto& elem : arr) {
hashValue ^= std::hash<unsigned int>()(elem) + 0x9e3779b9 + (hashValue << 6) + (hashValue >> 2);
}
return hashValue;
}
};
set_t findUniqueTriplets(const unsigned int &n) {
const int N_ROOTED = std::sqrt(n);
set_t unique_triplets;
for (unsigned int one = 1; one <= N_ROOTED; ++one) {
const unsigned int one_squared = one * one;
for (unsigned int two = 0; two <= N_ROOTED; ++two) {
const unsigned int sum_squares = one_squared + two * two;
if (sum_squares > n) {
break;
}
const unsigned int remaining = n - sum_squares;
const unsigned int three = std::sqrt(remaining);
if (three * three == remaining) {
uint_arr3 triplet = {one, two, three};
std::sort(triplet.begin(), triplet.end());
(void)unique_triplets.insert(triplet);
}
}
}
return unique_triplets;
}
Algorithm
For a somewhat more practical look at some of what indi was talking about, let's take your basic approach and change just a couple of things.
- Obviously I haven't broken this out into a function.
- I've started
y
in each inner loop at the value ofx
. For instance, if we've considered x = 4, y = 8; then there's no need to consider x = 8, y = 4. The remainder will be the same and we'll end up with a duplicate result. - There are two guards inside that inner loop. The first is straight from your code. If there's no remainder, the loop is done. There's no point checking further. The second prevents duplicates by determining that the root would be less than the second element of the triple and continuing to the next loop iteration, or doing that if the remainder isn't a perfect square.
int main() {
constexpr unsigned int n = 1000;
for (unsigned int x = 0; x < std::sqrt(n); ++x) {
for (unsigned int y = x; y < std::sqrt(n); ++y) {
unsigned int remainder = n - x * x - y * y;
unsigned int root = std::sqrt(remainder);
if (remainder <= 0) break;
if (root < y || root * root != remainder) continue;
std::cout << x << ", " << y << ", " << root << '\n';
}
}
}
You don't need to print to standard out as I've done, but the inside of the loop could, for instance, emplace back that triple into a std::vector<std::array<unsigned int, 3>>
.
Adding to indi's and Chris's answers:
Create a new type for triplets
The algorithm looks very cluttered because you are using a raw std::array<unsigned int, 3>
to store the triplets. This requires you to create a hash function which you have to pass explicitly to std::unorderded_set
, and you have to explicitly sort them. Consider instead creating a struct Triplet
which takes care of all that:
struct Triplet {
std::array<unsigned int, 3> values;
Triplet(unsigned int a, unsigned int b, unsigned int c): values{{a, b, c}} {
std::ranges::sort(values);
}
};
template<>
struct std::hash<Triplet> {
std::size_t operator()(const Triplet& triplet) const noexcept {
return /* hash of a, b and c */;
}
};
Now you can write your code like so:
std::unordered_set<Triplet> findUniqueTriplets(unsigned int n) {
...
unique_triplets.emplace(one, two, three);
...
}
There is a lot more you can do when you create your own type. For example, Chris mentioned checking triplets in lexicographic order. You could add an operator++()
to Triplet
that will update the triplet to the next one in lexicographic order. You could then just use a simple loop to search through them.
Another option would be to create an iterator for triplets that only returns triplets whose sum of squares is equal to some value. Combined with C++23's ranges, you could then make the main algorithm look like:
std::unordered_set<Triplet> findUniqueTriplets(unsigned int n) {
return triplets_with_sum_square(n) | std::ranges::to<std::unordered_set>;
}
x <= y <= z
as you will just find duplicates. Thus your second loop be optimized asfor (unsigned int two = one; two <= N_ROOTED; ++two)
\$\endgroup\$one_squared = one * one
and in my mind there's a circuit recording the fact thatone_squared == one
. Wrongly, of course. Also, we're well past the point where the performance benefit is noticeable, but there is generally a benefit to building a series of squares additively rather than using multiplication. \$\endgroup\$