On the topic of tail call optimization, I found two RFCS: 271 and 1888.
Until this gets implemented, I wanted to do something inspired from Scala's Trampoline
functionality.
The actual implementation is far from a full-fledged tail call optimization in that it replaces recursive function calls with iterative ones, but it has its use for keeping the stack depth low. My actual state of work is as follows:
enum Trampoline<A, R> {
Continue(A, R),
End(R),
}
impl<A, R> Trampoline<A, R> {
fn start(function: &Fn(A, R) -> Trampoline<A, R>, mut arg: A, mut sum: R) -> R
{
loop {
match function(arg, sum) {
Trampoline::Continue(r_arg, r_sum) => {
arg = r_arg;
sum = r_sum;
},
Trampoline::End(result) => return result,
}
}
}
}
fn recurse_trampolin(arg: i32, sum: i32) -> Trampoline<i32, i32> {
if arg == 0 {
Trampoline::End(sum)
} else {
Trampoline::Continue(arg - 1, sum + arg)
}
}
fn recurse_normal(arg: i32, sum: i32) -> i32 {
if arg == 0 {
sum
} else {
recurse_normal(arg - 1, sum + arg)
}
}
fn main() {
println!("{}", recurse_normal(5, 0));
println!("{}", Trampoline::start(&recurse_trampolin, 5, 0));
}
Apart from using references for compatibility with non-Copy
types, can this construction be further optimized or the code simplified/beautified?
1 Answer 1
Run Rustfmt. It will automatically fix such things as
- Opening curly braces go on the same line as the function signature (unless the function signature is already multiple lines).
- There is no comma at the end of a
match
arm that uses curly braces.
Your function name has a typo:
recurse_trampolin
Instead of taking a trait object reference (
&Fn(...) -> ...
), take a generic. This allows for better optimizations and avoids unnecessary indirection.
enum Trampoline<A, R> {
Continue(A, R),
End(R),
}
impl<A, R> Trampoline<A, R> {
fn start<F>(function: F, mut arg: A, mut sum: R) -> R
where
F: Fn(A, R) -> Trampoline<A, R>,
{
loop {
match function(arg, sum) {
Trampoline::Continue(r_arg, r_sum) => {
arg = r_arg;
sum = r_sum;
}
Trampoline::End(result) => return result,
}
}
}
}
fn recurse_trampoline(arg: i32, sum: i32) -> Trampoline<i32, i32> {
if arg == 0 {
Trampoline::End(sum)
} else {
Trampoline::Continue(arg - 1, sum + arg)
}
}
fn recurse_normal(arg: i32, sum: i32) -> i32 {
if arg == 0 {
sum
} else {
recurse_normal(arg - 1, sum + arg)
}
}
fn main() {
println!("{}", recurse_normal(5, 0));
println!("{}", Trampoline::start(recurse_trampoline, 5, 0));
}