I implemented a ripple carry adder in Rust. The function takes and outputs strings of 1
s and 0
s.
How can I improve this according to better Rust coding practises and API design, and make it more robust?
struct SumCarry {
sum: char,
carry: char
}
fn add_bits(left: char, right: char, carry_in: char) -> SumCarry {
match (left, right, carry_in) {
('0', '0', '0') => SumCarry { sum: '0', carry: '0' },
('0', '1', '0') => SumCarry { sum: '1', carry: '0' },
('1', '0', '0') => SumCarry { sum: '1', carry: '0' },
('1', '1', '0') => SumCarry { sum: '0', carry: '1' },
('0', '0', '1') => SumCarry { sum: '1', carry: '0' },
('0', '1', '1') => SumCarry { sum: '0', carry: '1' },
('1', '0', '1') => SumCarry { sum: '0', carry: '1' },
('1', '1', '1') => SumCarry { sum: '1', carry: '1' },
_ => panic!()
}
}
fn zip_longest<IntoIter, Item>(left: IntoIter, right: IntoIter, default: Item) -> impl Iterator<Item = (Item, Item)>
where
Item: Clone,
IntoIter: IntoIterator<Item = Item>
{
let mut left_iter = left.into_iter();
let mut right_iter = right.into_iter();
std::iter::from_fn(move || {
match (left_iter.next(), right_iter.next()) {
(Some(l), Some(r)) => Some((l, r)),
(Some(l), None) => Some((l, default.clone())),
(None, Some(r)) => Some((default.clone(), r)),
(None, None) => None
}
})
}
pub fn ripple_carry_adder(left: &str, right: &str) -> String {
let mut carry_in = '0';
let mut result = String::new();
for (l, r) in zip_longest(left.chars().rev(), right.chars().rev(), '0') {
let SumCarry { sum, carry: carry_out } = add_bits(l, r, carry_in);
result.insert(0, sum);
carry_in = carry_out;
}
if carry_in == '1' {
result.insert(0, carry_in);
}
result
}
pub fn main() {
println!("{}", ripple_carry_adder("101", "1"));
}
2 Answers 2
Because you're only working with ascii 0
and 1
you could use bytes instead of characters everywhere, it saves 3 bytes for every time a char
would be there but arguably it makes the literals a lot less readable:
struct SumCarry {
sum: u8,
carry: u8,
}
In this match
several patterns have the same outcome, you can join them using |
fn add_bits(left: u8, right: u8, carry_in: u8) -> SumCarry {
match (left, right, carry_in) {
(b'0', b'0', b'0') => SumCarry { sum: b'0', carry: b'0', },
(b'0', b'1', b'0') | (b'1', b'0', b'0') | (b'0', b'0', b'1') => SumCarry {
sum: b'1',
carry: b'0',
},
(b'1', b'1', b'0') | (b'0', b'1', b'1') | (b'1', b'0', b'1') => SumCarry {
sum: b'0',
carry: b'1',
},
(b'1', b'1', b'1') => SumCarry { sum: b'1', carry: b'1', },
_ => panic!(),
}
}
Since zip_longest
is already very specialized to your situation might as well fix item to char
/u8
and drop the default
fn zip_longest<IntoIter>(left: IntoIter, right: IntoIter) -> impl Iterator<Item = (u8, u8)>
where
IntoIter: IntoIterator<Item = u8>
{
const DEFAULT: u8 = b'0';
let mut left_iter = left.into_iter();
let mut right_iter = right.into_iter();
std::iter::from_fn(move || {
match (left_iter.next(), right_iter.next()) {
(Some(l), Some(r)) => Some((l, r)),
(Some(l), None) => Some((l, DEFAULT)),
(None, Some(r)) => Some((DEFAULT, r)),
(None, None) => None
}
})
}
Alternatively you could make it a lot more general by simply not binding both iterators to be of the same type:
fn zip_longest<L, R, Item>(left: L, right: R, default: Item) -> impl Iterator<Item = (Item, Item)>
where
Item: Clone,
L: IntoIterator<Item = Item>,
R: IntoIterator<Item = Item>,
{
let mut left_iter = left.into_iter();
let mut right_iter = right.into_iter();
std::iter::from_fn(move || {
match (left_iter.next(), right_iter.next()) {
(Some(l), Some(r)) => Some((l, r)),
(Some(l), None) => Some((l, default.clone())),
(None, Some(r)) => Some((default.clone(), r)),
(None, None) => None
}
})
}
Inserting at the beginning of a String
means that all characters have to be copied over it is O(n), that means that your ripple_carry_adder
is accidentially O(n2) you should prefer reversing the string once instead.
pub fn ripple_carry_adder(left: &str, right: &str) -> String {
let mut carry_in = b'0';
let mut result = String::new();
for (l, r) in zip_longest(left.bytes().rev(), right.bytes().rev()) {
let sum;
SumCarry {
sum,
carry: carry_in,
} = add_bits(l, r, carry_in);
result.push(sum as char);
}
if carry_in == b'1' {
result.push(carry_in as char);
}
result.chars().rev().collect()
}
-
\$\begingroup\$ Thanks for the helpful. I actually tried to make a general
zip_longest
. What about it currently makes it specific? \$\endgroup\$user98809– user988092023年03月23日 22:48:55 +00:00Commented Mar 23, 2023 at 22:48 -
1\$\begingroup\$ It requires both
left
andright
to be of the same type, you couldn't for example pass in one forward and one reversed iterator over the same collection. \$\endgroup\$cafce25– cafce252023年03月23日 23:00:09 +00:00Commented Mar 23, 2023 at 23:00 -
1\$\begingroup\$ The
match
could be entirely replaced with the boolean expressions for the output bits: it's a full adder (en.wikipedia.org/wiki/Adder_(electronics)#Full_adder), sosum = left ^ right ^ carry_in;
andcarry_out = ((left ^ right)&carry_in) | (left & right);
. Or more simply, use the low 2 bits of a u8 for integer addition and extract them:sum = a+b+carry_in;
carry_out = sum >> 1;
(bit #1)sum &= 1;
(bit #0). When doing extended / arbitrary precision integer math, you can use fixed-width addition as a building block, @theonlygusti. Normally 32 or 64-bit chunks! \$\endgroup\$Peter Cordes– Peter Cordes2023年03月24日 00:35:41 +00:00Commented Mar 24, 2023 at 0:35 -
1\$\begingroup\$ In languages like C where you don't have convenient access to the carry-out of a full-width addition, or the ability to feed it a carry-in, you might do like Python and use 30-bit chunks in 32-bit integers, so there's room for adding integers and having the carry-out in the
uint32_t
. But in languages like Rust that are better at this, you'd uselet (sum, carry_out) = left.carrying_add(right, carry_in)
- doc.rust-lang.org/std/primitive.u32.html#method.carrying_add (carrying_add is a new API.) \$\endgroup\$Peter Cordes– Peter Cordes2023年03月24日 00:43:49 +00:00Commented Mar 24, 2023 at 0:43 -
\$\begingroup\$ If you're manually simulating a ripple-carry adder internals, you can keep it simple and do each bit separately, or derive what the carry-in must have been for each bit from the sum and the two inputs, and use that to calculate the carry-out. So you can get 64 sum bits and 64 carry-out bits stored in two
u64
s, calculated with only oneadc
and a few boolean operations and maybe a shift, @theonlygusti. That's a very different design than what you went for, so IDK if I should post that as a code-review answer, especially without taking the time to actually write it in Rust. \$\endgroup\$Peter Cordes– Peter Cordes2023年03月24日 00:47:28 +00:00Commented Mar 24, 2023 at 0:47
I’ll start with the really minor stuff.
Reconsider Names Like l
It’s easily confused with the numeral 1
, or more rarely I
, and if you don’t have l
and r
on the same screen, it can be very hard to guess what either one is an abbreviation for.
You Can Clean up zip_longest
a Bit
The basic approach here is sound. However, what you actually want to do is
- Take two generic iterators (so you can pass in reversed or mapped iterators)
- Whose item types have default values, from the
std::default::Default
trait - Pad the shorter sequence with default values (without cloning, although this is harmless for built-in types)
- Zip to an iterator over pairs
- And hand along the result to the chain
That gives you a function that you might plausibly re-use in another app. It’s very close to what you wrote:
fn zip_pad<IterA, IterB>(mut left_it: IterA, mut right_it: IterB) ->
impl Iterator<Item = (IterA::Item, IterB::Item)>
where
IterA: Iterator,
IterB: Iterator,
IterA::Item: Default,
IterB::Item: Default
{
return std::iter::from_fn(move || {
match (left_it.next(), right_it.next()) {
(None, None) => None,
(Some(left), None) => Some ((left, IterB::Item::default())),
(None, Some(right)) => Some((IterA::Item::default(), right)),
(Some(left), Some(right)) => Some((left, right))
}
})
}
The core logic is similar, but it drops IntoIterator
, the user-supplied default value, and the separate Item
type parameter (If that was on purpose, you might want a type
alias), and is more flexible about accepting two different types of iterator.
Update: You’ve explained your reasoning in the comments, and it makes sense. Sometimes you want to pad with different elements (zero digits, spaces, null bytes). In that case, I’d still clean up the associated types and traits a bit, but your approach better suits what you wanted and mine is more of a special case.
Now Some Significant Advice
Iterate Over the Bytes of the String, not the Codepoints
In Rust, a String
is stored in UTF-8. Calling String::chars
will give you an iterator of 32-bit Unicode codepoints. You can iterate forward or in reverse, but this is inefficient, compared to iterating over a Vec
, because UTF-8 is a variable-width encoding. You must count the bytes representing each Unicode codepoint and convert it to 32-bit UCS-4.
If you convert the string into an iterator of bytes, each element is a fixed width and the optimizer can go to town. In particular, LLVM will now prrform tail-call elimination and inline all the functions.
Use bool
for Internal Computation
The main convenience of this for you is that your matches can be exhaustive without needing to throw in a default clause (which spoils the very nice feature of warning you when you write a non-exhaustive pattern). This also saves you from needing to do any runtime checks for valid char
values after the initial conversion, and the compiler can pack them more efficiently than char
—which, remember, is not a single byte in Rust, nor is a string stored as an array of char
.
But it generates better code, too.
Here are some partial listings of what you get when running zip_longest
on char
iterators. An excerpt of the code to check for None
versus Some(l)
and Some(r)
:
.LBB9_17:
mov rbx, r15
cmp ecx, 1114112
jne .LBB9_23
jmp .LBB9_46
.LBB9_11:
movzx edx, byte ptr [rbx - 1]
test dl, dl
js .LBB9_14
dec rbx
cmp ecx, 1114112
jne .LBB9_22
.LBB9_13:
cmp edx, 1114112
jne .LBB9_30
jmp .LBB9_46
There are at least six comparisons of a register to 1114112
, AKA 0x110000
, the first number that is not a valid Unicode codepoint and the binary representation of None
for Option<char>
.
This then calls carry_bits
. An excerpt of the code for the match
expression in that function:
.LBB9_24:
cmp ecx, 48
je .LBB9_30
cmp ecx, 49
jne .LBB9_45
cmp edx, 49
je .LBB9_40
cmp edx, 48
jne .LBB9_49
mov cl, 49
cmp eax, 48
jne .LBB9_33
jmp .LBB9_39
.LBB9_30:
cmp edx, 48
je .LBB9_35
cmp edx, 49
jne .LBB9_49
mov cl, 49
cmp eax, 48
je .LBB9_39
.LBB9_33:
cmp eax, 49
jne .LBB9_49
mov cl, 48
mov ebp, 49
jmp .LBB9_43
.LBB9_35:
mov cl, 48
cmp eax, 48
je .LBB9_39
cmp eax, 49
jne .LBB9_49
mov cl, 49
mov ebp, 48
jmp .LBB9_43
.LBB9_39:
mov ebp, eax
ASCII 48 is '0'
and ASCII 49 is '1'
, so we can see that the program is doing a series of nested conditional branches. Unlike the previous set of conditional branches, which at least only switch to returning None
once, these are not predictable. This means they will incur hefty branch misprediction penalties on a modern CPU.
Calling the bool
overload of zip_pad
changes the generated code for one path through the same functions, adding two digits, to:
.LBB6_12:
dec qword ptr [rsp]
cmp al, 2
jne .LBB6_7
.LBB6_7:
test al, al
setne bl
test cl, cl
setne al
mov r14d, eax
and r14b, bl
xor bl, al
In this case, al
holds an Option<bool>
, whose binary representation is either 0 for Some(false)
, 1 for Some(true)
or 2 for None
. If we call the Boolean value initially in al
, a, and the value initially in cl
, c, we see that, when a is not None
, the program jumps to a branch that sets the next digit to a ⊕ c and the carry-out to a ∧ c.
The generated code for the more complex cases is:
.LBB6_35:
dec rdi
cmp al, 2
jne .LBB6_26
test r14b, r14b
setne bl
jmp .LBB6_37
.LBB6_25:
xor edx, edx
mov rdi, rsi
.LBB6_26:
test r14b, r14b
setne bl
test dl, dl
setne r12b
test al, al
je .LBB6_37
mov qword ptr [rsp], rdi
test dl, dl
setne al
test r14b, r14b
setne r12b
or dl, r14b
setne r14b
xor r12b, al
xor r12b, 1
mov r15b, 3
lea rbx, [rbp + 1]
cmp rbx, qword ptr [rsp + 8]
jne .LBB6_41
jmp .LBB6_42
.LBB6_37:
mov qword ptr [rsp], rdi
mov r14d, r12d
and r14b, bl
xor r12b, bl
mov r15b, 3
lea rbx, [rbp + 1]
cmp rbx, qword ptr [rsp + 8]
jne .LBB6_41
Not only are these shorter instructions, they are branchless, inline and use bitwise register instructions. You also see that iterating over the byte iterator in reverse order, but not the String::chars()
iterator, is as simple as decrementing a pointer by one byte.
You Can Generate the (Reversed) Output as an Iterator
It’s more efficient to do this and .collect()
the results into a Vec
than to insert
within a loop, because each insertion must check the amount of space available and the current end of the vector, and resize if necessary.
When you .collect()
a standard-library iterator, it comes with a size hint that lets the implementation create the Vec
with approximately the right size buffer and fill it up to the amount of memory it reserved without checking again.
The code is very similar to the very nice code you wrote for zip_longest
, which comes out fine when you call it with .bytes().rev().map(to_bool)
instead of .chars().rev()
.
let mut reverse_zip = zip_pad(a_rev, b_rev).peekable();
let mut carry = false;
let mut sum : Vec<u8> = std::iter::from_fn(move|| {
match (reverse_zip.peek(), carry) {
(None, false) => None,
(None, true) => {
carry = false;
Some(true)
}
(Some(&(a, b)), carry_in) => {
let (digit, carry_out) = add_bits(a, b, carry_in);
carry = carry_out;
reverse_zip.next();
Some(digit)
}
}
}).map(to_utf8)
.collect();
There are a couple of wrinkles here: if there is a carry out of the final digit, we need to return one more digit of output than there were pairs of input digits. If not for that, I could much more simply have used Iterator::scan
. Trying to cheat by not returning None
on the last item of input causes an internal panic.
To work around that, I make the zipped iterator .peekable()
, and then .peek()
ahead to decide what to do. If we’re on the last input item and there is no carry, I return the leading 1
, clear the carry-in, and don’t advance the input iterator. That causes the next iteration to hit the terminating case, safely.
Although it’s at odds with what is otherwise similar to functional programming, this kind of closure in Rust requires mutable external state.
You’ll also observe that this generates the output bytes in the same order as reverse_zip
, that is, backwards.
Add from Right to Left, then Reverse, to Avoid Array-Shifting
I see that @cafce25 beat me to this one, but it’s important advice. Whenever you insert at the front of a String
, the entire string has to be shifted one character to the right, whenever you add a digit. This takes O(N2) time in all. You want to get that down to O(N). Not only that, but a char
in Rust is not an alias for byte like in C; it is a Unicode codepoint that must be converted from and two UTF-8 whenever it is extracted from and inserted into a String
.
Remove Unnecessary Copies
Once you have generated the digits of output into a sequence of bool
, you can convert it to ASCII digits by mapping a helper function:
const fn to_utf8(d: bool) -> u8 {
match d {
false => '0' as u8,
true => '1' as u8
}
}
This compiles to next-to-no overhead: an inline add r8, 48
instruction (since 48 is '0'
, 49 is '1'
, 0 is false
and 1 ia true
). The code above collects all of these converted ASCII digits into a mutable Vec<u8>
.
Why mutable? There’s a function sum.reverse()
that reverses the bytes of output in place, with no additional memory allocation or copying. When you enable vector instructions, the implementation uses them to optimize this step.
Since that gets us the correct UTF-8 output string, we can call String::from_utf8
to move the buffer directly from the Vec
container to a String
container, without needing to make another copy. Since this returns a Result
type, we need to .unwrap()
the return value. It’s good practice to use .expect
for this with a debug message. (We could in theory use the unsafe
function String::from_utf8_unchecked
instead.)
Putting it All Together
fn zip_pad<IterA, IterB>(mut left_it: IterA, mut right_it: IterB) ->
impl Iterator<Item = (IterA::Item, IterB::Item)>
where
IterA: Iterator,
IterB: Iterator,
IterA::Item: Default,
IterB::Item: Default
{
return std::iter::from_fn(move || {
match (left_it.next(), right_it.next()) {
(None, None) => None,
(Some(left), None) => Some ((left, IterB::Item::default())),
(None, Some(right)) => Some((IterA::Item::default(), right)),
(Some(left), Some(right)) => Some((left, right))
}
})
}
fn ripple_carry_adder( left_input :&str, right_input: &str ) -> String {
let a_rev = left_input.bytes()
.rev()
.map(to_bool);
let b_rev = right_input.bytes()
.rev()
.map(to_bool);
let mut reverse_zip = zip_pad(a_rev, b_rev).peekable();
let mut carry = false;
let mut sum : Vec<u8> = std::iter::from_fn(move|| {
match (reverse_zip.peek(), carry) {
(None, false) => None,
(None, true) => {
carry = false;
Some(true)
}
(Some(&(a, b)), carry_in) => {
let (digit, carry_out) = add_bits(a, b, carry_in);
carry = carry_out;
reverse_zip.next();
Some(digit)
}
}
}).map(to_utf8)
.collect();
/* The output string was generated in reverse order. Reverse it in place,
* without copying.
*/
sum.reverse();
/* Since the bytes of the vector are valid UTF-8, move them to a string,
* without copying.
*/
return String::from_utf8(sum).expect("The output should have contained only two possible values.");
const fn to_bool(c :u8) -> bool {
match c as char {
'0' => false,
'1' => true,
_ => unimplemented!()
}
}
const fn add_bits(a: bool, b: bool, c: bool) -> (bool, bool)
{
match (a, b, c) {
(false, false, false) => (false, false),
(false, false, true) => (true, false),
(false, true, false) => (true, false),
(false, true, true) => (false, true),
(true, false, false) => (true, false),
(true, false, true) => (false, true),
(true, true, false) => (false, true),
(true, true, true) => (true, true)
}
}
const fn to_utf8(d: bool) -> u8 {
match d {
false => '0' as u8,
true => '1' as u8
}
}
}
And a few tests:
pub fn main() {
let test1 = ripple_carry_adder("101", "1");
assert!(test1 == "110", "Expected 101 + 1 = 110. Got {}", test1);
println!("Test 1 passed.");
let test2 = ripple_carry_adder("1011", "1");
assert!(test2 == "1100", "Expected 1011 + 1 = 1100. Got {}", test2);
println!("Test 2 passed.");
let test3 = ripple_carry_adder("101010", "10101");
assert!(test3 == "111111", "Expected 101010 +たす 10101 =わ 111111. Got {}", test3);
println!("Test 3 passed.");
let test4 = ripple_carry_adder("100", "100");
assert!(test4 == "1000", "Expected 100 + 100 = 1000. Got {}", test4);
println!("Test 4 passed.");
}
-
\$\begingroup\$ I had a custom
default
parameter onzip_longest
because I don't think the defaultchar
is'0'
, and I feel like it would be a pain to make a custom type andDefault
for everyzip_longest
call? Is there a way to take iterators while also specifying the underlying values won't be mutated? \$\endgroup\$user98809– user988092023年03月24日 14:15:49 +00:00Commented Mar 24, 2023 at 14:15 -
\$\begingroup\$ @theonlygusti That’s a valid use case! Here, I just switched to a
bool
representation where padding with the default value sufficed. \$\endgroup\$Davislor– Davislor2023年03月24日 20:16:23 +00:00Commented Mar 24, 2023 at 20:16 -
\$\begingroup\$ @theonlygusti Are you thinking of the
Cloned
trait, for iterators that clone instead of moving or mutating their data? I don’t think Rust makes it possible to stop a type from also implementing some other trait that enables interior mutability. You could also maybe stick it in aconst fn
, so the compiler will warn you about some things that could have externally-visible side-effects. \$\endgroup\$Davislor– Davislor2023年03月24日 20:49:34 +00:00Commented Mar 24, 2023 at 20:49 -
1\$\begingroup\$ @theonlygusti Now that I understand your reasoning—which is valid; sometimes you want to pad with zero digits, sometime spaces, sometimes null bytes—I don’t think there’s a significantly better way than your original approach. You do, in this case, want to
.clone()
the user-supplied padding object. (And maybe call it something like that, instead ofdefault
, because the two meanings of default confused me.) For a primitive type likechar
, that’s zero-cost. Just pass it by immutable reference? \$\endgroup\$Davislor– Davislor2023年03月24日 20:56:40 +00:00Commented Mar 24, 2023 at 20:56 -
1\$\begingroup\$ @theonlygusti Generally, iterators themselves need to be
mut
in order to callnext
on them (because.next()
updates the iterator). If you call.iter()
, the iterators return non-owning references, which sounds like what you want. If you call.iter_mut()
, you get iterators over mutable references, and.into_iter()
transfers ownership of the elements. There are also, as I mentioned, iterators that clone instead. \$\endgroup\$Davislor– Davislor2023年03月24日 22:01:54 +00:00Commented Mar 24, 2023 at 22:01