I'm implementing some mathematical functions right now for future use in more exciting things and wanted to know if I'm on the right track in my approach, as I don't have much C++ experience.
Here I'm offering my softmax function implementation.
#include <cmath>
#include <iterator>
#include <functional>
#include <numeric>
#include <type_traits>
template <typename It>
void softmax (It beg, It end)
{
using VType = typename std::iterator_traits<It>::value_type;
static_assert(std::is_floating_point<VType>::value,
"Softmax function only applicable for floating types");
auto max_ele { *std::max_element(beg, end) };
std::transform(
beg,
end,
beg,
[&](VType x){ return std::exp(x - max_ele); });
VType exptot = std::accumulate(beg, end, 0.0);
std::transform(
beg,
end,
beg,
std::bind2nd(std::divides<VType>(), exptot));
}
Info on softmax
The softmax function is defined as
$$\text{softmax}(\mathbf{x})_{i} = \frac{\exp(x_{i})}{\sum_{j=1}^{n} \exp(x_{j})}$$
As mentioned in a comment, my reasoning for evaluating the softmax after subtracting the maximum of my vector is numerical stability. Consider the case where each element of my iterator is equal to some constant \$\alpha\$. If \$\alpha\$ is large and positive we might overflow or if large and negative we might underflow. Subtracting \$\max _{i} x_{i}\,ドル we ensure that the largest argument taken by \$\exp\$ is \0ドル\$ (so no overflow) and that at least one denominator term protects us from underflow leading to division by zero.
Thoughts while writing this
Is the
static_assert
a reasonable thing to check? The other thought I had was to use SFINAE which I don't know too much about. Liketemplate <typename Condition> using EnableIf = std::enable_if_t<Condition::value>; template < typename It, EnableIf<std::is_floating_point< typename std::iterator_traits<It>::value_type > >... > void softmax(...)
What I liked about this approach that if I try to pass an iterator with a non-floating point type somewhere I get a nice immediate error recognition in my Emacs session for no function overload.. but I thought the
static_assert
might be better because there really is no other overload I have in mind forsoftmax
- passing a non-floating iterator is really just an error.Is there any possible issues / considerations when using the value type of the iterator like I am doing here?
2 Answers 2
First thoughts
The code is very easy to follow, once I'd followed the link to the softmax definition (it may be worth including the Wikipedia link in the function's comment).
Interface
This is a destructive operation; it might be desirable to provided a non-destructive option:
template <typename IterIn, typename IterOut = IterIn>
void softmax (IterIn beg, IterIn end, IterOut dest = beg)
If you're feeling experimental, consider using concepts to constrain the iterator types.
Add const where possible
A little help to your readers:
auto const max_ele { *std::max_element(beg, end) };
VType const exptot = std::accumulate(beg, end, 0.0);
Bugfix - use the correct accumulate()
By passing the double
value 0.0
as the third argument to std::accumulate()
, we cause it to infer double
for its type. That's not what we want if VType
is long double
or some user-defined floating type. We should instead use VType
explicitly, using one of
VType const exptot = std::accumulate<IterIn, VType>(beg, end, 0.0);
VType const exptot = std::accumulate(beg, end, VType{});
Prefer the second of these, as C++20 says:
The number and order of deducible template parameters for algorithm declarations are unspecified, except where explicitly stated otherwise. [Note: Consequently, the algorithms may not be called with explicitly-specified template argument lists. — end note]
Consider accumulating as you exponentiate
We can save one pass over the input, at some expense to simplicity, by accumulating as we go:
VType exptot = 0;
std::transform(
beg,
end,
beg,
[&](VType x){ auto ex = std::exp(x - max_ele); exptot += ex; return ex; });
-
\$\begingroup\$ I particularly appreciate the interface change recommendation - I was debating overloading the function with something non-destructive etc etc, completely forgot that this is how some STL algorithms give you the option for destructive or non-destructive. Thanks! \$\endgroup\$Eric Hansen– Eric Hansen2017年10月16日 13:13:02 +00:00Commented Oct 16, 2017 at 13:13
-
\$\begingroup\$ Will that first bugfix suggestion really work? I would think not due to
std::accumulate
having two template parameters. \$\endgroup\$miradulo– miradulo2017年10月17日 01:57:37 +00:00Commented Oct 17, 2017 at 1:57 -
1\$\begingroup\$ Good catch, @Mitch - I should be more careful! Now fixed. \$\endgroup\$Toby Speight– Toby Speight2017年10月17日 08:31:53 +00:00Commented Oct 17, 2017 at 8:31
-
1\$\begingroup\$
std::accumulate<IterIn, VType>
does not work since C++20. [algorithms.requirements]/15 (eel.is/c++draft/algorithms.requirements#15). \$\endgroup\$L. F.– L. F.2019年06月11日 12:22:06 +00:00Commented Jun 11, 2019 at 12:22
Algorithm
You could actually skip looking up the maximum element in the range. This is because it gets cancelled anyways: $$y = {{e^{x - x_{max}}} \over {\Sigma e^{x - x_{max}}}} ={ {e^x \over e^{x_{max}}} \over {{{1} \over {e^{x_{max}}}}\Sigma e^x}} = {{e^x} \over {\Sigma e^x}} $$
Also, currently the results are stored in place, i.e. the original input data will be lost. This might not always be wanted, so maybe accept an iterator to write the results to?
static_assert
vs. SFINAE
I personally like the SFINAE approach more in this case, because it's easier to introduce another overload if needed (e.g. for iterators over associative containers) and you get immediate error reporting. That said, if the decision is final that you won't ever need another overload, static_assert
works fine.
iterator value_type
Well, if the container is nicely conforming to standard library guidelines, you'll be fine with using std::iterator_traits<It>::value_type
. For custom containers, this might not be the case, though - for those cases you could use decltype(*beg)
instead.
-
\$\begingroup\$ What's the impact on numeric precision of the transformation in your first section? We're now using much bigger numbers than we were, so I'd think there's an increased risk of overflow, at least. \$\endgroup\$Toby Speight– Toby Speight2017年10月16日 10:18:26 +00:00Commented Oct 16, 2017 at 10:18
-
\$\begingroup\$ @TobySpeight: It has the same risk of overflowing as your current code has of getting rounded to 0 because the number are getting too small to represent with the given precision. \$\endgroup\$hoffmale– hoffmale2017年10月16日 10:21:58 +00:00Commented Oct 16, 2017 at 10:21
-
\$\begingroup\$ Just to be clear - it's Eric's current code, not mine! \$\endgroup\$Toby Speight– Toby Speight2017年10月16日 10:35:51 +00:00Commented Oct 16, 2017 at 10:35
-
\$\begingroup\$ @TobySpeight: Whoops, I just assumed you were the OP. Sorry ^^ \$\endgroup\$hoffmale– hoffmale2017年10月16日 10:41:07 +00:00Commented Oct 16, 2017 at 10:41
-
\$\begingroup\$ Thank you very much for the feedback @hoffmale. I apologize for not making this initially clearer in my post, but I added a bit more info (just information, not code changes) to justify why I used the max like so - if I am incorrect in my reasoning though I would be interested. \$\endgroup\$Eric Hansen– Eric Hansen2017年10月16日 13:08:26 +00:00Commented Oct 16, 2017 at 13:08
beg == end
? (I don't expect a problem with that, as the division won't happen when you have zero elements, but it may be relevant if you later refactor the code). \$\endgroup\$