The actual task is to replace the tanh_() at line#799 with SeLU activation function in new_gate of gru_cell. The following code block is the RNN.cpp file from PyTorch github repo.
template <typename cell_params>
struct GRUCell : Cell<Tensor, cell_params> {
using hidden_type = Tensor;
hidden_type operator()(
const Tensor& input,
const hidden_type& hidden,
const cell_params& params,
bool pre_compute_input = false) const override {
if (input.is_cuda() || input.is_xpu()) {
TORCH_CHECK(!pre_compute_input);
auto igates = params.matmul_ih(input);
auto hgates = params.matmul_hh(hidden);
auto result = at::_thnn_fused_gru_cell(
igates, hgates, hidden, params.b_ih(), params.b_hh());
// Slice off the workspace argument (it's needed only for AD).
return std::move(std::get<0>(result));
}
const auto chunked_igates = pre_compute_input
? input.unsafe_chunk(3, 1)
: params.linear_ih(input).unsafe_chunk(3, 1);
auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(3, 1);
const auto reset_gate =
chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
const auto input_gate =
chunked_hgates[1].add_(chunked_igates[1]).sigmoid_();
**const auto new_gate =
chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate)).tanh_();**
return (hidden - new_gate).mul_(input_gate).add_(new_gate);
}
};
The new_gate is the Tensor. How we can implement a custom function to iterate over the Tensor and apply the Selu activation function on them ??
I replaced the tanh_() with selu_() that was present in the build/aten/src/ATen/ops/selu.h folder after building the PyTorch from source code In Develop Mode and Also included the related header files. But on Re-Building it generated an error "Did you mean relu_()".
I also tried to implement my own function for selu() but the problem was regarding Tensor datatype.