diff --git a/virtio-vsock/Cargo.toml b/virtio-vsock/Cargo.toml index 16719608..6a0296c3 100644 --- a/virtio-vsock/Cargo.toml +++ b/virtio-vsock/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "virtio-vsock" -version = "0.11.0" +version = "0.12.0" authors = ["rust-vmm community", "rust-vmm AWS maintainers "] description = "virtio vsock device implementation" repository = "https://github.com/rust-vmm/vm-virtio" diff --git a/virtio-vsock/README.md b/virtio-vsock/README.md index 0819108f..4efe75e0 100644 --- a/virtio-vsock/README.md +++ b/virtio-vsock/README.md @@ -24,31 +24,27 @@ operations are of the `VIRTIO_VSOCK_OP_RW` type, which means for data transfer, and the other ones are used for connection and buffer space management. `data` is non-empty only for the `VIRTIO_VSOCK_OP_RW` operations. -The abstraction used for the packet implementation is the `VsockPacket`. -It is using -[`VolatileSlice`](https://github.com/rust-vmm/vm-memory/blob/fc7153a4f63c352d1fa9419c4654a6c9aec408cb/src/volatile_memory.rs#L266)s -for representing the header and the data. We chose to use the `VolatileSlice` -because it's a safe wrapper over the unsafe Rust's raw pointers, and it is also -generic enough to allow creating packets from pointers to slices. Going with a -`GuestMemory` based approach would not make such configuration possible. -More details (including design -limitations) in [the `packet`'s module-level documentation](src/packet.rs). - -A `VsockPacket` instance is created by parsing a descriptor chain from either -the TX or the RX virtqueue. The `VsockPacket` API is also providing methods for -creating/setting up packets directly from pointers to slices. -It also offers setters and getters for each `virtio_vsock_hdr` field (e.g. -*src_cid*, *dst_port*, *op*). +The abstractions used for the packet implementation are `VsockPacketTx` and +`VsockPacketRx`. `VsockPacketTx` uses a `Reader` from `virtio_queue` to access +the device-readable packet data and stores a copy of the `PacketHeader`. +`VsockPacketRx` uses `Writer`s from `virtio_queue` for the header and data +portions of the device-writable buffers. More details in +[the `packet`'s module-level documentation](src/packet.rs). + +A `VsockPacketTx` or `VsockPacketRx` instance is created by parsing a +descriptor chain from the TX or the RX virtqueue respectively. The +`PacketHeader` struct offers setters and getters for each `virtio_vsock_hdr` +field (e.g. *src_cid*, *dst_port*, *op*). ### Usage The driver queues receive buffers on the RX virtqueue, and outgoing packets on -the TX virtqueue. The device processes the RX virtqueue using -`VsockPacket::from_rx_virtq_chain` and fills the buffers with data from the +the TX virtqueue. The device processes the RX virtqueue using +`VsockPacketRx::from_rx_virtq_chain` and fills the buffers with data from the vsock backend. On the TX side, the device processes the TX queue using -`VsockPacket::from_tx_virtq_chain`, packages the read buffers as vsock packets, -and then sends them to the backend. +`VsockPacketTx::from_tx_virtq_chain`, packages the read buffers as vsock +packets, and then sends them to the backend. ### Examples diff --git a/virtio-vsock/src/packet.rs b/virtio-vsock/src/packet.rs index 16dab0dc..bb99a8cc 100644 --- a/virtio-vsock/src/packet.rs +++ b/virtio-vsock/src/packet.rs @@ -4,67 +4,42 @@ //! Vsock packet abstraction. //! -//! This module provides the following abstraction for parsing a vsock packet, and working with it: +//! This module provides the following abstractions for parsing a vsock packet and working with it: +//! +//! - [`VsockPacketTx`](struct.VsockPacketTx.html) which handles parsing a vsock packet from a TX +//! descriptor chain via +//! [`VsockPacketTx::from_tx_virtq_chain`](struct.VsockPacketTx.html#method.from_tx_virtq_chain). +//! It uses a [`Reader`](virtio_queue::Reader) to access the device-readable packet data, and +//! stores a copy of the [`PacketHeader`](struct.PacketHeader.html). +//! - [`VsockPacketRx`](struct.VsockPacketRx.html) which handles parsing a vsock packet from an RX +//! descriptor chain via +//! [`VsockPacketRx::from_rx_virtq_chain`](struct.VsockPacketRx.html#method.from_rx_virtq_chain). +//! It uses [`Writer`](virtio_queue::Writer)s for the +//! header and data portions of the device-writable buffers. //! -//! - [`VsockPacket`](struct.VsockPacket.html) which handles the parsing of the vsock packet from -//! either a TX descriptor chain via -//! [`VsockPacket::from_tx_virtq_chain`](struct.VsockPacket.html#method.from_tx_virtq_chain), or an -//! RX descriptor chain via -//! [`VsockPacket::from_rx_virtq_chain`](struct.VsockPacket.html#method.from_rx_virtq_chain). //! The virtio vsock packet is defined in the standard as having a header of type `virtio_vsock_hdr` -//! and an optional `data` array of bytes. The methods mentioned above assume that both packet -//! elements are on the same descriptor, or each of the packet elements occupies exactly one -//! descriptor. For the usual drivers, this assumption stands, -//! but in the future we might make the implementation more generic by removing any constraint -//! regarding the number of descriptors that correspond to the header/data. The buffers associated -//! to the TX virtio queue are device-readable, and the ones associated to the RX virtio queue are -//! device-writable. -/// -/// The `VsockPacket` abstraction is using vm-memory's `VolatileSlice` for representing the header -/// and the data. `VolatileSlice` is a safe wrapper over a raw pointer, which also handles the dirty -/// page tracking behind the scenes. A limitation of the current implementation is that it does not -/// cover the scenario where the header or data buffer doesn't fit in a single `VolatileSlice` -/// because the guest memory regions of the buffer are contiguous in the guest physical address -/// space, but not in the host virtual one as well. If this becomes an use case, we can extend this -/// solution to use an array of `VolatileSlice`s for the header and data. -/// The `VsockPacket` abstraction is also storing a `virtio_vsock_hdr` instance (which is defined -/// here as `PacketHeader`). This is needed so that we always access the same data that was read the -/// first time from the descriptor chain. We avoid this way potential time-of-check time-of-use -/// problems that may occur when reading later a header field from the underlying memory itself -/// (i.e. from the header's `VolatileSlice` object). +//! and an optional `data` array of bytes. The descriptor chain layout is handled transparently by +//! the `Reader`/`Writer` abstractions from `virtio_queue`. The buffers associated to the TX virtio +//! queue are device-readable, and the ones associated to the RX virtio queue are device-writable. + use std::fmt::{self, Display}; use std::ops::Deref; -use virtio_queue::DescriptorChain; +use virtio_queue::{DescriptorChain, Reader, Writer}; use vm_memory::bitmap::{BitmapSlice, WithBitmapSlice}; -use vm_memory::{ - Address, ByteValued, Bytes, GuestAddress, GuestMemory, GuestMemoryError, Le16, Le32, Le64, - Permissions, VolatileMemoryError, VolatileSlice, -}; +use vm_memory::{ByteValued, GuestMemory, Le16, Le32, Le64}; /// Vsock packet parsing errors. #[derive(Debug)] pub enum Error { - /// Too few descriptors in a descriptor chain. - DescriptorChainTooShort, /// Descriptor that was too short to use. DescriptorLengthTooSmall, /// Descriptor that was too long to use. DescriptorLengthTooLong, - /// Data stretches over multiple memory fragments - FragmentedMemory, - /// The slice for creating a header has an invalid length. - InvalidHeaderInputSize(usize), + /// Invalid descriptor chain (e.g. missing descriptors, out-of-bounds memory, or overflow). + InvalidChain, /// The `len` header field value exceeds the maximum allowed data size. InvalidHeaderLen(u32), - /// Invalid guest memory access. - InvalidMemoryAccess(GuestMemoryError), - /// Invalid volatile memory access. - InvalidVolatileAccess(VolatileMemoryError), - /// Read only descriptor that protocol says to write to. - UnexpectedReadOnlyDescriptor, - /// Write only descriptor that protocol says to read from. - UnexpectedWriteOnlyDescriptor, } impl std::error::Error for Error {} @@ -72,9 +47,6 @@ impl std::error::Error for Error {} impl Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Error::DescriptorChainTooShort => { - write!(f, "There are not enough descriptors in the chain.") - } Error::DescriptorLengthTooSmall => write!( f, "The descriptor is pointing to a buffer that has a smaller length than expected." @@ -83,27 +55,12 @@ impl Display for Error { f, "The descriptor is pointing to a buffer that has a longer length than expected." ), - Error::FragmentedMemory => { - write!(f, "Data stretches over multiple memory fragments.") - } - Error::InvalidHeaderInputSize(size) => { - write!(f, "Invalid header input size: {size}") + Error::InvalidChain => { + write!(f, "Invalid descriptor chain.") } Error::InvalidHeaderLen(size) => { write!(f, "Invalid header `len` field value: {size}") } - Error::InvalidMemoryAccess(error) => { - write!(f, "Invalid guest memory access: {error}") - } - Error::InvalidVolatileAccess(error) => { - write!(f, "Invalid volatile memory access: {error}") - } - Error::UnexpectedReadOnlyDescriptor => { - write!(f, "Unexpected read-only descriptor.") - } - Error::UnexpectedWriteOnlyDescriptor => { - write!(f, "Unexpected write-only descriptor.") - } } } } @@ -128,421 +85,253 @@ pub struct PacketHeader { // and all accesses through safe `vm-memory` API will validate any garbage that could // be included in there. unsafe impl ByteValued for PacketHeader {} -// -// This structure will occupy the buffer pointed to by the head of the descriptor chain. Below are -// the offsets for each field, as well as the packed structure size. -// Note that these offsets are only used privately by the `VsockPacket` struct, the public interface -// consisting of getter and setter methods, for each struct field, that will also handle the correct -// endianness. - -/// The size of the header structure (when packed). -pub const PKT_HEADER_SIZE: usize = std::mem::size_of::(); -// Offsets of the header fields. -const SRC_CID_OFFSET: usize = 0; -const DST_CID_OFFSET: usize = 8; -const SRC_PORT_OFFSET: usize = 16; -const DST_PORT_OFFSET: usize = 20; -const LEN_OFFSET: usize = 24; -const TYPE_OFFSET: usize = 28; -const OP_OFFSET: usize = 30; -const FLAGS_OFFSET: usize = 32; -const BUF_ALLOC_OFFSET: usize = 36; -const FWD_CNT_OFFSET: usize = 40; - -/// Dedicated [`Result`](https://doc.rust-lang.org/std/result/) type. -pub type Result = std::result::Result; - -/// The vsock packet, implemented as a wrapper over a virtio descriptor chain: -/// - the chain head, holding the packet header; -/// - an optional data/buffer descriptor, only present for data packets (for VSOCK_OP_RW requests). -#[derive(Debug)] -pub struct VsockPacket<'a, B: BitmapSlice> { - // When writing to the header slice, we are using the `write` method of `VolatileSlice`s Bytes - // implementation. Because that can only return an error if we pass an invalid offset, we can - // safely use `unwraps` in the setters below. If we switch to a type different than - // `VolatileSlice`, this assumption can no longer hold. We also must always make sure the - // `VsockPacket` API is creating headers with PKT_HEADER_SIZE size. - header_slice: VolatileSlice<'a, B>, - header: PacketHeader, - data_slice: Option>, -} +impl PacketHeader { + /// Set the `src_cid` field. + pub fn set_src_cid(&mut self, src_cid: u64) -> &mut Self { + self.src_cid = src_cid.into(); + self + } -// This macro is intended to be used for setting a header field in both the `VolatileSlice` and the -// `PacketHeader` structure from a packet. `$offset` should be a valid offset in the `header_slice`, -// otherwise the macro will panic. -macro_rules! set_header_field { - ($packet:ident, $field:ident, $offset:ident, $value:ident) => { - $packet.header.$field = $value.into(); - $packet - .header_slice - .write(&$value.to_le_bytes(), $offset) - // This unwrap is safe only if `$offset` is a valid offset in the `header_slice`. - .unwrap(); - }; -} + /// Set the `dst_cid` field. + pub fn set_dst_cid(&mut self, dst_cid: u64) -> &mut Self { + self.dst_cid = dst_cid.into(); + self + } -/// Get a single slice for `[addr, addr + count)`. -/// -/// This is a replacement for the deprecated `GuestMemory::get_slice()` function: It calls -/// `mem.get_slices()` and will return the first slice from the iterator, if any. If that slice -/// does not cover the request length (i.e. the requested region would translate into multiple -/// slices), return `Err(Error::FragmentedMemory)`. -/// -/// If `count == 0`, this function will always return `Ok(None)`. Otherwise, it will always return -/// an error or `Ok(Some(slice))`. -fn get_single_slice<'a, M: GuestMemory, B: BitmapSlice>( - mem: &'a M, - addr: GuestAddress, - count: usize, - access: Permissions, -) -> Result>> -where - M::Bitmap: WithBitmapSlice<'a, S = B>, -{ - if count == 0 { - return Ok(None); + /// Set the `src_port` field. + pub fn set_src_port(&mut self, src_port: u32) -> &mut Self { + self.src_port = src_port.into(); + self } - let slice = mem - .get_slices(addr, count, access) - .map_err(Error::InvalidMemoryAccess)? - .next() - .expect("Expecting some result for a non-empty memory region") - .map_err(Error::InvalidMemoryAccess)?; - - if slice.len() == count { - Ok(Some(slice)) - } else { - Err(Error::FragmentedMemory) + /// Set the `dst_port` field. + pub fn set_dst_port(&mut self, dst_port: u32) -> &mut Self { + self.dst_port = dst_port.into(); + self } -} -impl<'a, B: BitmapSlice> VsockPacket<'a, B> { - /// Return a reference to the `header_slice` of the packet. - pub fn header_slice(&self) -> &VolatileSlice<'a, B> { - &self.header_slice + /// Set the `len` field. + pub fn set_len(&mut self, len: u32) -> &mut Self { + self.len = len.into(); + self } - /// Return a reference to the `data_slice` of the packet. - pub fn data_slice(&self) -> Option<&VolatileSlice<'a, B>> { - self.data_slice.as_ref() + /// Set the `type_` field. + pub fn set_type(&mut self, type_: u16) -> &mut Self { + self.type_ = type_.into(); + self } - /// Write to the packet header from an input of raw bytes. - /// - /// # Example - /// - /// ```rust - /// # use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE; - /// # use virtio_queue::mock::MockSplitQueue; - /// # use virtio_queue::{desc::{split::Descriptor as SplitDescriptor, RawDescriptor}, Queue, QueueT}; - /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; - /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap}; - /// - /// const MAX_PKT_BUF_SIZE: u32 = 64 * 1024; - /// - /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue { - /// # let vq = MockSplitQueue::new(m, 16); - /// # let mut q = vq.create_queue().unwrap(); - /// # - /// # let v = vec![ - /// # RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, VRING_DESC_F_WRITE as u16, 0)), - /// # RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, VRING_DESC_F_WRITE as u16, 0)), - /// # ]; - /// # let mut chain = vq.build_desc_chain(&v); - /// # q - /// # } - /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); - /// // Create a queue and populate it with a descriptor chain. - /// let mut queue = create_queue_with_chain(&mem); - /// - /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) { - /// let mut pkt = VsockPacket::from_rx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE).unwrap(); - /// pkt.set_header_from_raw(&[0u8; PKT_HEADER_SIZE]).unwrap(); - /// } - /// ``` - pub fn set_header_from_raw(&mut self, bytes: &[u8]) -> Result<()> { - if bytes.len() != PKT_HEADER_SIZE { - return Err(Error::InvalidHeaderInputSize(bytes.len())); - } - self.header_slice - .write(bytes, 0) - .map_err(Error::InvalidVolatileAccess)?; - let header = self - .header_slice() - .read_obj::(0) - .map_err(Error::InvalidVolatileAccess)?; - self.header = header; - Ok(()) + /// Set the `op` field. + pub fn set_op(&mut self, op: u16) -> &mut Self { + self.op = op.into(); + self } - /// Return the `src_cid` of the header. - pub fn src_cid(&self) -> u64 { - self.header.src_cid.into() + /// Set the `flags` field. + pub fn set_flags(&mut self, flags: u32) -> &mut Self { + self.flags = flags.into(); + self } - /// Set the `src_cid` of the header. - pub fn set_src_cid(&mut self, cid: u64) -> &mut Self { - set_header_field!(self, src_cid, SRC_CID_OFFSET, cid); + /// Set a single flag (bitwise OR with existing flags). + pub fn set_flag(&mut self, flag: u32) -> &mut Self { + self.flags = (u32::from(self.flags) | flag).into(); self } - /// Return the `dst_cid` of the header. - pub fn dst_cid(&self) -> u64 { - self.header.dst_cid.into() + /// Set the `buf_alloc` field. + pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { + self.buf_alloc = buf_alloc.into(); + self } - /// Set the `dst_cid` of the header. - pub fn set_dst_cid(&mut self, cid: u64) -> &mut Self { - set_header_field!(self, dst_cid, DST_CID_OFFSET, cid); + /// Set the `fwd_cnt` field. + pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { + self.fwd_cnt = fwd_cnt.into(); self } - /// Return the `src_port` of the header. - pub fn src_port(&self) -> u32 { - self.header.src_port.into() + /// Get the `src_cid` field. + pub fn src_cid(&self) -> u64 { + self.src_cid.into() } - /// Set the `src_port` of the header. - pub fn set_src_port(&mut self, port: u32) -> &mut Self { - set_header_field!(self, src_port, SRC_PORT_OFFSET, port); - self + /// Get the `dst_cid` field. + pub fn dst_cid(&self) -> u64 { + self.dst_cid.into() } - /// Return the `dst_port` of the header. - pub fn dst_port(&self) -> u32 { - self.header.dst_port.into() + /// Get the `src_port` field. + pub fn src_port(&self) -> u32 { + self.src_port.into() } - /// Set the `dst_port` of the header. - pub fn set_dst_port(&mut self, port: u32) -> &mut Self { - set_header_field!(self, dst_port, DST_PORT_OFFSET, port); - self + /// Get the `dst_port` field. + pub fn dst_port(&self) -> u32 { + self.dst_port.into() } - /// Return the `len` of the header. + /// Get the `len` field. pub fn len(&self) -> u32 { - self.header.len.into() + self.len.into() } - /// Returns whether the `len` field of the header is 0 or not. + /// Returns true if there is no payload pub fn is_empty(&self) -> bool { self.len() == 0 } - /// Set the `len` of the header. - pub fn set_len(&mut self, len: u32) -> &mut Self { - set_header_field!(self, len, LEN_OFFSET, len); - self - } - - /// Return the `type` of the header. + /// Get the `type_` field. pub fn type_(&self) -> u16 { - self.header.type_.into() + self.type_.into() } - /// Set the `type` of the header. - pub fn set_type(&mut self, type_: u16) -> &mut Self { - set_header_field!(self, type_, TYPE_OFFSET, type_); - self - } - - /// Return the `op` of the header. + /// Get the `op` field. pub fn op(&self) -> u16 { - self.header.op.into() - } - - /// Set the `op` of the header. - pub fn set_op(&mut self, op: u16) -> &mut Self { - set_header_field!(self, op, OP_OFFSET, op); - self + self.op.into() } - /// Return the `flags` of the header. + /// Get the `flags` field. pub fn flags(&self) -> u32 { - self.header.flags.into() + self.flags.into() } - /// Set the `flags` of the header. - pub fn set_flags(&mut self, flags: u32) -> &mut Self { - set_header_field!(self, flags, FLAGS_OFFSET, flags); - self + /// Get the `buf_alloc` field. + pub fn buf_alloc(&self) -> u32 { + self.buf_alloc.into() } - /// Set a specific flag of the header. - pub fn set_flag(&mut self, flag: u32) -> &mut Self { - self.set_flags(self.flags() | flag); - self + /// Get the `fwd_cnt` field. + pub fn fwd_cnt(&self) -> u32 { + self.fwd_cnt.into() } +} - /// Return the `buf_alloc` of the header. - pub fn buf_alloc(&self) -> u32 { - self.header.buf_alloc.into() - } +/// The size of the header structure (when packed). +pub const PKT_HEADER_SIZE: usize = std::mem::size_of::(); - /// Set the `buf_alloc` of the header. - pub fn set_buf_alloc(&mut self, buf_alloc: u32) -> &mut Self { - set_header_field!(self, buf_alloc, BUF_ALLOC_OFFSET, buf_alloc); - self +/// Dedicated [`Result`](https://doc.rust-lang.org/std/result/) type. +pub type Result = std::result::Result; + +/// The TX vsock packet, implemented as a wrapper over a virtio descriptor chain using a `Reader`: +/// - a [`PacketHeader`] parsed from the chain; +/// - an optional data `Reader`, only present for data packets (VSOCK_OP_RW). +#[derive(Clone)] +pub struct VsockPacketTx<'a, B: BitmapSlice> { + header: PacketHeader, + data_slice: Option>, +} + +impl<'a, B: BitmapSlice> VsockPacketTx<'a, B> { + /// Return a mutable reference to the `data_slice` of the packet, if present. + pub fn data_slice(&mut self) -> Option<&mut Reader<'a, B>> { + self.data_slice.as_mut() } - /// Return the `fwd_cnt` of the header. - pub fn fwd_cnt(&self) -> u32 { - self.header.fwd_cnt.into() + /// Return a reference to the packet header. + pub fn header(&self) -> &PacketHeader { + &self.header } - /// Set the `fwd_cnt` of the header. - pub fn set_fwd_cnt(&mut self, fwd_cnt: u32) -> &mut Self { - set_header_field!(self, fwd_cnt, FWD_CNT_OFFSET, fwd_cnt); - self + /// Return a mutable reference to the packet header. + pub fn header_mut(&mut self) -> &mut PacketHeader { + &mut self.header } /// Create the packet wrapper from a TX chain. /// - /// The chain head is expected to hold a valid packet header. A following packet data - /// descriptor can optionally end the chain. + /// The chain is expected to hold a valid packet header, optionally followed by packet data. /// /// # Arguments /// /// * `mem` - the `GuestMemory` object that can be used to access the queue buffers. /// * `desc_chain` - the descriptor chain corresponding to a packet. - /// * `max_data_size` - the maximum size allowed for the packet payload, that was negotiated between the device and the driver. Tracking issue for defining this feature in virtio-spec [here](https://github.com/oasis-tcs/virtio-spec/issues/140). - /// - /// # Example - /// - /// ```rust - /// # use virtio_queue::mock::MockSplitQueue; - /// # use virtio_queue::{desc::{split::Descriptor as SplitDescriptor, RawDescriptor}, Queue, QueueT}; - /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; - /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap}; - /// - /// const MAX_PKT_BUF_SIZE: u32 = 64 * 1024; - /// const OP_RW: u16 = 5; - /// - /// # fn create_queue_with_chain(m: &GuestMemoryMmap) -> Queue { - /// # let vq = MockSplitQueue::new(m, 16); - /// # let mut q = vq.create_queue().unwrap(); - /// # - /// # let v = vec![ - /// # RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)), - /// # RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)), - /// # ]; - /// # let mut chain = vq.build_desc_chain(&v); - /// # q - /// # } - /// let mem = GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(); - /// // Create a queue and populate it with a descriptor chain. - /// let mut queue = create_queue_with_chain(&mem); - /// - /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) { - /// let pkt = match VsockPacket::from_tx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE) { - /// Ok(pkt) => pkt, - /// Err(_e) => { - /// // Do some error handling. - /// queue.add_used(&mem, head.head_index(), 0); - /// continue; - /// } - /// }; - /// // Here we would send the packet to the backend. Depending on the operation type, a - /// // different type of action will be done. - /// - /// // For example, if it's a RW packet, we will forward the packet payload to the backend. - /// if pkt.op() == OP_RW { - /// // Send the packet payload to the backend. - /// } - /// queue.add_used(&mem, head.head_index(), 0); - /// } - /// ``` + /// * `max_data_size` - the maximum size allowed for the packet payload, that was negotiated + /// between the device and the driver. pub fn from_tx_virtq_chain( mem: &'a M, - desc_chain: &mut DescriptorChain, + desc_chain: DescriptorChain, max_data_size: u32, ) -> Result where M: GuestMemory, ::Bitmap: WithBitmapSlice<'a, S = B>, - T: Deref, + T: Deref, T::Target: GuestMemory, { - let chain_head = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; - // All TX buffers must be device-readable. - if chain_head.is_write_only() { - return Err(Error::UnexpectedWriteOnlyDescriptor); - } - - // The packet header should fit inside the buffer corresponding to the head descriptor. - if (chain_head.len() as usize) < PKT_HEADER_SIZE { - return Err(Error::DescriptorLengthTooSmall); - } - - let header_slice = - get_single_slice(mem, chain_head.addr(), PKT_HEADER_SIZE, Permissions::Read)? - .expect("Received empty mapping for non-zero PKT_HEADER_SIZE"); - - let header = mem - .read_obj(chain_head.addr()) - .map_err(Error::InvalidMemoryAccess)?; + let mut reader = desc_chain.reader(mem).map_err(|_| Error::InvalidChain)?; + let header = reader + .read_obj::() + .map_err(|_| Error::DescriptorLengthTooSmall)?; let mut pkt = Self { - header_slice, header, data_slice: None, }; // If the `len` field of the header is zero, then the packet doesn't have a `data` element. - if pkt.is_empty() { + if pkt.header.is_empty() { return Ok(pkt); } // Reject packets that exceed the maximum allowed value for payload. - if pkt.len() > max_data_size { - return Err(Error::InvalidHeaderLen(pkt.len())); + if pkt.header.len() > max_data_size { + return Err(Error::InvalidHeaderLen(pkt.header.len())); } - // Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both - // header and data. - let data_slice = - if !chain_head.has_next() && chain_head.len() - PKT_HEADER_SIZE as u32 >= pkt.len() { - get_single_slice( - mem, - chain_head - .addr() - .checked_add(PKT_HEADER_SIZE as u64) - .ok_or(Error::DescriptorLengthTooSmall)?, - pkt.len() as usize, - Permissions::Read, - )? - .expect("Received empty mapping for non-empty packet") - } else { - if !chain_head.has_next() { - return Err(Error::DescriptorChainTooShort); - } - - let data_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; - - if data_desc.is_write_only() { - return Err(Error::UnexpectedWriteOnlyDescriptor); - } - - // The data buffer should be large enough to fit the size of the data, as described by - // the header descriptor. - if data_desc.len() < pkt.len() { - return Err(Error::DescriptorLengthTooSmall); - } - - get_single_slice(mem, data_desc.addr(), pkt.len() as usize, Permissions::Read)? - .expect("Received empty mapping for non-empty packet") - }; - - pkt.data_slice = Some(data_slice); + // Reject packets whose payload is bigger than the available space on the descriptor chain. + if pkt.header.len() as usize > reader.available_bytes() { + return Err(Error::DescriptorLengthTooSmall); + } + + // Limit the amount of data that can be read to the payload and not the full chain. + let _ = reader.split_at(pkt.header.len() as usize); + + pkt.data_slice = Some(reader); Ok(pkt) } +} + +impl fmt::Debug for VsockPacketTx<'_, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("VsockPacketTx") + .field("header", &self.header) + .field("has_data", &self.data_slice.is_some()) + .finish() + } +} + +/// The RX vsock packet, implemented as a wrapper over a virtio descriptor chain using `Writer`s: +/// - a header `Writer` for writing the packet header; +/// - a data `Writer` for writing the packet payload. +pub struct VsockPacketRx<'a, B: BitmapSlice> { + header_slice: Writer<'a, B>, + data_slice: Writer<'a, B>, +} + +impl fmt::Debug for VsockPacketRx<'_, B> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("VsockPacketRx").finish_non_exhaustive() + } +} + +impl<'a, B: BitmapSlice> VsockPacketRx<'a, B> { + /// Return a mutable reference to the data `Writer` of the packet. + pub fn data_slice(&mut self) -> &mut Writer<'a, B> { + &mut self.data_slice + } + + /// Return a mutable reference to the header `Writer` of the packet. + pub fn header_slice(&mut self) -> &mut Writer<'a, B> { + &mut self.header_slice + } /// Create the packet wrapper from an RX chain. /// - /// There must be two descriptors in the chain, both writable: a header descriptor and a data - /// descriptor. + /// The writable portion of the chain must be large enough to hold at least a packet header + /// and some data. /// /// # Arguments /// @@ -556,9 +345,11 @@ impl<'a, B: BitmapSlice> VsockPacket<'a, B> { /// # use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE; /// # use virtio_queue::mock::MockSplitQueue; /// # use virtio_queue::{desc::{split::Descriptor as SplitDescriptor, RawDescriptor}, Queue, QueueT}; - /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; + /// # use virtio_vsock::packet::{VsockPacketRx, PKT_HEADER_SIZE, PacketHeader}; /// # use vm_memory::{Bytes, GuestAddress, GuestAddressSpace, GuestMemoryMmap}; /// + /// # use std::io::Write; + /// /// # const MAX_PKT_BUF_SIZE: u32 = 64 * 1024; /// # const SRC_CID: u64 = 1; /// # const DST_CID: u64 = 2; @@ -587,137 +378,67 @@ impl<'a, B: BitmapSlice> VsockPacket<'a, B> { /// // Create a queue and populate it with a descriptor chain. /// let mut queue = create_queue_with_chain(&mem); /// - /// while let Some(mut head) = queue.pop_descriptor_chain(&mem) { - /// let used_len = match VsockPacket::from_rx_virtq_chain(&mem, &mut head, MAX_PKT_BUF_SIZE) { + /// while let Some(head) = queue.pop_descriptor_chain(&mem) { + /// let head_index = head.head_index(); + /// let used_len = match VsockPacketRx::from_rx_virtq_chain(&mem, head, MAX_PKT_BUF_SIZE) { /// Ok(mut pkt) => { /// // Make sure the header is zeroed out first. - /// pkt.header_slice() - /// .write(&[0u8; PKT_HEADER_SIZE], 0) - /// .unwrap(); - /// // Write data to the packet, using the setters. - /// pkt.set_src_cid(SRC_CID) + /// let mut header = PacketHeader::default(); + /// header.set_src_cid(SRC_CID) /// .set_dst_cid(DST_CID) /// .set_src_port(SRC_PORT) /// .set_dst_port(DST_PORT) /// .set_type(TYPE_STREAM) /// .set_buf_alloc(BUF_ALLOC) - /// .set_fwd_cnt(FWD_CNT); + /// .set_fwd_cnt(FWD_CNT) + /// .set_op(OP_RW) + /// .set_len(LEN); + /// pkt.header_slice().write_obj(header).unwrap(); /// // In this example, we are sending a RW packet. /// pkt.data_slice() - /// .unwrap() - /// .write_slice(&[1u8; LEN as usize], 0); - /// pkt.set_op(OP_RW).set_len(LEN); - /// pkt.header_slice().len() as u32 + LEN + /// .write(&[1u8; LEN as usize]).unwrap(); + /// size_of::() as u32 + LEN /// } /// Err(_e) => { /// // Do some error handling. /// 0 /// } /// }; - /// queue.add_used(&mem, head.head_index(), used_len); + /// queue.add_used(&mem, head_index, used_len); /// } /// ``` pub fn from_rx_virtq_chain( mem: &'a M, - desc_chain: &mut DescriptorChain, + desc_chain: DescriptorChain, max_data_size: u32, ) -> Result where M: GuestMemory, ::Bitmap: WithBitmapSlice<'a, S = B>, - T: Deref, + T: Deref, T::Target: GuestMemory, { - let chain_head = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; - // All RX buffers must be device-writable. - if !chain_head.is_write_only() { - return Err(Error::UnexpectedReadOnlyDescriptor); - } + let mut header_writer = desc_chain.writer(mem).map_err(|_| Error::InvalidChain)?; - // The packet header should fit inside the head descriptor. - if (chain_head.len() as usize) < PKT_HEADER_SIZE { - return Err(Error::DescriptorLengthTooSmall); + if header_writer.available_bytes() == 0 { + return Err(Error::InvalidChain); } - let header_slice = - get_single_slice(mem, chain_head.addr(), PKT_HEADER_SIZE, Permissions::Write)? - .expect("Received empty mapping for non-zero PKT_HEADER_SIZE"); - - // Starting from Linux 6.2 the virtio-vsock driver can use a single descriptor for both - // header and data. - let data_slice = if !chain_head.has_next() && chain_head.len() as usize > PKT_HEADER_SIZE { - get_single_slice( - mem, - chain_head - .addr() - .checked_add(PKT_HEADER_SIZE as u64) - .ok_or(Error::DescriptorLengthTooSmall)?, - chain_head.len() as usize - PKT_HEADER_SIZE, - Permissions::Write, - )? - } else { - if !chain_head.has_next() { - return Err(Error::DescriptorChainTooShort); - } - - let data_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; + let data_writer = header_writer + .split_at(size_of::()) + .map_err(|_| Error::DescriptorLengthTooSmall)?; - if !data_desc.is_write_only() { - return Err(Error::UnexpectedReadOnlyDescriptor); - } - - if data_desc.len() > max_data_size { - return Err(Error::DescriptorLengthTooLong); - } + if data_writer.available_bytes() as u32 > max_data_size { + return Err(Error::DescriptorLengthTooLong); + } - get_single_slice( - mem, - data_desc.addr(), - data_desc.len() as usize, - Permissions::Write, - )? - }; + if data_writer.available_bytes() == 0 { + return Err(Error::DescriptorLengthTooSmall); + } Ok(Self { - header_slice, - header: Default::default(), - // `None` if and only if the length is 0 - data_slice, - }) - } -} - -impl<'a> VsockPacket<'a, ()> { - /// Create a packet based on one pointer for the header, and an optional one for data. - /// - /// # Safety - /// - /// To use this safely, the caller must guarantee that the memory pointed to by the `hdr` and - /// `data` slices is available for the duration of the lifetime of the new `VolatileSlice`. The - /// caller must also guarantee that all other users of the given chunk of memory are using - /// volatile accesses. - /// - /// # Example - /// - /// ```rust - /// use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; - /// - /// const LEN: usize = 16; - /// - /// let mut pkt_raw = [0u8; PKT_HEADER_SIZE + LEN]; - /// let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE); - /// // Safe because `hdr_raw` and `data_raw` live for as long as the scope of the current - /// // example. - /// let packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() }; - /// ``` - pub unsafe fn new(header: &mut [u8], data: Option<&mut [u8]>) -> Result> { - if header.len() != PKT_HEADER_SIZE { - return Err(Error::InvalidHeaderInputSize(header.len())); - } - Ok(VsockPacket { - header_slice: VolatileSlice::new(header.as_mut_ptr(), PKT_HEADER_SIZE), - header: Default::default(), - data_slice: data.map(|data| VolatileSlice::new(data.as_mut_ptr(), data.len())), + header_slice: header_writer, + data_slice: data_writer, }) } } @@ -726,7 +447,7 @@ impl<'a> VsockPacket<'a, ()> { mod tests { use super::*; - use vm_memory::{GuestAddress, GuestMemoryMmap}; + use vm_memory::{Bytes, GuestAddress, GuestMemoryMmap}; use virtio_bindings::bindings::virtio_ring::VRING_DESC_F_WRITE; use virtio_queue::desc::{split::Descriptor as SplitDescriptor, RawDescriptor}; @@ -736,22 +457,10 @@ mod tests { fn eq(&self, other: &Self) -> bool { use self::Error::*; match (self, other) { - (DescriptorChainTooShort, DescriptorChainTooShort) => true, (DescriptorLengthTooSmall, DescriptorLengthTooSmall) => true, (DescriptorLengthTooLong, DescriptorLengthTooLong) => true, - (FragmentedMemory, FragmentedMemory) => true, - (InvalidHeaderInputSize(size), InvalidHeaderInputSize(other_size)) => { - size == other_size - } + (InvalidChain, InvalidChain) => true, (InvalidHeaderLen(size), InvalidHeaderLen(other_size)) => size == other_size, - (InvalidMemoryAccess(ref e), InvalidMemoryAccess(ref other_e)) => { - format!("{e}").eq(&format!("{other_e}")) - } - (InvalidVolatileAccess(ref e), InvalidVolatileAccess(ref other_e)) => { - format!("{e}").eq(&format!("{other_e}")) - } - (UnexpectedReadOnlyDescriptor, UnexpectedReadOnlyDescriptor) => true, - (UnexpectedWriteOnlyDescriptor, UnexpectedWriteOnlyDescriptor) => true, _ => false, } } @@ -772,35 +481,6 @@ mod tests { const MAX_PKT_BUF_SIZE: u32 = 64 * 1024; - /// For `get_mem_ptr()`: Whether we access the RX or TX ring. - #[derive(Copy, Clone, Debug, Eq, PartialEq)] - enum RxTx { - /// Receive ring - Rx, - /// Transmission ring - Tx, - } - - /// Return a host pointer to the slice at `[addr, addr + length)`. Use this only for - /// comparison in `assert_eq!()`. - fn get_mem_ptr( - mem: &M, - addr: GuestAddress, - length: usize, - rx_tx: RxTx, - ) -> Result<*const u8> { - let access = match rx_tx { - RxTx::Rx => Permissions::Write, - RxTx::Tx => Permissions::Read, - }; - - assert!(length > 0); - Ok(get_single_slice(mem, addr, length, access)? - .unwrap() - .ptr_guard() - .as_ptr()) - } - #[test] fn test_from_rx_virtq_chain() { let mem: GuestMemoryMmap = @@ -810,18 +490,14 @@ mod tests { let v = vec![ // A device-readable packet header descriptor should be invalid. RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)), - RawDescriptor::from(SplitDescriptor::new( - 0x20_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), + RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)), ]; let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); + assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::UnexpectedReadOnlyDescriptor + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let v = vec![ @@ -832,16 +508,10 @@ mod tests { VRING_DESC_F_WRITE as u16, 0, )), - RawDescriptor::from(SplitDescriptor::new( - 0x20_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), Error::DescriptorLengthTooSmall ); @@ -859,9 +529,9 @@ mod tests { 0, )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), Error::DescriptorLengthTooLong ); @@ -874,25 +544,19 @@ mod tests { 0, )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::DescriptorChainTooShort + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::DescriptorLengthTooSmall ); - let v = vec![ - RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)), - RawDescriptor::from(SplitDescriptor::new( - 0x20_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let v = vec![RawDescriptor::from(SplitDescriptor::new( + 0x20_0000, 0x100, 0, 0, + ))]; + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::UnexpectedReadOnlyDescriptor + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let mem: GuestMemoryMmap = @@ -914,10 +578,10 @@ mod tests { )), ]; let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::FragmentedMemory + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let v = vec![ @@ -935,29 +599,12 @@ mod tests { 0, )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress( - 0x20_0000 - ))) + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); - let v = vec![ - RawDescriptor::from(SplitDescriptor::new( - 0x5_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - // A device-readable packet data descriptor should be invalid. - RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)), - ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); - assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::UnexpectedReadOnlyDescriptor - ); let v = vec![ RawDescriptor::from(SplitDescriptor::new( 0x5_0000, @@ -973,10 +620,11 @@ mod tests { 0, )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::FragmentedMemory + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let v = vec![ @@ -994,12 +642,10 @@ mod tests { 0, )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress( - 0x20_0000 - ))) + VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); // Let's also test a valid descriptor chain. @@ -1017,29 +663,14 @@ mod tests { 0, )), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); - - let packet = VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); - assert_eq!(packet.header, PacketHeader::default()); - let header = packet.header_slice(); - assert_eq!( - header.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x5_0000), header.len(), RxTx::Rx).unwrap() - ); - assert_eq!(header.len(), PKT_HEADER_SIZE); + let chain = queue.build_desc_chain(&v).unwrap(); - let data = packet.data_slice().unwrap(); - assert_eq!( - data.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x8_0000), data.len(), RxTx::Rx).unwrap() - ); - assert_eq!(data.len(), 0x100); + let mut packet = VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap(); - // If we try to get a vsock packet again, it fails because we already consumed all the - // descriptors from the chain. + assert_eq!(packet.header_slice().available_bytes(), PKT_HEADER_SIZE); assert_eq!( - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::DescriptorChainTooShort + packet.data_slice().available_bytes(), + 0x200 - PKT_HEADER_SIZE ); // Let's also test a valid descriptor chain, with both header and data on a single @@ -1050,29 +681,12 @@ mod tests { VRING_DESC_F_WRITE as u16, 0, ))]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); - let packet = VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); - assert_eq!(packet.header, PacketHeader::default()); + let mut packet = VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap(); let header = packet.header_slice(); - assert_eq!( - header.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x5_0000), header.len(), RxTx::Rx).unwrap() - ); - assert_eq!(header.len(), PKT_HEADER_SIZE); - - let data = packet.data_slice().unwrap(); - assert_eq!( - data.ptr_guard().as_ptr(), - get_mem_ptr( - &mem, - GuestAddress(0x5_0000 + PKT_HEADER_SIZE as u64), - data.len(), - RxTx::Rx - ) - .unwrap() - ); - assert_eq!(data.len(), 0x100); + assert_eq!(header.available_bytes(), PKT_HEADER_SIZE); + assert_eq!(packet.data_slice().available_bytes(), 0x100); } #[test] @@ -1089,13 +703,12 @@ mod tests { VRING_DESC_F_WRITE as u16, 0, )), - RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)), ]; let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::UnexpectedWriteOnlyDescriptor + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::DescriptorLengthTooSmall ); let v = vec![ @@ -1106,11 +719,10 @@ mod tests { 0, 0, )), - RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), Error::DescriptorLengthTooSmall ); @@ -1121,7 +733,7 @@ mod tests { 0, 0, ))]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); let header = PacketHeader { src_cid: SRC_CID.into(), @@ -1137,14 +749,8 @@ mod tests { }; mem.write_obj(header, GuestAddress(0x10_0000)).unwrap(); - let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); + let mut packet = VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap(); assert_eq!(packet.header, header); - let header_slice = packet.header_slice(); - assert_eq!( - header_slice.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x10_0000), header_slice.len(), RxTx::Tx).unwrap() - ); - assert_eq!(header_slice.len(), PKT_HEADER_SIZE); assert!(packet.data_slice().is_none()); let mem: GuestMemoryMmap = @@ -1156,10 +762,10 @@ mod tests { RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)), ]; let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::FragmentedMemory + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let v = vec![ @@ -1167,12 +773,10 @@ mod tests { RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)), RawDescriptor::from(SplitDescriptor::new(0x30_0000, 0x100, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress( - 0x20_0000 - ))) + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); // Write some non-zero value to the `len` field of the header, which means there is also @@ -1194,9 +798,9 @@ mod tests { RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)), RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), Error::InvalidHeaderLen(MAX_PKT_BUF_SIZE + 1) ); @@ -1215,13 +819,14 @@ mod tests { }; mem.write_obj(header, GuestAddress(0x5_0000)).unwrap(); let v = vec![ - // The data descriptor is missing. + // No room for data. RawDescriptor::from(SplitDescriptor::new(0x5_0000, PKT_HEADER_SIZE as u32, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); + assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::DescriptorChainTooShort + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::DescriptorLengthTooSmall ); let v = vec![ @@ -1229,10 +834,11 @@ mod tests { // The data array doesn't fit entirely in the memory bounds. RawDescriptor::from(SplitDescriptor::new(0x10_0000, 0x100, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); + assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::FragmentedMemory + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let v = vec![ @@ -1240,38 +846,21 @@ mod tests { // The data array is outside the memory bounds. RawDescriptor::from(SplitDescriptor::new(0x20_0000, 0x100, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); - assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::InvalidMemoryAccess(GuestMemoryError::InvalidGuestAddress(GuestAddress( - 0x20_0000 - ))) - ); + let chain = queue.build_desc_chain(&v).unwrap(); - let v = vec![ - RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)), - // A device-writable packet data descriptor should be invalid. - RawDescriptor::from(SplitDescriptor::new( - 0x8_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::UnexpectedWriteOnlyDescriptor + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), + Error::InvalidChain ); let v = vec![ - RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)), + RawDescriptor::from(SplitDescriptor::new(0x5_0000, PKT_HEADER_SIZE as u32, 0, 0)), // A data length < the length of data as described by the header. RawDescriptor::from(SplitDescriptor::new(0x8_0000, LEN - 1, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), + VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap_err(), Error::DescriptorLengthTooSmall ); @@ -1280,32 +869,16 @@ mod tests { RawDescriptor::from(SplitDescriptor::new(0x5_0000, 0x100, 0, 0)), RawDescriptor::from(SplitDescriptor::new(0x8_0000, 0x100, 0, 0)), ]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); - let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); + let mut packet = VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap(); assert_eq!(packet.header, header); - let header_slice = packet.header_slice(); - assert_eq!( - header_slice.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x5_0000), header_slice.len(), RxTx::Tx).unwrap() - ); - assert_eq!(header_slice.len(), PKT_HEADER_SIZE); + // The `len` field of the header was set to 16. - assert_eq!(packet.len(), LEN); + assert_eq!(packet.header().len(), LEN); let data = packet.data_slice().unwrap(); - assert_eq!( - data.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x8_0000), data.len(), RxTx::Tx).unwrap() - ); - assert_eq!(data.len(), LEN as usize); - - // If we try to get a vsock packet again, it fails because we already consumed all the - // descriptors from the chain. - assert_eq!( - VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap_err(), - Error::DescriptorChainTooShort - ); + assert_eq!(data.available_bytes(), LEN as usize); // Let's also test a valid descriptor chain, with both header and data on a single // descriptor. @@ -1315,31 +888,15 @@ mod tests { 0, 0, ))]; - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); - let packet = VsockPacket::from_tx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); + let mut packet = VsockPacketTx::from_tx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap(); assert_eq!(packet.header, header); - let header_slice = packet.header_slice(); - assert_eq!( - header_slice.ptr_guard().as_ptr(), - get_mem_ptr(&mem, GuestAddress(0x5_0000), header_slice.len(), RxTx::Tx).unwrap() - ); - assert_eq!(header_slice.len(), PKT_HEADER_SIZE); // The `len` field of the header was set to 16. - assert_eq!(packet.len(), LEN); + assert_eq!(packet.header().len(), LEN); let data = packet.data_slice().unwrap(); - assert_eq!( - data.ptr_guard().as_ptr(), - get_mem_ptr( - &mem, - GuestAddress(0x5_0000 + PKT_HEADER_SIZE as u64), - data.len(), - RxTx::Tx - ) - .unwrap() - ); - assert_eq!(data.len(), LEN as usize); + assert_eq!(data.available_bytes(), LEN as usize); } #[test] @@ -1362,11 +919,12 @@ mod tests { )), ]; let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); + let chain = queue.build_desc_chain(&v).unwrap(); - let mut packet = - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); - packet + let mut packet = VsockPacketRx::from_rx_virtq_chain(&mem, chain, MAX_PKT_BUF_SIZE).unwrap(); + + let mut header = PacketHeader::default(); + header .set_src_cid(SRC_CID) .set_dst_cid(DST_CID) .set_src_port(SRC_PORT) @@ -1379,226 +937,103 @@ mod tests { .set_buf_alloc(BUF_ALLOC) .set_fwd_cnt(FWD_CNT); - assert_eq!(packet.flags(), FLAGS | FLAG); - assert_eq!(packet.op(), OP); - assert_eq!(packet.type_(), TYPE); - assert_eq!(packet.dst_cid(), DST_CID); - assert_eq!(packet.dst_port(), DST_PORT); - assert_eq!(packet.src_cid(), SRC_CID); - assert_eq!(packet.src_port(), SRC_PORT); - assert_eq!(packet.fwd_cnt(), FWD_CNT); - assert_eq!(packet.len(), LEN); - assert_eq!(packet.buf_alloc(), BUF_ALLOC); - - let expected_header = PacketHeader { - src_cid: SRC_CID.into(), - dst_cid: DST_CID.into(), - src_port: SRC_PORT.into(), - dst_port: DST_PORT.into(), - len: LEN.into(), - type_: TYPE.into(), - op: OP.into(), - flags: (FLAGS | FLAG).into(), - buf_alloc: BUF_ALLOC.into(), - fwd_cnt: FWD_CNT.into(), - }; - - assert_eq!(packet.header, expected_header); + // Verify PacketHeader getters. + assert_eq!(header.src_cid(), SRC_CID); + assert_eq!(header.dst_cid(), DST_CID); + assert_eq!(header.src_port(), SRC_PORT); + assert_eq!(header.dst_port(), DST_PORT); + assert_eq!(header.len(), LEN); + assert_eq!(header.type_(), TYPE); + assert_eq!(header.op(), OP); + assert_eq!(header.flags(), FLAGS | FLAG); + assert_eq!(header.buf_alloc(), BUF_ALLOC); + assert_eq!(header.fwd_cnt(), FWD_CNT); + + // Write header through the Writer, then read back from guest memory. + packet.header_slice().write_obj(header).unwrap(); + + let read_back: PacketHeader = mem.read_obj(GuestAddress(0x10_0000)).unwrap(); + assert_eq!(read_back, header); + + // Offsets of the header fields. + const SRC_CID_OFFSET: u64 = 0; + const DST_CID_OFFSET: u64 = 8; + const SRC_PORT_OFFSET: u64 = 16; + const DST_PORT_OFFSET: u64 = 20; + const LEN_OFFSET: u64 = 24; + const TYPE_OFFSET: u64 = 28; + const OP_OFFSET: u64 = 30; + const FLAGS_OFFSET: u64 = 32; + const BUF_ALLOC_OFFSET: u64 = 36; + const FWD_CNT_OFFSET: u64 = 40; + + let base = 0x10_0000; assert_eq!( u64::from_le( - packet - .header_slice() - .read_obj::(SRC_CID_OFFSET) + mem.read_obj::(GuestAddress(base + SRC_CID_OFFSET)) .unwrap() ), SRC_CID ); assert_eq!( u64::from_le( - packet - .header_slice() - .read_obj::(DST_CID_OFFSET) + mem.read_obj::(GuestAddress(base + DST_CID_OFFSET)) .unwrap() ), DST_CID ); assert_eq!( u32::from_le( - packet - .header_slice() - .read_obj::(SRC_PORT_OFFSET) + mem.read_obj::(GuestAddress(base + SRC_PORT_OFFSET)) .unwrap() ), SRC_PORT ); assert_eq!( u32::from_le( - packet - .header_slice() - .read_obj::(DST_PORT_OFFSET) + mem.read_obj::(GuestAddress(base + DST_PORT_OFFSET)) .unwrap() ), - DST_PORT, + DST_PORT ); assert_eq!( - u32::from_le(packet.header_slice().read_obj::(LEN_OFFSET).unwrap()), + u32::from_le( + mem.read_obj::(GuestAddress(base + LEN_OFFSET)) + .unwrap() + ), LEN ); assert_eq!( - u16::from_le(packet.header_slice().read_obj::(TYPE_OFFSET).unwrap()), + u16::from_le( + mem.read_obj::(GuestAddress(base + TYPE_OFFSET)) + .unwrap() + ), TYPE ); assert_eq!( - u16::from_le(packet.header_slice().read_obj::(OP_OFFSET).unwrap()), + u16::from_le(mem.read_obj::(GuestAddress(base + OP_OFFSET)).unwrap()), OP ); assert_eq!( - u32::from_le(packet.header_slice().read_obj::(FLAGS_OFFSET).unwrap()), + u32::from_le( + mem.read_obj::(GuestAddress(base + FLAGS_OFFSET)) + .unwrap() + ), FLAGS | FLAG ); assert_eq!( u32::from_le( - packet - .header_slice() - .read_obj::(BUF_ALLOC_OFFSET) + mem.read_obj::(GuestAddress(base + BUF_ALLOC_OFFSET)) .unwrap() ), BUF_ALLOC ); assert_eq!( u32::from_le( - packet - .header_slice() - .read_obj::(FWD_CNT_OFFSET) + mem.read_obj::(GuestAddress(base + FWD_CNT_OFFSET)) .unwrap() ), FWD_CNT ); } - - #[test] - fn test_set_header_from_raw() { - let mem: GuestMemoryMmap = - GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap(); - // The `build_desc_chain` function will populate the `NEXT` related flags and field. - let v = vec![ - RawDescriptor::from(SplitDescriptor::new( - 0x10_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - RawDescriptor::from(SplitDescriptor::new( - 0x20_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - ]; - let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); - - let mut packet = - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); - - let header = PacketHeader { - src_cid: SRC_CID.into(), - dst_cid: DST_CID.into(), - src_port: SRC_PORT.into(), - dst_port: DST_PORT.into(), - len: LEN.into(), - type_: TYPE.into(), - op: OP.into(), - flags: (FLAGS | FLAG).into(), - buf_alloc: BUF_ALLOC.into(), - fwd_cnt: FWD_CNT.into(), - }; - - // SAFETY: created from an existing packet header. - let slice = unsafe { - std::slice::from_raw_parts( - (&header as *const PacketHeader) as *const u8, - std::mem::size_of::(), - ) - }; - assert_eq!(packet.header, PacketHeader::default()); - packet.set_header_from_raw(slice).unwrap(); - assert_eq!(packet.header, header); - let header_from_slice: PacketHeader = packet.header_slice().read_obj(0).unwrap(); - assert_eq!(header_from_slice, header); - - let invalid_slice = [0; PKT_HEADER_SIZE - 1]; - assert_eq!( - packet.set_header_from_raw(&invalid_slice).unwrap_err(), - Error::InvalidHeaderInputSize(PKT_HEADER_SIZE - 1) - ); - } - - #[test] - fn test_packet_new() { - let mut pkt_raw = [0u8; PKT_HEADER_SIZE + LEN as usize]; - let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE); - // SAFETY: safe because ``hdr_raw` and `data_raw` live for as long as - // the scope of the current test. - let packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() }; - assert_eq!( - packet.header_slice.ptr_guard().as_ptr(), - hdr_raw.as_mut_ptr(), - ); - assert_eq!(packet.header_slice.len(), PKT_HEADER_SIZE); - assert_eq!(packet.header, PacketHeader::default()); - assert_eq!( - packet.data_slice.unwrap().ptr_guard().as_ptr(), - data_raw.as_mut_ptr(), - ); - assert_eq!(packet.data_slice.unwrap().len(), LEN as usize); - - // SAFETY: Safe because ``hdr_raw` and `data_raw` live as long as the - // scope of the current test. - let packet = unsafe { VsockPacket::new(hdr_raw, None).unwrap() }; - assert_eq!( - packet.header_slice.ptr_guard().as_ptr(), - hdr_raw.as_mut_ptr(), - ); - assert_eq!(packet.header, PacketHeader::default()); - assert!(packet.data_slice.is_none()); - - let mut hdr_raw = [0u8; PKT_HEADER_SIZE - 1]; - assert_eq!( - // SAFETY: Safe because ``hdr_raw` lives for as long as the scope of the current test. - unsafe { VsockPacket::new(&mut hdr_raw, None).unwrap_err() }, - Error::InvalidHeaderInputSize(PKT_HEADER_SIZE - 1) - ); - } - - #[test] - #[should_panic] - fn test_set_header_field_with_invalid_offset() { - const INVALID_OFFSET: usize = 50; - - let mem: GuestMemoryMmap = - GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x30_0000)]).unwrap(); - // The `build_desc_chain` function will populate the `NEXT` related flags and field. - let v = vec![ - RawDescriptor::from(SplitDescriptor::new( - 0x10_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - RawDescriptor::from(SplitDescriptor::new( - 0x20_0000, - 0x100, - VRING_DESC_F_WRITE as u16, - 0, - )), - ]; - let queue = MockSplitQueue::new(&mem, 16); - let mut chain = queue.build_desc_chain(&v).unwrap(); - - let mut packet = - VsockPacket::from_rx_virtq_chain(&mem, &mut chain, MAX_PKT_BUF_SIZE).unwrap(); - // Set the `src_cid` of the header, but use an invalid offset for that. - set_header_field!(packet, src_cid, INVALID_OFFSET, SRC_CID); - } }