I have implemented exponentiation with integer base and non-negative integer exponent for practicing purposes.
Of course there is some upper limit what numbers can be exponentiated within the int datatype, but I wasn't concerned about this at this point. But any suggestions how to handle it are welcome.
There are four variantions of the power function:
// full runtime version
int ipow(int, unsigned int);
// full compile time version
template<int, unsigned int> int ipow();
// only base is known at compile time
template<int> int ipow_base(unsigned int);
// only exponent is known at compile time
template<unsigned int> int ipow_exp(int);
I am interested in a general review, if I'm doing something wrong, could I have made something more explicit, etc...
My only restriction is that I am not (yet) intersted in post C++17 features.
#include <cstddef>
#include <stdexcept>
template<const unsigned int exponent>
constexpr int ipow_exp(int base)
{
if (exponent == 0) return 1;
return base * ipow_exp<exponent-1>(base);
}
template<>
constexpr int ipow_exp<1>(int base)
{
return base;
}
template<>
constexpr int ipow_exp<0>(int base)
{
if (base == 0) throw std::logic_error("0^0 is undefined.");
return 1;
}
template<const int base>
constexpr int ipow_base(unsigned int exponent)
{
if (exponent == 0) return 1;
return base * ipow_base<base>(exponent-1);
}
template<>
constexpr int ipow_base<2>(unsigned int exponent)
{
return 1 << exponent;
}
template<>
constexpr int ipow_base<1>(unsigned int exponent)
{
return 1;
}
template<>
constexpr int ipow_base<0>(unsigned int exponent)
{
if (exponent == 0) throw std::logic_error("0^0 is undefined.");
return 0;
}
template<const int base, const unsigned int exponent>
constexpr int ipow()
{
static_assert(exponent != 0 || base != 0, "0^0 is undefined.");
if (exponent == 1 || base == 0 || base == 1) {
return base;
}
return ipow_base<base>(exponent);
}
int ipow(int base, unsigned int exponent)
{
if (exponent == 0) {
if (base == 0) throw std::logic_error("0^0 is undefined.");
return 1;
}
int result = base;
while (--exponent > 0) {
result *= base;
}
return result;
}
#include <cassert>
int main()
{
int tmp;
bool thrown = false;
// full runtime version
assert(ipow(0,1) == 0);
assert(ipow(0,2) == 0);
assert(ipow(0,3) == 0);
assert(ipow(0,50) == 0);
assert(ipow(1,0) == 1);
assert(ipow(1,1) == 1);
assert(ipow(1,2) == 1);
assert(ipow(1,50) == 1);
assert(ipow(2,0) == 1);
assert(ipow(2,1) == 2);
assert(ipow(2,2) == 4);
assert(ipow(2,3) == 8);
assert(ipow(2,10) == 1024);
assert(ipow(3,0) == 1);
assert(ipow(3,1) == 3);
assert(ipow(3,2) == 9);
assert(ipow(3,3) == 27);
assert(ipow(3,4) == 81);
assert(ipow(5,0) == 1);
assert(ipow(5,1) == 5);
assert(ipow(5,2) == 25);
assert(ipow(5,3) == 125);
assert(ipow(5,4) == 625);
assert(ipow(-1,0) == 1);
assert(ipow(-1,1) == -1);
assert(ipow(-1,2) == 1);
assert(ipow(-1,3) == -1);
assert(ipow(-1,4) == 1);
assert(ipow(-1,31) == -1);
assert(ipow(-1,32) == 1);
assert(ipow(-2,0) == 1);
assert(ipow(-2,1) == -2);
assert(ipow(-2,2) == 4);
assert(ipow(-2,9) == -512);
assert(ipow(-2,10) == 1024);
assert(ipow(-5,0) == 1);
assert(ipow(-5,1) == -5);
assert(ipow(-5,2) == 25);
assert(ipow(-5,3) == -125);
assert(ipow(-5,4) == 625);
thrown = false;
try {
ipow(0,0);
} catch (std::logic_error e){
thrown = true;
}
assert(thrown);
// compile time exponent version
tmp = ipow_exp<1>(0); assert(tmp == 0);
tmp = ipow_exp<2>(0); assert(tmp == 0);
tmp = ipow_exp<3>(0); assert(tmp == 0);
tmp = ipow_exp<50>(0); assert(tmp == 0);
tmp = ipow_exp<0>(1); assert(tmp == 1);
tmp = ipow_exp<1>(1); assert(tmp == 1);
tmp = ipow_exp<2>(1); assert(tmp == 1);
tmp = ipow_exp<50>(1); assert(tmp == 1);
tmp = ipow_exp<0>(2); assert(tmp == 1);
tmp = ipow_exp<1>(2); assert(tmp == 2);
tmp = ipow_exp<2>(2); assert(tmp == 4);
tmp = ipow_exp<3>(2); assert(tmp == 8);
tmp = ipow_exp<10>(2); assert(tmp == 1024);
tmp = ipow_exp<0>(3); assert(tmp == 1);
tmp = ipow_exp<1>(3); assert(tmp == 3);
tmp = ipow_exp<2>(3); assert(tmp == 9);
tmp = ipow_exp<3>(3); assert(tmp == 27);
tmp = ipow_exp<4>(3); assert(tmp == 81);
tmp = ipow_exp<0>(5); assert(tmp == 1);
tmp = ipow_exp<1>(5); assert(tmp == 5);
tmp = ipow_exp<2>(5); assert(tmp == 25);
tmp = ipow_exp<3>(5); assert(tmp == 125);
tmp = ipow_exp<4>(5); assert(tmp == 625);
tmp = ipow_exp<0>(-1); assert(tmp == 1);
tmp = ipow_exp<1>(-1); assert(tmp == -1);
tmp = ipow_exp<2>(-1); assert(tmp == 1);
tmp = ipow_exp<3>(-1); assert(tmp == -1);
tmp = ipow_exp<4>(-1); assert(tmp == 1);
tmp = ipow_exp<31>(-1); assert(tmp == -1);
tmp = ipow_exp<32>(-1); assert(tmp == 1);
tmp = ipow_exp<0>(-2); assert(tmp == 1);
tmp = ipow_exp<1>(-2); assert(tmp == -2);
tmp = ipow_exp<2>(-2); assert(tmp == 4);
tmp = ipow_exp<9>(-2); assert(tmp == -512);
tmp = ipow_exp<10>(-2); assert(tmp == 1024);
tmp = ipow_exp<0>(-5); assert(tmp == 1);
tmp = ipow_exp<1>(-5); assert(tmp == -5);
tmp = ipow_exp<2>(-5); assert(tmp == 25);
tmp = ipow_exp<3>(-5); assert(tmp == -125);
tmp = ipow_exp<4>(-5); assert(tmp == 625);
thrown = false;
try {
ipow_exp<0>(0);
} catch (std::logic_error e){
thrown = true;
}
assert(thrown);
// compile time base version
tmp = ipow_base<0>(1); assert(tmp == 0);
tmp = ipow_base<0>(2); assert(tmp == 0);
tmp = ipow_base<0>(3); assert(tmp == 0);
tmp = ipow_base<0>(50); assert(tmp == 0);
tmp = ipow_base<1>(0); assert(tmp == 1);
tmp = ipow_base<1>(1); assert(tmp == 1);
tmp = ipow_base<1>(2); assert(tmp == 1);
tmp = ipow_base<1>(50); assert(tmp == 1);
tmp = ipow_base<2>(0); assert(tmp == 1);
tmp = ipow_base<2>(1); assert(tmp == 2);
tmp = ipow_base<2>(2); assert(tmp == 4);
tmp = ipow_base<2>(3); assert(tmp == 8);
tmp = ipow_base<2>(10); assert(tmp == 1024);
tmp = ipow_base<3>(0); assert(tmp == 1);
tmp = ipow_base<3>(1); assert(tmp == 3);
tmp = ipow_base<3>(2); assert(tmp == 9);
tmp = ipow_base<3>(3); assert(tmp == 27);
tmp = ipow_base<3>(4); assert(tmp == 81);
tmp = ipow_base<5>(0); assert(tmp == 1);
tmp = ipow_base<5>(1); assert(tmp == 5);
tmp = ipow_base<5>(2); assert(tmp == 25);
tmp = ipow_base<5>(3); assert(tmp == 125);
tmp = ipow_base<5>(4); assert(tmp == 625);
tmp = ipow_base<-1>(0); assert(tmp == 1);
tmp = ipow_base<-1>(1); assert(tmp == -1);
tmp = ipow_base<-1>(2); assert(tmp == 1);
tmp = ipow_base<-1>(3); assert(tmp == -1);
tmp = ipow_base<-1>(4); assert(tmp == 1);
tmp = ipow_base<-1>(31); assert(tmp == -1);
tmp = ipow_base<-1>(32); assert(tmp == 1);
tmp = ipow_base<-2>(0); assert(tmp == 1);
tmp = ipow_base<-2>(1); assert(tmp == -2);
tmp = ipow_base<-2>(2); assert(tmp == 4);
tmp = ipow_base<-2>(9); assert(tmp == -512);
tmp = ipow_base<-2>(10); assert(tmp == 1024);
tmp = ipow_base<-5>(0); assert(tmp == 1);
tmp = ipow_base<-5>(1); assert(tmp == -5);
tmp = ipow_base<-5>(2); assert(tmp == 25);
tmp = ipow_base<-5>(3); assert(tmp == -125);
tmp = ipow_base<-5>(4); assert(tmp == 625);
thrown = false;
try {
ipow_base<0>(0);
} catch (std::logic_error e){
thrown = true;
}
assert(thrown);
// full compile time version
tmp = ipow<0,1>(); assert(tmp == 0);
tmp = ipow<0,2>(); assert(tmp == 0);
tmp = ipow<0,3>(); assert(tmp == 0);
tmp = ipow<0,50>(); assert(tmp == 0);
tmp = ipow<1,0>(); assert(tmp == 1);
tmp = ipow<1,1>(); assert(tmp == 1);
tmp = ipow<1,2>(); assert(tmp == 1);
tmp = ipow<1,50>(); assert(tmp == 1);
tmp = ipow<2,0>(); assert(tmp == 1);
tmp = ipow<2,1>(); assert(tmp == 2);
tmp = ipow<2,2>(); assert(tmp == 4);
tmp = ipow<2,3>(); assert(tmp == 8);
tmp = ipow<2,10>(); assert(tmp == 1024);
tmp = ipow<3,0>(); assert(tmp == 1);
tmp = ipow<3,1>(); assert(tmp == 3);
tmp = ipow<3,2>(); assert(tmp == 9);
tmp = ipow<3,3>(); assert(tmp == 27);
tmp = ipow<3,4>(); assert(tmp == 81);
tmp = ipow<5,0>(); assert(tmp == 1);
tmp = ipow<5,1>(); assert(tmp == 5);
tmp = ipow<5,2>(); assert(tmp == 25);
tmp = ipow<5,3>(); assert(tmp == 125);
tmp = ipow<5,4>(); assert(tmp == 625);
tmp = ipow<-1,0>(); assert(tmp == 1);
tmp = ipow<-1,1>(); assert(tmp == -1);
tmp = ipow<-1,2>(); assert(tmp == 1);
tmp = ipow<-1,3>(); assert(tmp == -1);
tmp = ipow<-1,4>(); assert(tmp == 1);
tmp = ipow<-1,31>(); assert(tmp == -1);
tmp = ipow<-1,32>(); assert(tmp == 1);
tmp = ipow<-2,0>(); assert(tmp == 1);
tmp = ipow<-2,1>(); assert(tmp == -2);
tmp = ipow<-2,2>(); assert(tmp == 4);
tmp = ipow<-2,9>(); assert(tmp == -512);
tmp = ipow<-2,10>(); assert(tmp == 1024);
tmp = ipow<-5,0>(); assert(tmp == 1);
tmp = ipow<-5,1>(); assert(tmp == -5);
tmp = ipow<-5,2>(); assert(tmp == 25);
tmp = ipow<-5,3>(); assert(tmp == -125);
tmp = ipow<-5,4>(); assert(tmp == 625);
#ifdef TEST_COMPILE_ERRORS
ipow<0,0>();
#endif
return 0;
}
$ g++ --version
g++ (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
2 Answers 2
Avoid large recursion
The unspecialized ipow_base()
may recurse exponent
times before multiplying. Just defer to the general case here:
template<const int base>
constexpr int ipow_base(unsigned int exponent)
{
return ipow(base, exponent);
}
Use binary exponentiation for efficiency with larger exponents
These functions (other than the specializations) scale linearly with the exponent value, but could scale logarithmically like this:
template<const unsigned int exponent>
constexpr int ipow_exp(int base)
{
return (exponent & 1 ? base : 1) * ipow_exp<exponent/2>(base*base);
}
constexpr int ipow(int base, unsigned int exponent)
{
if (!exponent && !base) {
throw std::logic_error("0^0 is undefined.");
}
if (base == 2) {
return 1 << exponent;
}
int result = 1;
int term = base;
while (exponent) {
if (exponent & 1) {
result *= term;
}
term *= term;
exponent /= 2;
}
return result;
}
Extend to other integer types
Users would probably like to be able to use any std::is_integral
type for base
(e.g. unsigned long
), so that ought to be a template type.
Simplify tests for throwing
We don't need the thrown
variable here:
thrown = false; try { ipow(0,0); } catch (std::logic_error e){ thrown = true; } assert(thrown);
Just assert in the try
block:
try {
ipow(0,0);
assert(false);
} catch (std::logic_error& e) {
// expected
}
Better still, use one of the many available test frameworks rather than simple assert()
. That would help in several ways, such as detecting multiple failures per run, and showing actual and expected values for comparison.
You might be able to use std::pow()
in constexpr
expressions in C++11
Since this post was not tagged "reinventing-the-wheel", I want to point out that some compilers (notably GCC) will compile the below C++11 code:
#include <cmath>
constexpr int ipow(int a, int b) {
return std::pow(a, b);
}
int main(int argc, char *argv[]) {
static_assert(ipow(-5, 3) == -125);
return ipow(argc, 2);
}
One drawback is that std::pow()
converts integer arguments to double
, which at run-time may or may not result in slower computation than using int
. Also, while for int
there is no loss of precision, if you would want to use int64_t
, there is a potential loss of precision.
The other drawback, as pointed out by Oliver Schonrock, is that not all compilers allow constexpr
use of std::pow()
. As explained in this post,
constexpr
math functions were only allowed in C++11 but not in C++14. But there are libraries that provide constexpr
math functions, see for example Sprout's pow()
implementation.
Zero to the power zero is one*
With most programming languages, one usually finds that pow(0, 0) == 1
. You should ensure your solution also returns one in that case, to ensure consistency, regardless of your personal feelings about zero to the power zero.
As a bonus, by having a well-defined result for ipow(0, 0)
, it no longer throws exceptions, and you can get rid of some of the specializations.
Catch exceptions by const reference
Make it a habit to catch exceptions by const reference. Apart from being a little bit faster (although this of course is the least of your worries when exceptions are being thrown), it ensures you don't lose information when the exception thrown is of a derived class. See this StackOverflow question for more information.
-
\$\begingroup\$ "following code is valid C++11". Are you sure? It compiles for me under gcc-9.2. but clang-9 throws "static_assert expression is not an integral constant expression". This doesn't mention
constexpr
: en.cppreference.com/w/cpp/numeric/math/pow . This caught my eye because I recently personally implemented a limited compile timepow
exactly because it only works in gcc as a non-standard extension? \$\endgroup\$Oliver Schönrock– Oliver Schönrock2020年01月06日 03:33:58 +00:00Commented Jan 6, 2020 at 3:33 -
1\$\begingroup\$ Summary of all 3 main compilers. Latest stable of each. godbolt.org/z/07b7Ww Only gcc seems to support it. Also tried trunk clang, which doesn't work either. \$\endgroup\$Oliver Schönrock– Oliver Schönrock2020年01月06日 07:42:09 +00:00Commented Jan 6, 2020 at 7:42
-
\$\begingroup\$ Woops, it's indeed not guaranteed to be
constexpr
. Apparently it was only allowed to beconstexpr
in C++11. Thanks for the godbolt link! \$\endgroup\$G. Sliepen– G. Sliepen2020年01月06日 10:25:40 +00:00Commented Jan 6, 2020 at 10:25 -
\$\begingroup\$ Thanks for the answer! Although I must say 0^0=1 gives me headache :) \$\endgroup\$slepic– slepic2020年01月06日 11:12:30 +00:00Commented Jan 6, 2020 at 11:12
-
\$\begingroup\$ Heya, sorry I removed the accepted answer. You made some valid points. But I find the other answer a bit more useful so I decided to accept that one instead... \$\endgroup\$slepic– slepic2020年01月06日 17:28:14 +00:00Commented Jan 6, 2020 at 17:28