I am new to Rust but well versed in Go. I have written the following program which initiates connection to the remote-addr
and starting a thread to listen to the incoming connections at local-addr
. It will forward packets received from the remote connection to all incoming connections. The program works very well but wanted a code review here.
use std::collections::HashMap;
use std::env;
use std::io::{Error, Write};
use std::sync::mpsc::{self, Sender};
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::{io::Read, net::TcpListener, net::TcpStream};
fn handle_stream(stream: &mut TcpStream, tx: &Sender<Vec<u8>>) -> Result<usize, Error> {
let mut buf = [0; 1024];
let n = stream.read(&mut buf[..])?;
tx.send(buf.to_vec()).unwrap_or_else(|error| {
panic!("Sending error: {:?}", error.to_string());
});
Ok(n)
}
fn process_args(cargs: Vec<String>) -> HashMap<String, String> {
let mut args = HashMap::new();
for e in cargs {
let arg: Vec<&str> = e.as_str().split("=").collect();
match arg[0] {
"--remote-addr" => {
args.insert(String::from("remote-addr"), String::from(arg[1]));
}
"--local-addr" => {
args.insert(String::from("local-addr"), String::from(arg[1]));
}
_ => (),
}
}
args
}
fn main() -> std::io::Result<()> {
let cargs: Vec<String> = env::args().collect();
let args = process_args(cargs);
if *&args.is_empty()
|| *&args.get("remote-addr").is_none()
|| *&args.get("local-addr").is_none()
{
panic!("expected args: remote and local addr")
}
let (tx, rx) = mpsc::channel();
let i_stream: Arc<Mutex<Vec<Result<TcpStream, Error>>>> = Arc::new(Mutex::new(Vec::new()));
let mut handles = vec![];
let i_stream_c1 = Arc::clone(&i_stream);
let handle = thread::spawn(move || loop {
let received: Vec<u8> = rx.recv().unwrap();
let mut streams_to_delete: Vec<usize> = Vec::new();
let mut incoming_stream = i_stream_c1.lock().unwrap();
for (idx, s) in incoming_stream.iter().enumerate() {
let mut x = s.as_ref().unwrap();
_ = x.write(received.as_slice()).unwrap_or_else(|_| {
streams_to_delete.push(idx);
0
});
}
for idx in streams_to_delete.iter() {
incoming_stream.remove(*idx).unwrap();
}
});
handles.push(handle);
let i_stream_c2 = Arc::clone(&i_stream);
let args1 = args.clone();
let handle = thread::spawn(move || {
let listener = TcpListener::bind(args1.get("local-addr").unwrap()).unwrap();
println!(
"ready to accept connections at {:?}",
args1.get("local-addr").unwrap()
);
for stream in listener.incoming() {
println!(
"new connection from {:?}",
stream.as_ref().unwrap().peer_addr().unwrap()
);
let mut incoming_stream = i_stream_c2.lock().unwrap();
incoming_stream.push(stream);
}
});
handles.push(handle);
let mut stream = TcpStream::connect(args.get("remote-addr").unwrap())?;
println!("connecting to {:?}", args.get("remote-addr").unwrap());
loop {
let mut break_loop = false;
handle_stream(&mut stream, &tx).unwrap_or_else(|_| {
break_loop = true;
0
});
if break_loop {
break;
}
}
for handle in handles {
handle.join().unwrap();
}
Ok(())
}
Here the error control is not so considered, hence you will see lot of unwrap
s. But that was intentional as it was not PROD ready code.
To run: ./tcp_forwarder --remote-addr=0.0.0.0:2100 --local-addr=0.0.0.0:2100
2 Answers 2
I would use scoped threads instead of
thread::spawn()
. This allows you to avoid theArc
and cloningargs
.For
args
, I would use a struct instead of aHashMap
.If you insist on using a map for the args, checking
args.is_empty()
is redundant - if it is empty,remote-addr
andlocal-addr
will beNone
.I would move the check of the args to
process_args()
. This is subjective, but I think validating the arguments is part of processing them.Instead of using
split("=")
and collecting toVec
(by the way: prefersplit('=')
tosplit("=")
as it is more performant, there is even a Clippy lint for that), you can usesplit_once()
.You don't need to collect
std::env::args()
to aVec
, as you immediately iterate over them.Instead of using a variable
break_loop
because you cannot break from inside the closure, you can just give up on usingunwrap_or_else()
and use a simpleif handle_stream(&mut stream, &tx).is_err() { break; }
. Or evenwhile handle_stream(&mut stream, &tx).is_ok() {}
(although I dislike that, as I dislike empty loop bodies).This is rather minor, but I would replace:
let mut incoming_stream = i_stream_c2.lock().unwrap();
incoming_stream.push(stream);
With:
i_stream_c2.lock().unwrap().push(stream);
This also means that the MutexGuard
is a temporary, meaning that we unlock the mutex at the end of the statement and not of the block, so if we add new statements after we won't hold the lock longer than needed.
Instead of pushing
Result
s toi_stream
andunwrap()
ing them on the first thread, I wouldunwrap()
them on the second immediately when we connect to them, needing to do that only once.Instead of using
unwrap_or_else()
afterx.write()
, forcing you to return a dummy result, I would do:
if x.write(received.as_slice()).is_err() {
streams_to_delete.push(idx);
}
If two streams are scheduled for removal, you will remove the wrong stream as the indices will be shuffled after the first stream will be removed. Instead, remove them in a reverse order.
I would use
streams_to_delete.into_iter()
instead ofstreams_to_delete.iter()
, also saving a dereference. You can replace this with juststreams_to_delete
, but you cannot if you reverse it as per the previous comment.This is a neat one (IMO): instead of storing indices to remove, you can use
retain_mut()
:
i_stream.lock().unwrap().retain_mut(|x| x.write(received.as_slice()).is_ok());
panic!("Sending error: {:?}", error.to_string())
->panic!("Sending error: {error}")
. Not exactly the same (the previous prints the string in debug representation) but this is probably what you wanted anyway.&mut buf[..]
->&mut buf
, and alsoreceived.as_slice()
->&received
, thanks to deref coercion. But this is matter of style; you can prefer your version.You send the whole
buf
, which is filled with zeroes after the message. You need to truncate it:buf[..n].to_vec()
instead ofbuf.to_vec()
.Instead of allocating a
Vec
for each packet, you can send the array. This saves an allocation per packet. You also need to send the size with it, otherwise you won't know how long the packet is (see the previous comment).Your imports are in two flavors: each module get a new
use
and combined for all (the last line). Choose one flavor and stick with it.write()
is not guaranteed to write all data. Usewrite_all()
instead. Did you run Clippy and it reported an error "written amount is not handled" so you added the_ =
? This is a very bad idea. Don't silence warnings, listen to them (and by the way, run Clippy. I don't do that as an experienced Rustacean because I have my own preferences, but as a beginner it will help you a lot. It reports some warning with your original code, try it yourself).
The code after all changes:
use std::env;
use std::io::{Error, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::mpsc::{self, Sender};
use std::sync::Mutex;
use std::thread;
const PACKET_SIZE: usize = 1024;
fn handle_stream(
stream: &mut TcpStream,
tx: &Sender<([u8; PACKET_SIZE], usize)>,
) -> Result<usize, Error> {
let mut buf = [0; PACKET_SIZE];
let n = stream.read(&mut buf)?;
tx.send((buf, n)).unwrap_or_else(|error| {
panic!("Sending error: {error}");
});
Ok(n)
}
#[derive(Default)]
struct Args {
remote_addr: String,
local_addr: String,
}
fn process_args(cargs: impl Iterator<Item = String>) -> Args {
let mut args = Args::default();
for e in cargs {
match e.split_once('=') {
Some(("--remote-addr", value)) => {
args.remote_addr = String::from(value);
}
Some(("--local-addr", value)) => {
args.local_addr = String::from(value);
}
_ => (),
}
}
if args.remote_addr.is_empty() || args.local_addr.is_empty() {
panic!("expected args: remote and local addr")
}
args
}
fn main() -> std::io::Result<()> {
let args = process_args(env::args());
let (tx, rx) = mpsc::channel();
let i_stream: Mutex<Vec<TcpStream>> = Mutex::new(Vec::new());
thread::scope(|s| {
s.spawn(|| {
let rx = rx; // Force a move for `rx`, otherwise we get an error as it does not implement `Sync`.
loop {
let (received, n): ([u8; PACKET_SIZE], usize) = rx.recv().unwrap();
let received = &received[..n];
i_stream
.lock()
.unwrap()
.retain_mut(|x| x.write_all(received).is_ok());
}
});
s.spawn(|| {
let listener = TcpListener::bind(&args.local_addr).unwrap();
println!("ready to accept connections at {:?}", &args.local_addr);
for stream in listener.incoming() {
let stream = stream.unwrap();
println!("new connection from {:?}", stream.peer_addr().unwrap());
i_stream.lock().unwrap().push(stream);
}
});
let mut stream = TcpStream::connect(&args.remote_addr)?;
println!("connecting to {:?}", &args.remote_addr);
loop {
if handle_stream(&mut stream, &tx).is_err() {
break;
}
}
Ok(())
})
}
```
Use existing libraries
Use e.g. clap
to parse the command line arguments. I recommend its derive
feature. This way, you only have to define one struct and the parsing magic is done for you by a proven an tested framework.
While it may be okay to manually yank the args from the cargs API in a small project, you'll soon get into the habit of copy-pasting such code snippets into future projects resulting in WET code.
Naming
There are only two hard things in Computer Science: cache invalidation and naming things. -- Phil Karlton
Your function name handle_stream
does not convey, what it actually does.
It forwards one stream to another. So why not call it forward_stream
instead? Also its parameters are misleading. sender
is the stream, that gets the data, while stream
is a stream being read from.
So a better fit for their names might be src
and dst
(or, source
and destination
if you like to be more verbose).
main
should be lean
Your current main
function is the longest function in your entire program. That is a sign of suboptimal design. main
should just set up your program and invoke the functions that actually do the business. Consider splitting it up into further functions that are then invoked in main
.
You will see that by carefully splitting responsibility while adhering to the principles of functional programming, your functions will get more concise and actually (unit-) testable.
Rethink your design
Your program currently only forwards between TCP sockets. Since you're basically reinventing netcat here, why not take its example and allow reading from STDIN and writing to STDOUT as well. This would hugely simplify your code, since you would no longer need threads to forward between two TCP sockets, but use two separate processes instead:
$ ./tcp_forwarder --listen=0.0.0.0:2100 | ./tcp_forwarder --send=0.0.0.0:2100
-
\$\begingroup\$
forward_stream()
is also not a good name IMO as it does not actually forward anything, just sends to the channel. Maybesend_packet_to_channel()
? Also, in my opinion in short scripty programs a longmain()
is fine. \$\endgroup\$Chayim Refael Friedman– Chayim Refael Friedman2022年12月30日 09:37:41 +00:00Commented Dec 30, 2022 at 9:37 -
1\$\begingroup\$ Thank you for sharing your opinion. Given my review, you already know mine. \$\endgroup\$Richard Neumann– Richard Neumann2022年12月30日 09:39:35 +00:00Commented Dec 30, 2022 at 9:39
clap
. \$\endgroup\$