diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index 54b2f20c47..3fa7c1a807 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -24,11 +24,21 @@ pub(crate) enum VsockState { ReceiveRequest, Connected, Connecting, + /// The peer sent a graceful `Op::Shutdown`. Buffered data may still be + /// read; once drained, reads report EOF. Shutdown, + /// The peer (or the device) sent an abortive `Op::Rst`. Buffered data may + /// still be read; once drained, reads report `ECONNRESET`. + Reset, } pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024; +/// Identifies an established connection by its local (listen) port and the +/// remote endpoint `(remote_cid, remote_port)`. Multiple connections may share +/// one local port, mirroring how TCP demultiplexes by the connection 4-tuple. +pub(crate) type ConnKey = (u32, u32, u32); + #[derive(Debug)] pub(crate) struct RawSocket { pub remote_cid: u32, @@ -77,8 +87,20 @@ async fn vsock_run() { let type_ = Type::try_from(header.type_.to_ne()).unwrap(); let mut vsock_guard = VSOCK_MAP.lock(); let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); + let remote_port = header.src_port.to_ne(); - let Some(raw) = vsock_guard.get_mut_socket(port) else { + // Packets for an established connection address the local listen + // port but belong to a specific remote endpoint, so route them to + // the connection entry keyed by `(port, remote_cid, remote_port)`. + // `Op::Request` (and outbound-connect responses) have no such entry + // yet and fall back to the listener/connect socket in `port_map`. + let raw = if let Some(conn) = + vsock_guard.get_mut_connection((port, header_cid, remote_port)) + { + conn + } else if let Some(s) = vsock_guard.get_mut_socket(port) { + s + } else { return; }; @@ -113,6 +135,16 @@ async fn vsock_run() { } else if op == Op::Shutdown { if raw.remote_cid == header_cid { raw.state = VsockState::Shutdown; + raw.rx_waker.wake(); + raw.tx_waker.wake(); + } else { + trace!("Receive message from invalid source {header_cid}"); + } + } else if op == Op::Rst { + if raw.remote_cid == header_cid { + raw.state = VsockState::Reset; + raw.rx_waker.wake(); + raw.tx_waker.wake(); } else { trace!("Receive message from invalid source {header_cid}"); } @@ -123,6 +155,13 @@ async fn vsock_run() { raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); raw.tx_waker.wake(); } + } else if op == Op::Request { + // A connection request the listener cannot service right now + // (e.g. a previous request is still pending accept on this + // port). Reply with a reset so the peer fails fast instead of + // blocking until it times out. + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; } else if raw.remote_cid == header_cid { hdr = Some(*header); fwd_cnt = raw.fwd_cnt; @@ -161,13 +200,19 @@ async fn vsock_run() { } pub(crate) struct VsockMap { + /// Listeners (keyed by listen port) and outbound-connect sockets (keyed by + /// a synthetic ephemeral port). port_map: BTreeMap, + /// Established inbound connections, keyed by `(local_port, remote_cid, + /// remote_port)`, so several connections can share one listen port. + conn_map: BTreeMap, } impl VsockMap { pub const fn new() -> Self { Self { port_map: BTreeMap::new(), + conn_map: BTreeMap::new(), } } @@ -198,17 +243,47 @@ impl VsockMap { Err(Errno::Badf) } - pub fn get_socket(&self, port: u32) -> Option<&RawSocket> { - self.port_map.get(&port) - } - pub fn get_mut_socket(&mut self, port: u32) -> Option<&mut RawSocket> { self.port_map.get_mut(&port) } + pub fn get_mut_connection(&mut self, key: ConnKey) -> Option<&mut RawSocket> { + self.conn_map.get_mut(&key) + } + pub fn remove_socket(&mut self, port: u32) { self.port_map.remove(&port); } + + pub fn remove_connection(&mut self, key: ConnKey) { + self.conn_map.remove(&key); + } + + /// Move the pending connection on `listen_port` (in state `ReceiveRequest`) + /// into `conn_map` keyed by `(listen_port, remote_cid, remote_port)`, then + /// reset the listener entry to `Listen` so it can accept further + /// connections. Returns the new connection's key. + pub fn establish(&mut self, listen_port: u32) -> io::Result { + let listener = self.port_map.get_mut(&listen_port).ok_or(Errno::Inval)?; + let key = (listen_port, listener.remote_cid, listener.remote_port); + + // Build the connection entry from the negotiated handshake state, then + // reset the listener's fields in place. Resetting in place (rather than + // replacing the whole struct) preserves the listener's wakers, so an + // `accept()` future already parked on it is not lost. + let mut conn = RawSocket::new(VsockState::Connected); + conn.remote_cid = listener.remote_cid; + conn.remote_port = listener.remote_port; + conn.peer_buf_alloc = listener.peer_buf_alloc; + + listener.state = VsockState::Listen; + listener.remote_cid = 0; + listener.remote_port = 0; + listener.peer_buf_alloc = 0; + + self.conn_map.insert(key, conn); + Ok(key) + } } pub(crate) fn init() { diff --git a/src/fd/delegate.rs b/src/fd/delegate.rs index abf25fa4e1..27c1961b35 100644 --- a/src/fd/delegate.rs +++ b/src/fd/delegate.rs @@ -44,8 +44,6 @@ pub(crate) enum Fd { #[cfg(feature = "udp")] UdpSocket(udp::Socket), #[cfg(feature = "virtio-vsock")] - VsockNullSocket(vsock::NullSocket), - #[cfg(feature = "virtio-vsock")] VsockSocket(vsock::Socket), #[cfg(feature = "virtio-fs")] VirtioFsFileHandle(VirtioFsFileHandle), @@ -94,8 +92,6 @@ fd_from! { #[cfg(feature = "udp")] UdpSocket(udp::Socket), #[cfg(feature = "virtio-vsock")] - VsockNullSocket(vsock::NullSocket), - #[cfg(feature = "virtio-vsock")] VsockSocket(vsock::Socket), #[cfg(feature = "virtio-fs")] VirtioFsFileHandle(VirtioFsFileHandle), @@ -128,8 +124,6 @@ impl ObjectInterface for Fd { #[cfg(feature = "udp")] Self::UdpSocket(fd) => fd, #[cfg(feature = "virtio-vsock")] - Self::VsockNullSocket(fd) => fd, - #[cfg(feature = "virtio-vsock")] Self::VsockSocket(fd) => fd, #[cfg(feature = "virtio-fs")] Self::VirtioFsFileHandle(fd) => fd, diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 514d97deba..ec5814cc97 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -11,10 +11,17 @@ use crate::arch::kernel::mmio as hardware; #[cfg(feature = "pci")] use crate::drivers::pci as hardware; use crate::errno::Errno; -use crate::executor::vsock::{VSOCK_MAP, VsockState}; +use crate::executor::vsock::{ConnKey, RawSocket, VSOCK_MAP, VsockState}; use crate::fd::{self, Endpoint, Fd, ListenEndpoint, ObjectInterface, PollEvent}; use crate::io; +/// Further receives will be disallowed +pub const SHUT_RD: i32 = 0; +/// Further sends will be disallowed +pub const SHUT_WR: i32 = 1; +/// Further sends and receives will be disallowed +pub const SHUT_RDWR: i32 = 2; + #[derive(Debug)] pub struct VsockListenEndpoint { pub port: u32, @@ -39,18 +46,14 @@ impl VsockEndpoint { } } -pub struct NullSocket; - -impl NullSocket { - pub const fn new() -> Self { - Self {} - } -} - -impl ObjectInterface for NullSocket {} - pub struct Socket { + /// The local port this socket is bound/listening on, or the synthetic + /// ephemeral port of an outbound connection. port: u32, + /// Set for sockets returned by `accept`: identifies the established + /// connection in the executor's connection map. `None` for listeners and + /// outbound-connect sockets, which are keyed by `port` instead. + conn: Option, cid: u32, is_nonblocking: bool, } @@ -59,17 +62,31 @@ impl Socket { pub fn new() -> Self { Self { port: 0, + conn: None, cid: u32::MAX, is_nonblocking: false, } } + + /// Borrow this socket's `RawSocket` from the executor map, whether it is a + /// listener/connect socket (keyed by `port`) or an accepted connection + /// (keyed by `conn`). + fn raw_mut<'a>( + &self, + guard: &'a mut crate::executor::vsock::VsockMap, + ) -> Option<&'a mut RawSocket> { + match self.conn { + Some(key) => guard.get_mut_connection(key), + None => guard.get_mut_socket(self.port), + } + } } impl ObjectInterface for Socket { async fn poll(&self, event: PollEvent) -> io::Result { future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); - let raw = guard.get_mut_socket(self.port).ok_or(Errno::Inval)?; + let raw = self.raw_mut(&mut guard).ok_or(Errno::Inval)?; match raw.state { VsockState::Shutdown | VsockState::ReceiveRequest => { @@ -88,6 +105,20 @@ impl ObjectInterface for Socket { Poll::Ready(Ok(ret)) } } + VsockState::Reset => { + // Reset is readable/writable so the next read/write + // observes ECONNRESET, and signals error + hangup. + let available = PollEvent::POLLIN + | PollEvent::POLLRDNORM + | PollEvent::POLLRDBAND + | PollEvent::POLLOUT + | PollEvent::POLLWRNORM + | PollEvent::POLLWRBAND; + + Poll::Ready(Ok((event & available) + | PollEvent::POLLERR + | PollEvent::POLLHUP)) + } VsockState::Listen | VsockState::Connecting => { raw.rx_waker.register(cx.waker()); raw.tx_waker.register(cx.waker()); @@ -138,12 +169,19 @@ impl ObjectInterface for Socket { async fn bind(&mut self, endpoint: ListenEndpoint) -> io::Result<()> { match endpoint { ListenEndpoint::Vsock(ep) => { - self.port = ep.port; - if let Some(cid) = ep.cid { - self.cid = cid; - } else { - self.cid = u32::MAX; + // A socket may only listen on `VMADDR_CID_ANY` or this guest's + // own CID. Binding to any other CID is rejected, mirroring Linux + // `AF_VSOCK` (which returns `EADDRNOTAVAIL`). + let cid = ep.cid.unwrap_or(u32::MAX); + if cid != u32::MAX { + let local_cid = hardware::get_vsock_driver().unwrap().lock().get_cid(); + if u64::from(cid) != local_cid { + return Err(Errno::Addrnotavail); + } } + + self.port = ep.port; + self.cid = cid; VSOCK_MAP.lock().bind(ep.port) } #[cfg(feature = "net")] @@ -200,6 +238,9 @@ impl ObjectInterface for Socket { raw.rx_waker.register(cx.waker()); Poll::Pending } + // A reset in response to our request means the peer + // refused the connection. + VsockState::Reset => Poll::Ready(Err(Errno::Connrefused)), _ => Poll::Ready(Err(Errno::Badf)), } }) @@ -211,8 +252,8 @@ impl ObjectInterface for Socket { } async fn getpeername(&self) -> io::Result> { - let guard = VSOCK_MAP.lock(); - let raw = guard.get_socket(self.port).ok_or(Errno::Inval)?; + let mut guard = VSOCK_MAP.lock(); + let raw = self.raw_mut(&mut guard).ok_or(Errno::Inval)?; Ok(Some(Endpoint::Vsock(VsockEndpoint::new( raw.remote_port, @@ -235,9 +276,8 @@ impl ObjectInterface for Socket { async fn accept(&mut self) -> io::Result<(Arc>, Endpoint)> { let port = self.port; - let cid = self.cid; - let endpoint = future::poll_fn(|cx| { + let (conn_key, endpoint) = future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); let raw = guard.get_mut_socket(port).ok_or(Errno::Inval)?; @@ -251,52 +291,63 @@ impl ObjectInterface for Socket { } } VsockState::ReceiveRequest => { - let result = { - const HEADER_SIZE: usize = size_of::(); - let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); - let local_cid = driver_guard.get_cid(); - - driver_guard.send_packet(HEADER_SIZE, |buffer| { - let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; - - response.src_cid = le64::from_ne(local_cid); - response.dst_cid = le64::from_ne(raw.remote_cid.into()); - response.src_port = le32::from_ne(port); - response.dst_port = le32::from_ne(raw.remote_port); - response.len = le32::from_ne(0); - response.type_ = le16::from_ne(Type::Stream.into()); - if local_cid != u64::from(cid) && cid != u32::MAX { - response.op = le16::from_ne(Op::Rst.into()); - } else { - response.op = le16::from_ne(Op::Response.into()); - } - response.flags = le32::from_ne(0); - response.buf_alloc = le32::from_ne( - crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32, - ); - response.fwd_cnt = le32::from_ne(raw.fwd_cnt); - }); - - raw.state = VsockState::Connected; - - Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid)) - }; - - Poll::Ready(result) + const HEADER_SIZE: usize = size_of::(); + let mut driver_guard = hardware::get_vsock_driver().unwrap().lock(); + let local_cid = driver_guard.get_cid(); + + driver_guard.send_packet(HEADER_SIZE, |buffer| { + let response = unsafe { &mut *buffer.as_mut_ptr().cast::() }; + + response.src_cid = le64::from_ne(local_cid); + response.dst_cid = le64::from_ne(raw.remote_cid.into()); + response.src_port = le32::from_ne(port); + response.dst_port = le32::from_ne(raw.remote_port); + response.len = le32::from_ne(0); + response.type_ = le16::from_ne(Type::Stream.into()); + response.op = le16::from_ne(Op::Response.into()); + response.flags = le32::from_ne(0); + response.buf_alloc = + le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32); + response.fwd_cnt = le32::from_ne(raw.fwd_cnt); + }); + + let endpoint = VsockEndpoint::new(raw.remote_port, raw.remote_cid); + + // Move the pending connection into the connection map keyed by its + // remote endpoint and reset the listener so it keeps accepting. + let conn_key = guard.establish(port)?; + + Poll::Ready(Ok((conn_key, endpoint))) } _ => Poll::Ready(Err(Errno::Badf)), } }) .await?; + // Return the accepted connection as a DISTINCT Socket addressing the + // established connection. The listener `self` is left untouched, so it + // keeps accepting further connections on the same port. + let conn = Socket { + port, + conn: Some(conn_key), + cid: self.cid, + is_nonblocking: self.is_nonblocking, + }; + Ok(( - Arc::new(async_lock::RwLock::new(NullSocket::new().into())), + Arc::new(async_lock::RwLock::new(conn.into())), Endpoint::Vsock(endpoint), )) } - async fn shutdown(&self, _how: i32) -> io::Result<()> { - Ok(()) + async fn shutdown(&self, how: i32) -> io::Result<()> { + // Validate `how` for parity with the other socket types. This does not + // yet emit an `Op::Shutdown` to the peer, so a remote `read` is not + // woken with EOF by a local shutdown. + match how { + SHUT_RD | SHUT_WR | SHUT_RDWR => Ok(()), + _ => Err(Errno::Inval), + } } async fn status_flags(&self) -> io::Result { @@ -315,10 +366,9 @@ impl ObjectInterface for Socket { } async fn read(&self, buf: &mut [u8]) -> io::Result { - let port = self.port; future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); - let raw = guard.get_mut_socket(port).ok_or(Errno::Inval)?; + let raw = self.raw_mut(&mut guard).ok_or(Errno::Inval)?; match raw.state { VsockState::Connected => { @@ -338,19 +388,29 @@ impl ObjectInterface for Socket { Poll::Ready(Ok(len)) } } - VsockState::Shutdown => { + VsockState::Shutdown | VsockState::Reset => { let len = core::cmp::min(buf.len(), raw.buffer.len()); - if len == 0 { - Poll::Ready(Ok(0)) - } else { + if len != 0 { + // Deliver any data buffered before the peer closed or + // reset the connection. let tmp: Vec<_> = raw.buffer.drain(..len).collect(); buf[..len].copy_from_slice(tmp.as_slice()); Poll::Ready(Ok(len)) + } else if raw.state == VsockState::Reset { + // Abortive close with no remaining data: surface + // ECONNRESET, matching Linux `AF_VSOCK`. + Poll::Ready(Err(Errno::Connreset)) + } else { + // Graceful shutdown, buffer drained: report EOF. + Poll::Ready(Ok(0)) } } - _ => Poll::Ready(Err(Errno::Io)), + // A connection-keyed socket only ever holds Connected, Shutdown, + // or Reset; the remaining states cannot occur here. Treat them + // as EOF defensively. + _ => Poll::Ready(Ok(0)), } }) .await @@ -360,7 +420,7 @@ impl ObjectInterface for Socket { let port = self.port; future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); - let raw = guard.get_mut_socket(port).ok_or(Errno::Inval)?; + let raw = self.raw_mut(&mut guard).ok_or(Errno::Inval)?; let diff = raw.tx_cnt.abs_diff(raw.peer_fwd_cnt); match raw.state { @@ -406,7 +466,12 @@ impl ObjectInterface for Socket { Poll::Ready(Ok(len)) } } - _ => Poll::Ready(Err(Errno::Io)), + // Peer reset the connection: writing fails with ECONNRESET. + VsockState::Reset => Poll::Ready(Err(Errno::Connreset)), + // Peer closed its receive half (graceful shutdown) or the + // connection is otherwise gone: writing fails with EPIPE, + // matching Linux `AF_VSOCK`. + _ => Poll::Ready(Err(Errno::Pipe)), } }) .await @@ -416,6 +481,9 @@ impl ObjectInterface for Socket { impl Drop for Socket { fn drop(&mut self) { let mut guard = VSOCK_MAP.lock(); - guard.remove_socket(self.port); + match self.conn { + Some(key) => guard.remove_connection(key), + None => guard.remove_socket(self.port), + } } } diff --git a/src/syscalls/socket/mod.rs b/src/syscalls/socket/mod.rs index d0090b3e80..f30af5d3bb 100644 --- a/src/syscalls/socket/mod.rs +++ b/src/syscalls/socket/mod.rs @@ -684,7 +684,10 @@ pub unsafe extern "C" fn sys_accept( |v| { block_on(async { v.write().await.accept().await }, None).map_or_else( |e| -i32::from(e), - #[cfg_attr(not(feature = "net"), expect(unused_variables))] + #[cfg_attr( + not(any(feature = "net", feature = "virtio-vsock")), + expect(unused_variables) + )] |(obj, endpoint)| match endpoint { #[cfg(feature = "net")] Endpoint::Ip(endpoint) => { @@ -717,7 +720,7 @@ pub unsafe extern "C" fn sys_accept( } #[cfg(feature = "virtio-vsock")] Endpoint::Vsock(endpoint) => { - let new_fd = insert_object(v.clone()).unwrap(); + let new_fd = insert_object(obj).unwrap(); if !addr.is_null() && !addrlen.is_null() { let addrlen = unsafe { &mut *addrlen }; diff --git a/xtask/src/ci/qemu.rs b/xtask/src/ci/qemu.rs index 0e60815491..1643bca362 100644 --- a/xtask/src/ci/qemu.rs +++ b/xtask/src/ci/qemu.rs @@ -94,7 +94,7 @@ impl Qemu { pub fn run( self, image: &Path, - features: &[String], + _features: &[String], smp: usize, arch: Arch, small: bool, @@ -169,13 +169,7 @@ impl Qemu { "mioudp" => test_mioudp(guest_ip)?, "poll" => test_poll(guest_ip)?, "stdin" => test_stdin(&mut qemu.0)?, - "vsock" => { - let has_client = features - .iter() - .flat_map(|s| s.split(&[' ', ','][..])) - .any(|feature| feature == "client"); - test_vsock(has_client)?; - } + "vsock" => test_vsock()?, _ => {} } @@ -598,28 +592,56 @@ fn test_stdin(child: &mut Child) -> Result<()> { Ok(()) } -fn test_vsock(has_client: bool) -> Result<()> { - let mut stream = if has_client { - let listener = VsockListener::bind_with_cid_port(vsock::VMADDR_CID_ANY, 9975)?; - let (stream, _addr) = listener.accept()?; - stream - } else { - thread::sleep(Duration::from_secs(10)); - VsockStream::connect_with_cid_port(3, 9975)? - }; +fn test_vsock() -> Result<()> { + // The VM first connects out to us; we listen, send some messages, and check + // the VM echoes them back (#880). Then the VM listens and we open + // PING_PONG_CONNECTIONS connections to it, each expecting a "pong" reply to + // our "ping". The VM accepting more than one connection is the regression + // test for hermit-os/kernel#2433. + const ECHO_PORT: u32 = 9975; + const PING_PONG_PORT: u32 = 9976; + const GUEST_CID: u32 = 3; + const PING_PONG_CONNECTIONS: usize = 2; + + // Example 1: we listen, the VM connects, we verify it echoes our messages. + { + let listener = VsockListener::bind_with_cid_port(vsock::VMADDR_CID_ANY, ECHO_PORT)?; + let (mut stream, _addr) = listener.accept()?; + + let messages = ["Hello, there!", "Hello, again!", "Bye-bye!"]; + for message in messages { + writeln!(&mut stream, "{message}")?; + thread::sleep(Duration::from_secs(1)); + } - let messages = ["Hello, there!", "Hello, again!", "Bye-bye!"]; - for message in messages { - writeln!(&mut stream, "{message}")?; - thread::sleep(Duration::from_secs(1)); + const BUF_SIZE: usize = 8 * 1024; + let mut buf = vec![0; BUF_SIZE]; + let n = stream.read(&mut buf)?; + let s = str::from_utf8(&buf[0..n])?; + let received_messages = s.trim().split('\n').collect::>(); + assert_eq!(received_messages, messages); + // Drop `stream`/`listener` here so the VM sees EOF and moves on to + // listening for the ping/pong phase. } - const BUF_SIZE: usize = 8 * 1024; - let mut buf = vec![0; BUF_SIZE]; - let n = stream.read(&mut buf)?; - let s = str::from_utf8(&buf[0..n])?; - let received_messages = s.trim().split('\n').collect::>(); - assert_eq!(received_messages, messages); + // Example 2: the VM listens; we connect, send "ping", expect "pong". + // Give the VM time to start listening on the ping/pong port before we + // connect. + thread::sleep(Duration::from_secs(1)); + + let do_ping_pong = || -> Result<()> { + let mut stream = VsockStream::connect_with_cid_port(GUEST_CID, PING_PONG_PORT)?; + stream.write_all(b"ping")?; + let mut buf = [0u8; 64]; + let n = stream.read(&mut buf)?; + let msg = from_utf8(&buf[..n])?; + ensure!(msg == "pong", "expected 'pong', got {msg:?}"); + Ok(()) + }; + + for _ in 0..PING_PONG_CONNECTIONS { + do_ping_pong()?; + } Ok(()) }