Story
I'm trying to become more familiar with template metaprogramming and practice more of it. I need to use modular arithmetic to solve certain problems and so I decided to write a general purpose template for it.
Code
Include files
// Individual
#include <iostream>
#include <chrono>
#include <vector>
#include <exception>
#include <string>
#include <type_traits>
// PCH
#include "bits/stdc++.h"
Modular Arithmetic Template
namespace competitive_programming::utility::sfinae {
template <typename T, typename = void>
struct is_istream_streamable
: std::false_type
{ };
template <typename T>
struct is_istream_streamable <T, std::void_t <decltype(std::cin >> std::declval <T&> ())>>
: std::true_type
{ };
template <typename T>
constexpr bool is_istream_streamable_v = is_istream_streamable <T>::value;
template <typename T, typename = void>
struct is_ostream_streamable
: std::false_type
{ };
template <typename T>
struct is_ostream_streamable <T, std::void_t <decltype(std::cerr << std::declval <T> ())>>
: std::true_type
{ };
template <typename T>
constexpr bool is_ostream_streamable_v = is_ostream_streamable <T>::value;
template <typename T, typename = void>
struct has_begin_iterator
: std::false_type
{ };
template <typename T>
struct has_begin_iterator <T, std::void_t <decltype(std::declval <T> ().begin())>>
: std::true_type
{ };
template <typename T>
constexpr bool has_begin_iterator_v = has_begin_iterator <T>::value;
template <typename T, typename = void>
struct has_end_iterator
: std::false_type
{ };
template <typename T>
struct has_end_iterator <T, std::void_t <decltype(std::declval <T> ().end())>>
: std::true_type
{ };
template <typename T>
constexpr bool has_end_iterator_v = has_end_iterator <T>::value;
template <typename T>
constexpr bool has_begin_end_iterator_v = has_begin_iterator_v <T> && has_end_iterator_v <T>;
template <typename T, typename = void>
struct is_mathematical
: std::false_type
{ };
template <typename T>
struct is_mathematical <T, std::void_t <decltype(
std::declval <T&> () += std::declval <T> (),
std::declval <T&> () -= std::declval <T> (),
std::declval <T&> () *= std::declval <T> (),
std::declval <T&> () /= std::declval <T> (),
std::declval <T&> () %= std::declval <T> (),
std::declval <T&> () &= std::declval <T> (),
std::declval <T&> () |= std::declval <T> (),
std::declval <T&> () ^= std::declval <T> (),
std::declval <T&> () <<= std::declval <T> (),
std::declval <T&> () >>= std::declval <T> (),
++std::declval <T&> (),
--std::declval <T&> (),
std::declval <T&> ()++,
std::declval <T&> ()--,
std::declval <T> () + std::declval <T> (),
std::declval <T> () - std::declval <T> (),
std::declval <T> () * std::declval <T> (),
std::declval <T> () / std::declval <T> (),
std::declval <T> () % std::declval <T> (),
std::declval <T> () & std::declval <T> (),
std::declval <T> () | std::declval <T> (),
std::declval <T> () ^ std::declval <T> (),
std::declval <T> () << std::declval <T> (),
std::declval <T> () >> std::declval <T> (),
+std::declval <T> (),
-std::declval <T> ()
)>> : std::true_type
{ };
template <typename T, typename = void>
struct is_comparable
: std::false_type
{ };
template <typename T>
struct is_comparable <T, std::void_t <decltype(
std::declval <T> () == std::declval <T> (),
std::declval <T> () != std::declval <T> (),
std::declval <T> () <= std::declval <T> (),
std::declval <T> () >= std::declval <T> (),
std::declval <T> () < std::declval <T> (),
std::declval <T> () > std::declval <T> (),
std::declval <T&>() = std::declval <T> (),
static_cast <bool> (std::declval <T> ())
)>> : std::true_type
{ };
template <typename T>
struct is_integer_type {
static constexpr bool value = is_mathematical <T>::value &&
is_comparable <T>::value;
};
template <typename T>
constexpr bool is_integer_type_v = is_integer_type <T>::value;
template <typename T, typename = void>
struct common_bigger_integer {
using type = T;
};
template <typename T>
struct common_bigger_integer <T, std::enable_if_t <std::is_integral_v <T>, void>> {
using type = int64_t;
};
template <typename T>
using common_bigger_integer_t = typename common_bigger_integer <T>::type;
}
namespace competitive_programming::utility::math {
namespace sfinae = competitive_programming::utility::sfinae;
template <typename Integer, Integer Modulo>
class modular {
static_assert(sfinae::is_integer_type_v <Integer>, "class modular requires integer type");
static_assert(Modulo > 0, "Modulo must be positive");
using Common_Integer = std::decay_t <sfinae::common_bigger_integer_t <Integer>>;
private:
Common_Integer integer = Common_Integer();
const Common_Integer modulo = Modulo;
void normalize () {
if (integer >= mod() || -integer <= mod())
integer %= mod();
if (integer < 0)
integer += mod();
}
public:
modular (const modular& m)
: integer (m.integer)
{ }
template <typename T = Common_Integer>
modular (const T& integer = T())
: integer (static_cast <Common_Integer> (integer))
{ normalize(); }
modular& operator = (const modular& m)
{ integer = m.integer; return *this; }
constexpr Common_Integer mod () const { return modulo; }
constexpr Common_Integer operator () () const { return integer; }
constexpr Common_Integer& operator () () { return integer; }
template <typename T>
constexpr explicit operator T() const { return static_cast <T> (integer); }
modular& operator += (const modular& other) {
integer += other.integer;
if (integer >= mod())
integer -= mod();
return *this;
}
modular& operator -= (const modular& other) {
integer -= other.integer;
if (integer < 0)
integer += mod();
return *this;
}
modular& operator *= (const modular& other)
{ return integer *= other.integer, normalize(), *this; }
modular& operator /= (const modular& other)
{ return *this *= other.extended_euclidean_inverse().integer, normalize(), *this; }
modular& operator ++ () { return *this += 1; }
modular& operator -- () { return *this -= 1; }
modular operator ++ (int) const { modular result (*this); *this += 1; return result; }
modular operator -- (int) const { modular result (*this); *this -= 1; return result; }
modular operator + () const { return *this; }
modular operator - () const { return modular(-integer); }
friend modular operator + (modular self, const modular& other) { return self += other; }
friend modular operator - (modular self, const modular& other) { return self -= other; }
friend modular operator * (modular self, const modular& other) { return self *= other; }
friend modular operator / (modular self, const modular& other) { return self /= other; }
friend bool operator == (const modular& left, const modular& right) { return left() == right(); }
friend bool operator != (const modular& left, const modular& right) { return left() != right(); }
friend bool operator <= (const modular& left, const modular& right) { return left() <= right(); }
friend bool operator >= (const modular& left, const modular& right) { return left() >= right(); }
friend bool operator < (const modular& left, const modular& right) { return left() < right(); }
friend bool operator > (const modular& left, const modular& right) { return left() > right(); }
// Assumes modulo is prime
// Fermat's Little Theorem (https://www.wikiwand.com/en/Fermat%27s_little_theorem)
modular fermat_inverse () const {
modular inverse = *this;
inverse.binary_exponentiate(mod() - 2);
#ifdef LOST_IN_SPACE
if (*this * inverse != 1)
throw std::runtime_error("integer and modulo are not co-prime");
#endif
return inverse;
}
// Assumes modulo is prime
// Euler's Totient Theorem (https://www.wikiwand.com/en/Euler%27s_theorem)
modular euler_inverse () const {
auto m = mod();
long double totient = mod();
for (Common_Integer i = 2; i * i <= m; ++i)
if (m % i == 0) {
while (m % i == 0)
m /= i;
totient *= 1.0L - 1.0L / i;
}
if (m > 1)
totient *= 1.0L - 1.0L / m;
Common_Integer phi = totient;
modular inverse = *this;
inverse.binary_exponentiate(phi - 1);
#ifdef LOST_IN_SPACE
if (*this * inverse != 1)
throw std::runtime_error("integer and modulo are not co-prime");
#endif
return inverse;
}
// Assumes modulo is co-prime with integer
// Extended Euclidean Algorithm (https://www.wikiwand.com/en/Extended_Euclidean_algorithm)
modular extended_euclidean_inverse () const {
Common_Integer u = 0, v = 1;
Common_Integer a = integer, m = mod();
while (a != 0) {
Common_Integer t = m / a;
m -= t * a;
u -= t * v;
std::swap(a, m);
std::swap(u, v);
}
#ifdef LOST_IN_SPACE
if (m != 1)
throw std::runtime_error("integer and modulo are not co-prime");
#endif
return u;
}
// Assumes power is non-negative
modular binary_exponentiate (Common_Integer power) {
auto base = *this;
*this = 1;
while (power > 0) {
if (power & 1)
*this *= base;
base *= base;
power >>= 1;
}
return *this;
}
modular abs () const { return *this; }
std::string to_string () const { return std::to_string((*this)()); }
friend auto operator << (std::ostream& stream, const modular& m)
-> std::enable_if_t <sfinae::is_ostream_streamable_v <Common_Integer>, std::ostream&>
{ return stream << m(); }
friend auto operator >> (std::istream& stream, modular& m)
-> std::enable_if_t <sfinae::is_istream_streamable_v <Common_Integer>, std::istream&>
{ return stream >> m(), m.normalize(), stream; }
};
}
template <typename Integer, Integer Modulo = Integer()>
using modular = competitive_programming::utility::math::modular <Integer, Modulo>;
using mod998244353 = modular <int, 998244353>;
using mod1000000007 = modular <int, 1000000007>;
Usage
// Example 1
int main () {
int64_t n;
std::cin >> n;
mod998244353 y = 2;
y.binary_exponentiate(n);
std::vector <mod998244353> fibo (n + 1);
fibo [0] = 0;
fibo [1] = 1;
for (int i = 2; i <= n; ++i)
fibo [i] = fibo [i - 1] + fibo [i - 2];
std::cout << fibo [n] * y.fermat_inverse() << '\n';
return 0;
}
// Example 2
namespace competitive_programming::utility {
using namespace std::chrono;
class timer {
private:
time_point <steady_clock> begin, end;
public:
#ifdef LOST_IN_SPACE
timer () : begin (steady_clock::now()), end () { }
~timer () {
end = steady_clock::now();
std::cerr << "\n\nDuration: " << duration <double> (end - begin).count() << "s\n";
}
#else
timer () : begin (), end () { }
~timer () { }
#endif
};
}
int main () {
mod1000000007 m;
std::cin >> m; // 42
{
competitive_programming::utility::timer t;
std::cout << m.extended_euclidean_inverse() << '\n';
// 23809524
// Duration: 0.0000641200s
}
{
competitive_programming::utility::timer t;
std::cout << m.fermat_inverse() << '\n';
// 23809524
// Duration: 0.0000082080s
}
{
competitive_programming::utility::timer t;
std::cout << m.euler_inverse() << '\n';
// 23809524
// Duration: 0.0007950880s
}
return 0;
}
Question and Others
Comments about code-style and other related guides are welcome as usual. I'd like to know if my code is optimally written in the math related part. It would be great to have comments on is_mathematical
and is_comparable
parts of the code. With C++20 I believe, the way I've done it simplifies a great deal in implementation verbosity with concepts. What other features may be added to the code (that I may have missed)?
1 Answer 1
Don't #include <bits/stdc++.h>
You should avoid #include <bits/stdc++.h>
, as there are several issues with it, the most important being that it is just not standard C++. You are putting your classes in a namespace competitive_programming
, so I can see where you are coming from, but please unlearn this.
Only write concepts that make sense
I see you wrote both has_begin_iterator
and has_end_iterator
. Does it ever make sense for a class to only have one of those, and do you ever expect to write a template where you only need one of them? I think it makes more sense to just have a single is_container
or is_iterable
concept.
is_mathematical
and is_comparable
check too much
I think everyone would consider a float
to be a "mathematical" type. However, since you check for the binary operators, this template will fail for float
, double
, std::complex
and possibly other types as well. In fact, it also fails for modular
! You might want to reduce the amount of operations you check for.
Note that the word "mathematical" also doesn't imply a certain set of operations. In Mathematics, you can make types you can do math with, but which for example only support addition and subtraction operations. The closest that comes to what you mean is the word "arithmetic". So I suggest that you rename it is_arithmetic
, and do not check for the logical and bitwise operations.
There is a similar issue with is_comparable
. You check for the presence of all comparison operators, but there are types where only some operators are valid. Consider for example std::complex<float>
: it's an arithmetic type, you can certainly check whether two complex numbers are the same or not, however you cannot check if one number compares greater than another number; that operation does not make sense for complex numbers. So here it might make sense to split it up into an is_equality_comparable
and perhaps something like is_less_greater_than_comparable
. Note that most algorithms only require a single comparison operator to be available, either ==
or <
, and you don't want to unnecessarily restrict the types you accept.
Unnecesarily template
d constructor
Why is the constructor for class modular
a template
? That way it will accept all types, even ones that don't make sense. Just make a non-templated constructor that accepts a Common_Integer
parameter:
template <typename Integer, Integer Modulo>
class modular {
...
modular (const Common_Integer &integer = {})
: integer (integer)
{ normalize(); }
...
};
Don't make a templated
cast operator
You should not make a cast operator that is a template that accepts all possible types. You already have operator()
for getting the integer
out, C++ will implicitly handle valid casts for you. I would rather write:
mod998244353 y = 2;
float x = y();
Than:
mod998244353 y = 2;
auto x = static_cast<float>(y);
Don't write unnecessary member function
Why does class modular
have a member function abs()
that just returns itself? Since the integer
is always normalized, you know the value will never be negative, so this seems redundant. If you include an abs()
function this suggests to a user that the class supports negative values, when it doesn't.
Be consistent using member functions from other member functions
I see this code:
void normalize () {
if (integer >= mod() || -integer <= mod())
...
}
Why directly access the member variable integer
, but use the member function mod
to get the value of the member variable modulo
? This makes it look like something special is going on, when it's not. Here I would just directly access both member variabels:
if (integer >= modulo || -integer <= modulo)
...
There is one place where you did use a member fuction to get the value of integer
:
std::string to_string () const { return std::to_string((*this)()); }
That looks very weird, again I would suggest directly accessing the member variable:
std::string to_string () const { return std::to_string(integer); }
-
\$\begingroup\$ Thank you for the comments and answer! I'll see to it that I make the recommended changes and remember the style guides in future as well. \$\endgroup\$Arrow– Arrow2020年11月20日 21:47:32 +00:00Commented Nov 20, 2020 at 21:47
Explore related questions
See similar questions with these tags.
std::is_integral
to check if a type is mathematical or not. Although, I'd like the template to be more general purpose and allow user-defined integer types too. Say that someone has a arbitrary multiprecision big integer implementation and would like to use a really big modulo (bigger thanstd::numeric_limits <int>::max()
, my code would allow you to do that. \$\endgroup\$Integer
is a signed type? \$\endgroup\$-1 % -10
==9 % -10
? If so, all's well and good. (That was a real question, not a criticism). \$\endgroup\$