Traits for beginners
This post is for those who are starting their journey with coding and Rust, it is an introduction to traits, and I will try to explain what are traits and how to use them.
My goal is to try to shed some light on how one can use traits to write code that is simpler to understand and maintain.
We are going to implement network encryption using chacha20poly1305. First, we will implement it without traits. Then, we will refactor it to use traits and see how much better the code will be.
I am assuming you have some familiarity with Rust basics, the book is a great place to start if you don’t.
A full async implemenation of the encryption and encryption traits can be found on my async-encrypted-stream Github repo.
Reading and Sending the data
As a first step, it is necessary to build the basics to send and receive the data. Due to the goal of this post, the code will not be async, because this would lead to unnecessary complexity.
The code will be a simple implementation that connects to a host and encrypts all the information exchanged. I will split the execution flow into “reading” and “writing” threads. The reading part will wait for messages from the network and write them to an internal channel. The writing part will wait for messages from an internal channel and then send them to the network.
The code below contains comments explaining each section.
// Imports all the necessary types from the standard library.
// This is important mostly for the traits
use std::{
error::Error,
io::{Read, Write},
net::TcpStream,
sync::{
mpsc::{Receiver, Sender},
Arc,
},
thread,
};
// Message types to avoid confusion with the channels
pub struct InboundMessage(Vec<u8>);
pub struct OutboundMessage(Vec<u8>);
fn connect_to_host(
addr: &str,
) -> (Sender<OutboundMessage>, Receiver<InboundMessage>) {
// Creates a TcpStream by connecting to a host and
// wraps it an Arc to share between two threads
let stream = Arc::new(TcpStream::connect(addr).expect("Failed to connect"));
let inbound_channel = receive_data(stream.clone());
let outbound_channel = send_data(stream);
(outbound_channel, inbound_channel)
}
fn send_data(stream: Arc<TcpStream>) -> Sender<OutboundMessage> {
// Spawn a thread that will read some bytes from the channel
// and send to the stream
let (tx, rx) = std::sync::mpsc::channel();
thread::spawn(move || {
let stream = &*stream;
while let Ok(OutboundMessage(bytes)) = rx.recv() {
if stream.write_all(&bytes).is_err() {
break;
}
}
});
tx
}
fn receive_data(stream: Arc<TcpStream>) -> Receiver<InboundMessage> {
// Spawn a thread that will read some bytes from the network
// and send to a channel
let (tx, rx) = std::sync::mpsc::channel();
thread::spawn(move || {
let stream = &*stream;
let mut buf = [0u8; 4096];
while let Ok(n) = stream.read(&mut buf) {
if n == 0 {
break;
}
if tx.send(InboundMessage(buf[..n].to_vec())).is_err() {
break;
}
}
});
rx
}
The most important part in the code above is the read
call, it populates the given buffer and returns the number of bytes read from the network, this is ok for now.
Adding encryption
With the basics in place, it is time to add encryption. For that, it is necessary to add a new dependency in the Cargo.toml file.
[dependencies]
chacha20poly1305 = { version = "0.10.1", features = ["stream", "std"] }
With the new dependency added, now it is time to create the encryptor and decryptor structs.
use chacha20poly1305::{
aead::{
stream::{DecryptorLE31, EncryptorLE31},
KeyInit,
},
XChaCha20Poly1305,
};
pub fn get_encryptor_and_decryptor(
key: [u8; 32],
nonce: [u8; 20],
) -> (
EncryptorLE31<XChaCha20Poly1305>,
DecryptorLE31<XChaCha20Poly1305>,
) {
// Creates the encryptor and decryptor pair with provided key and nonce.
let encryptor: EncryptorLE31<XChaCha20Poly1305> =
chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
nonce.as_ref().into(),
);
let decryptor: DecryptorLE31<XChaCha20Poly1305> =
chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
XChaCha20Poly1305::new(key.as_ref().into()),
nonce.as_ref().into(),
);
(encryptor, decryptor)
}
Now that we know how to get the encryptor and decryptor, it is time to change the send_data
and receive_data
to use them to apply encryption to the data. Luckily, the encryption crate is very easy to use, we just need to call encrypt_next
and decrypt_next
.
fn send_data(
stream: Arc<TcpStream>,
// receives the encryptor as a parameter.
mut encryptor: EncryptorLE31<XChaCha20Poly1305>,
// flag to enable encryption.
encryption_enabled: bool,
) -> Sender<OutboundMessage> {
let (tx, rx) = std::sync::mpsc::channel();
thread::spawn(move || {
let stream = &*stream;
while let Ok(OutboundMessage(bytes)) = rx.recv() {
if encryption_enabled {
// if encryption is enabled, encrypts the whole message
// and send the encrypted result to the tcp stream
let Ok(encrypted) = encryptor.encrypt_next(&bytes[..]) else {
break;
};
if stream.write_all(&encrypted).is_err() {
break;
}
} else if stream.write_all(&bytes).is_err() {
break;
}
}
});
tx
}
fn receive_data(
stream: Arc<TcpStream>,
// receives the decryptor as a parameter.
mut decryptor: DecryptorLE31<XChaCha20Poly1305>,
// flag to enable encryption.
encryption_enabled: bool,
) -> Receiver<InboundMessage> {
let (tx, rx) = std::sync::mpsc::channel();
thread::spawn(move || {
let stream = &*stream;
let mut buf = [0u8; 4096];
while let Ok(n) = stream.read(&mut buf) {
if n == 0 {
break;
}
if tx.send(InboundMessage(buf[..n].to_vec())).is_err() {
break;
}
if encryption_enabled {
// If encryption is enabled, decrypt what came
// from the stream and send to the channel
let Ok(decrypted) = decryptor.decrypt_next(&buf[..n]) else {
break;
};
if tx.send(InboundMessage(decrypted)).is_err() {
break;
}
} else if tx.send(InboundMessage(buf[..n].to_vec())).is_err() {
break;
}
}
});
rx
}
Both functions send_data
and receive_data
were changed to optionally use encryption over the stream.
Amazing, another well done job! Everything works fine! It is time to pack the bag and go home. Except, it doesn’t work! I mean, the code would work fine with a simple “hello” message. The data would be encrypted, sent over the wire, and decrypted on the other side. However, the code has a fundamental flaw.
The issue is that the encrypt_next
and decrypt_next
calls must be equal on both ends. For each call to encrypt_next
, decrypt_next
must be called exactly once, because the nonce value changes every time it is used.
Imagine a message of 100 bytes. After encryption, the message will be longer, the exact size will depend on the encryption used, but let’s say it is 120 bytes. When read
is called on the other side, it will read what is available in the buffer, which may or may not be 120 bytes. The function decrypt_next
must be called with the same data generated by encrypt_next
.
To solve this issue, it is necessary to know exactly how many bytes must be sent to the next decrypt_next
call. This can be done by sending the length of the message before sending the message itself.
fn send_data(
stream: Arc<TcpStream>,
mut encryptor: EncryptorLE31<XChaCha20Poly1305>,
encryption_enabled: bool,
) -> Sender<OutboundMessage> {
let (tx, rx) = std::sync::mpsc::channel();
thread::spawn(move || {
let stream = &*stream;
while let Ok(OutboundMessage(bytes)) = rx.recv() {
if encryption_enabled {
let Ok(encrypted) = encryptor.encrypt_next(&bytes[..]) else {
break;
};
// Get the size of the message and send to the stream.
// This is the only necessary change.
let size = (encrypted.len() as u16).to_be_bytes();
if stream.write_all(&size).is_err() {
break;
}
if stream.write_all(&encrypted).is_err() {
break;
}
} else {
// Get the size of the message and send to the stream.
// This is the only necessary change.
let size = (bytes.len() as u16).to_be_bytes();
if stream.write_all(&size).is_err() {
break;
}
if stream.write_all(&bytes).is_err() {
break;
}
}
}
});
tx
}
Sending is the easiest part. Just get the length of the message, cast it to a bytes array, and write it to the stream before the actual data.
Just a few observations before we proceed:
- Never use
usize
when sending binary data over the network. Different architectures may have differentusize
sizes. It is better to be explicit here. - 16 bits is a good size for most scenarios, but it may be necessary to use a bigger value if you need to deal with very big data.
to_be_bytes()
returns the big endian bytes representation of the number- Calling
write_all
multiple times without buffering will result in poor performance.
Welcome to network programming!
fn receive_data(stream: Arc<TcpStream>,
mut decryptor: DecryptorLE31<XChaCha20Poly1305>,
encryption_enabled: bool,
) -> Receiver<InboundMessage> {
let (tx, rx) = std::sync::mpsc::channel();
thread::spawn(move || {
let stream = &*stream;
let mut buf = vec![0u8; 4096];
// Keep track of how many bytes across multiple reads
let mut position = 0;
while let Ok(n) = stream.read(&mut buf[position..]) {
if n == 0 {
break;
}
// Move the position to account for the new data.
position += n;
// buf may have enough data for multiple messages
// we are going to try to process everything.
loop {
// Since the length is u16, we need at least 2 bytes.
if position < 2 {
break;
}
// Read the first 2 bytes to get the length of the message
let length_bytes = [buf[0], buf[1]];
let end = u16::from_be_bytes(length_bytes) as usize + 2;
// Verify if the whole message is in the buffer.
if position < end {
break;
}
let payload = &buf[2..end];
if encryption_enabled {
let Ok(decrypted) = decryptor.decrypt_next(payload) else {
break;
};
if tx.send(InboundMessage(decrypted)).is_err() {
break;
}
} else if tx.send(InboundMessage(payload.to_vec())).is_err() {
break;
}
// Shift the contents of the buffer to the start
if position > end {
buf.copy_within(end..position, 0);
position -= end;
} else {
position = 0;
}
}
}
});
rx
}
With the changes in place, the decryption should work, but the code is becoming complex and error-prone, and it has huge limitations.
Two bugs are present in the current implementation.
- The code fails if a message is longer than 4094 bytes. The solution for this is to use a dynamic-sized buffer.
- If decryption fails,
receive_data
will loop without resetting the buffer position, leading to an infinite loop.
On top of the bugs, the current approach mixes encrypted and plain communication in the same place, and also, extending this code is becoming hard due to the complexity.
I also mentioned before that calling write_all
repeatedly was a bad idea, let’s address that issue by using BufWriter.
BufWriter
is quite easy to use, we just need to pass the stream to it.
fn send_data(...) -> Sender<OutboundMessage> {
// ommited for brevity.
let stream = &*stream;
// We just need to add this line to the code
let mut stream = BufWriter::new(stream);
while let Ok(OutboundMessage(bytes)) = rx.recv() {
// ommited for brevity.
}
}
Neat. But how can the change be that simple? Is it not even necessary to change the rest of the code?
The reason this works is because both TcpStream
and BufWriter
types implement a trait called Write, in fact, the function write_all
is defined in the Write
trait, not in the TcpStream
struct.
Now, it is time to talk about traits!
Traits
What are traits anyway?
The rust book defines traits as a way to define shared behavior in an abstract way. A different way of defining a trait is a way to isolate an aspect of something, because when you work with a trait, you only care about a specific aspect of the whole.
I will try to explain this with an analogy.
Back to Rust, the Write trait is the “socket” from our analogy. When a type implements the Write trait, it gains the ability to be treated by its Write aspect, which creates the possibility of ignoring every other aspect of the struct. When the code sees a Write
type, it doesn’t care where or how the data is written, it only cares that the type can write the data somewhere, be it a file, a network connection, or a memory buffer.
Using Traits in the code
Now that we have an idea of what traits are, it is time to rewrite the code with traits to make it better.
Let’s start with the plain text communication by implementing two structs, PlainReadHalf and PlainWriteHalf. The two structs will implement the Read and Write traits respectively.
/// PlainReadHalf has an inner reader to read from
/// This reads as "inner can be of any type T as long as T implements Read"
pub struct PlainReadHalf<T> where T: Read {
inner: T,
}
impl<T> PlainReadHalf<T> where T: Read {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
/// Implements Read for the struct PlainReadHalf<T>
impl<T> Read for PlainReadHalf<T> where T: Read {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// Reads exactly 2 bytes from the inner reader
let mut len_buf = [0u8; 2];
self.inner.read_exact(&mut len_buf)?;
let len = u16::from_be_bytes(len_buf) as usize;
if buf.len() < len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Provided buffer is not big enough",
));
}
// Reads exactly the amount of bytes necessary for the message
self.inner.read_exact(&mut buf[..len])?;
Ok(len)
}
}
In the code above, there are two important details about Traits.
The PlainReadHalf
struct has a generic type parameter T
that is constrained to anything that implements Read, which means T can be a BufReader
, a TcpStream
, or anything that implements the Read
trait.
Then comes the implementation of the Read
trait for the struct itself. With this, PlainReadHalf
can be used anywhere a Read
type is expected.
The where T: Read
syntax is where the code isolates the Read aspect of T
. By writing this, T
can be anything, but we only care about the Read
part.
Now let’s look the PlainWriteHalf
struct
/// PlainWriteHalf has an inner writer to write to,
/// It also has a buffer to hold the information before sending to the writer
pub struct PlainWriteHalf<T> where T: Write {
inner: T,
buffer: Vec<u8>,
}
impl<T> PlainWriteHalf<T> where T: Write {
pub fn new(inner: T) -> Self {
Self {
inner,
buffer: Vec::new(),
}
}
}
// Implements Write for the PlainWriteHalf struct
impl<T> Write for PlainWriteHalf<T> where T: Write {
// Write function just write the contents to the internal buffer
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.buffer.extend_from_slice(buf);
Ok(buf.len())
}
// Flush is the function that actually send
// the information to the inner writer
fn flush(&mut self) -> std::io::Result<()> {
let bytes = (self.buffer.len() as u16).to_be_bytes().to_vec();
self.inner.write_all(&bytes)?;
self.inner.write_all(&self.buffer)?;
self.inner.flush()?;
self.buffer.clear();
Ok(())
}
}
One thing you may be wondering is, what is the benefit of this? That seems to be a lot more code to do the same thing as before. Why go through all this trouble?
Let’s check the first benefit of traits by writing a test for this struct.
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use std::io::Write;
#[test]
pub fn test_plain_write_read() {
// Cursor<&mut Vec<T>> implements Write,
// it can be used with our PlainWriteHalf struct
let mut buf = vec![0u8; 512];
let mut writer = PlainWriteHalf::new(Cursor::new(&mut buf));
let _ = writer.write_all(&b"some bytes"[..]);
let _ = writer.flush();
assert_eq!(buf[2..12], b"some bytes"[..]);
// Cursor<& Vec<T>> implements Read,
// it can be used with our PlainReadHalf struct
let mut reader = PlainReadHalf::new(Cursor::new(&buf));
let mut output = String::new();
let _ = reader.read_to_string(&mut output);
assert_eq!(output, "some bytes");
}
}
Notice that we don’t need a TCP connection to test the code, we did manage to write the test for our struct implementation by using another type that implements the necessary traits.
The encrypted implementation will be very similar to the plain text implementation, except it will have an extra step.
/// EncryptedReadHalf has the inner writer and a decryptor
pub struct EncryptedReadHalf<T>
where
T: Read,
{
inner: T,
decryptor: DecryptorLE31<XChaCha20Poly1305>,
}
impl<T> EncryptedReadHalf<T>
where
T: Read,
{
pub fn new(inner: T, decryptor: DecryptorLE31<XChaCha20Poly1305>) -> Self {
Self { inner, decryptor }
}
}
impl<T> Read for EncryptedReadHalf<T>
where
T: Read,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// First read the length of the payload
let mut len_buf = [0u8; 2];
self.inner.read_exact(&mut len_buf)?;
let payload_len = u16::from_be_bytes(len_buf) as usize;
// Reads the payload
let mut payload = vec![0u8; payload_len];
self.inner.read_exact(&mut payload)?;
// Decrypts the payload
let decrypted = self
.decryptor
.decrypt_next(&payload[..])
.map_err(|err| {
std::io::Error::new(std::io::ErrorKind::InvalidData, err)
})?;
let decrypted_len = decrypted.len();
if buf.len() < decrypted_len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"Provided buffer is not big enough",
));
}
// Writes the plain content to the buffer
buf[..decrypted_len].copy_from_slice(&decrypted[..]);
Ok(decrypted_len)
}
}
The WriteHalf implementation is also very similar to the PlainWriteHalf, with just an extra step.
pub struct EncryptedWriteHalf<T>
where
T: Write,
{
inner: T,
encryptor: EncryptorLE31<XChaCha20Poly1305>,
buffer: Vec<u8>,
}
impl<T> EncryptedWriteHalf<T>
where
T: Write,
{
pub fn new(inner: T, encryptor: EncryptorLE31<XChaCha20Poly1305>) -> Self {
Self {
inner,
encryptor,
buffer: Vec::new(),
}
}
}
impl<T> Write for EncryptedWriteHalf<T>
where
T: Write,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.buffer.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
// Encrypts the contents of the internal buffer
let encrypted = self
.encryptor
.encrypt_next(&self.buffer[..])
.map_err(|err| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, err)
})?;
// Writes the encrypted value to the internal writer
let len_bytes = (encrypted.len() as u16).to_be_bytes().to_vec();
self.inner.write_all(&len_bytes)?;
self.inner.write_all(&encrypted)?;
self.inner.flush()?;
// Clean the internal buffer
self.buffer.clear();
Ok(())
}
}
And again, let’s write a test case.
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
use std::io::Write;
#[test]
pub fn test_encryption() {
let mut buf = vec![0u8; 512];
let (encryptor, decryptor) =
super::super::get_encryptor_and_decryptor([0u8; 32], [0u8; 20]);
let mut writer = EncryptedWriteHalf::new(Cursor::new(&mut buf), encryptor);
let _ = writer.write_all(&b"some bytes"[..]);
let _ = writer.flush();
let mut reader = EncryptedReadHalf::new(Cursor::new(&buf), decryptor);
let mut output = String::new();
let _ = reader.read_to_string(&mut output);
assert_eq!(output, "some bytes");
}
}
The only thing left is to adapt the rest of the code to use the new structs. With the new structs, the thread
handling will change because of the rust borrower checker.
/// send_data receives a writer, it can be anything that implements Write
/// It will work with PlainWriteHalf or EncryptedWriteHalf
fn send_data(mut stream: impl Write, rx: Receiver<OutboundMessage>) {
while let Ok(OutboundMessage(bytes)) = rx.recv() {
if stream.write_all(&bytes).is_err() {
break;
}
if stream.flush().is_err() {
break;
}
}
}
/// receive_data receives a reader, it can be anything that implements Read
/// It will work with PlainReadHalf or EncryptedReadHalf
fn receive_data(mut stream: impl Read, tx: Sender<InboundMessage>) {
let mut buf = vec![0u8; 4096];
while let Ok(n) = stream.read(&mut buf[..]) {
if n == 0 {
break;
}
if tx.send(InboundMessage(buf[..n].to_vec())).is_err() {
break;
}
}
}
pub fn connect_to_host(
addr: &str,
key: [u8; 32],
nonce: [u8; 20],
encryption_enabled: bool,
) -> (Sender<OutboundMessage>, Receiver<InboundMessage>) {
// Connects to the host
// &mut &TcpStream implements both Read and Write traits
// So we create two copies of it to send to each thread
let stream = Arc::new(TcpStream::connect(addr).expect("Failed to connect"));
// Each copy will be moved to a different thread
let read_half = stream.clone();
let write_half = stream.clone();
// Initialize the channels to be returned from this function
let (out_tx, out_rx) = std::sync::mpsc::channel();
let (in_tx, in_rx) = std::sync::mpsc::channel();
if encryption_enabled {
let (encryptor, decryptor) = get_encryptor_and_decryptor(key, nonce);
// Moves read_half and decryptor to a new thread
thread::spawn(move || {
let read_half = &*read_half;
receive_data(
encrypted::EncryptedReadHalf::new(read_half, decryptor),
in_tx,
);
});
// Moves write_half and encryptor to a new thread
thread::spawn(move || {
let write_half = &*write_half;
send_data(
encrypted::EncryptedWriteHalf::new(write_half, encryptor),
out_rx,
);
});
} else {
// Moves read_half to a new thread
thread::spawn(move || {
let read_half = &*read_half;
receive_data(plain::PlainReadHalf::new(read_half), in_tx);
});
// Moves write_half to a new thread
thread::spawn(move || {
let write_half = &*write_half;
send_data(plain::PlainWriteHalf::new(write_half), out_rx);
});
}
(out_tx, in_rx)
}
In this new implementation, the code can be easily extended to add new features.
For instance, if you want to data compression, a new struct can be used to implement the Read and Write traits.
thread::spawn(move || {
let read_half = &*read_half;
receive_data(
compression::CompressedReadHalf::new(
encrypted::EncryptedReadHalf::new(read_half, decryptor)
),
in_tx,
);
});
thread::spawn(move || {
let write_half = &*write_half;
send_data(
compression::CompressedWriteHalf::new(
encrypted::EncryptedWriteHalf::new(write_half, encryptor)
),
out_rx,
);
});
Dynamic use of traits.
There is one thing in the connect_to_host
function above that I don’t like, there is a lot of repetition. It is possible to slim it down by spawning the threads only once.
let (reader, writer) = if encryption_enabled {
let (encryptor, decryptor) = get_encryptor_and_decryptor(key, nonce);
let reader = encrypted::EncryptedReadHalf::new(read_half, decryptor);
ler writer = encrypted::EncryptedWriteHalf::new(write_half, encryptor);
(reader, writer)
} else {
let reader = plain::PlainReadHalf::new(read_half);
ler writer = plain::PlainWriteHalf::new(write_half);
(reader, writer)
}
thread::spawn(move || {
receive_data(reader, in_tx);
});
thread::spawn(move || {
send_data(writer, out_rx);
});
The code above is a lot cleaner, however, it does not compile because both paths of the if statement must return the same type, which is not happening now. The same problem happens when calling receive_data
and send_data
functions, the type must be known before calling. To understand how that can be fixed, we need to first talk about generics and dynamic dispatch.
When you write rust code with a generic parameter <T>
, the compiler will use monomorphization to generate code with concrete types, for instance:
fn my_func<T>() {}
my_func<u32>(); // This will generate my_func_u32 when compiled
my_func<u16>(); // This will generate my_func_u16 when compiled
struct MyStruct<T> {
inner: T
}
var s = MyStruct { inner: 0u32 }; // MyStruct_u32
my_func<MyStruct<u32>>(); // This will generate my_func_my_struct_u32
The only restriction with generics is that you need to know the concrete types at compile time, which is not always the case. When the type cannot be defined, such as when it depends on a user-provided parameter, we can use dynamic dispatch. The way to do that in Rust is to use the dyn Trait
syntax. For instance.
fn send_data<T: Write>(writer: T) // Concrete generic type
fn send_data(writer: impl Write) // This is equivalent to the one above
fn send_data(writer: &dyn Write) // Dynamic dispatch with a reference
fn send_data(writer: Box<dyn Write>) // Dynamic dispatch with owned type
// Returning a type that implements a trait
// Concrete generic type, no branching is allowed inside the function
fn get_writer<T: Write>() -> T;
fn get_writer() -> impl Write;
// Dynamic dispatch, it is allowed to branch and return different types
fn get_writer() -> Box<dyn Write>
Let’s see how it is possible to change the connect_to_host
function to use dynamic dispatch.
pub fn connect_to_host(
addr: &str,
key: [u8; 32],
nonce: [u8; 20],
encryption_enabled: bool,
out_rx: Receiver<OutboundMessage>,
in_tx: Sender<InboundMessage>,
) {
let stream = Arc::new(TcpStream::connect(addr).expect("Failed to connect"));
// explicitly declare reader: Box<dyn Read + Send> and writer: Box<dyn Write + Send>
let (reader, writer): (Box<dyn Read + Send>, Box<dyn Write + Send>) = if encryption_enabled {
let (encryptor, decryptor) = get_encryptor_and_decryptor(key, nonce);
let reader = encrypted::EncryptedReadHalf::new(&*stream, decryptor);
let writer = encrypted::EncryptedWriteHalf::new(&*stream, encryptor);
// Send reader and writer to the heap, by boxing them
// this will allow the dynamic dispatch
(Box::new(reader), Box::new(writer))
} else {
let reader = plain::PlainReadHalf::new(&*stream);
let writer = plain::PlainWriteHalf::new(&*stream);
// Send reader and writer to the heap, by boxing them
// this will allow the dynamic dispatch
(Box::new(reader), Box::new(writer))
};
thread::scope(|s| {
s.spawn(move || {
send_data(writer, out_rx);
});
s.spawn(move || {
receive_data(reader, in_tx);
});
})
}
The syntax Trait + AnotherTrait
specifies how multiple trait constraints are defined in Rust. In our case, Read + Send
is must implement Read and Send. Send is how Rust knows that a type can be sent to another thread.
In this particular case, it is not necessary to change send_data
and receive_data
signatures because Box implements both Write and Read, but this is how it would look like.
fn send_data<'a>(mut stream: Box<dyn Write + 'a>, rx: Receiver<OutboundMessage>)
fn receive_data<'a>(mut stream: Box<dyn Read + 'a>, tx: Sender<InboundMessage>)
One last observation before we wrap up. Since we are talking about improving the code with traits, it wouldn’t be fair if I don’t show how to make the encryption algorithm generic as well.
use std::io::{Read, Write};
use std::ops::Sub;
use chacha20poly1305::{
aead::{
generic_array::ArrayLength,
stream::{Decryptor, Encryptor, NonceSize, StreamPrimitive},
},
AeadInPlace,
};
// Introduces the U parameter for the decryptor.
// Notice we don't need to add the constraints here.
pub struct EncryptedReadHalf<T, U> {
inner: T,
decryptor: U,
}
// Introduces all the constraints necessary for the Decryptor to work
// This allow us to use the struct with any stream decryption algorithm
impl<T, A, S> EncryptedReadHalf<T, Decryptor<A, S>>
where
T: Read,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
pub fn new(inner: T, decryptor: Decryptor<A, S>) -> Self {
Self { inner, decryptor }
}
}
impl<T, A, S> Read for EncryptedReadHalf<T, Decryptor<A, S>>
where
T: Read,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
// ...
}
// Introduces the U parameter for the encryptor.
// Notice we don't need to add the constraints here.
pub struct EncryptedWriteHalf<T, U> {
inner: T,
encryptor: U,
}
// Introduces all the constraints necessary for the Encryptro to work
// This allow us to use the struct with any stream encryption algorithm
impl<T, A, S> EncryptedWriteHalf<T, Encryptor<A, S>>
where
T: Write,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
pub fn new(inner: T, encryptor: Encryptor<A, S>) -> Self {
Self {
inner,
encryptor,
}
}
}
impl<T, A, S> Write for EncryptedWriteHalf<T, Encryptor<A, S>>
where
T: Write,
S: StreamPrimitive<A>,
A: AeadInPlace,
A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
NonceSize<A, S>: ArrayLength<u8>,
{
// ...
}
Conclusion
In this post, we have learned how to use traits to isolate code behavior and write cleaner code.
With the use of Traits, we can:
- Isolate and decouple code behavior.
- Test specific parts of the code.
- Reuse and share specific behaviors.
- Compose behavior by using multiple traits.
If you were struggling to understand what are Traits, or what are the benefits of using them, I hope this post have helped you.