drivers: nsos: handle multiple blocking APIs on single socket

So far only a single blocking API could be handled simultaneously, due to
epoll_ctl(..., EPOLL_CTL_ADD, ...) returning -EEXIST when same file
descriptor was added twice.

Follow 'man epoll' advice about using dup() syscall to create a duplicate
file descriptor, which can be used with different events masks. Use such
duplicate for each blocking API except ioctl() (for handling Zephyr poll()
syscall).

Signed-off-by: Marcin Niestroj <m.niestroj@emb.dev>
This commit is contained in:
Marcin Niestroj 2024-05-20 13:48:25 +02:00 committed by David Leach
parent 59a2d84f45
commit f45d3c81cc
3 changed files with 92 additions and 59 deletions

View file

@ -140,6 +140,8 @@ int nsos_adapt_fcntl_setfl(int fd, int flags);
int nsos_adapt_fionread(int fd, int *avail);
int nsos_adapt_dup(int oldfd);
int nsos_adapt_getaddrinfo(const char *node, const char *service,
const struct nsos_mid_addrinfo *hints,
struct nsos_mid_addrinfo **res,

View file

@ -1001,6 +1001,18 @@ int nsos_adapt_fionread(int fd, int *avail)
return 0;
}
int nsos_adapt_dup(int oldfd)
{
int ret;
ret = dup(oldfd);
if (ret < 0) {
return -errno_to_nsos_mid(errno);
}
return ret;
}
static void nsos_adapt_init(void)
{
nsos_epoll_fd = epoll_create(1);

View file

@ -37,17 +37,24 @@ BUILD_ASSERT(CONFIG_HEAP_MEM_POOL_SIZE > 0);
#define NSOS_IRQ_FLAGS (0)
#define NSOS_IRQ_PRIORITY (2)
struct nsos_socket {
int fd;
struct nsos_mid_pollfd pollfd;
struct k_poll_signal poll;
struct nsos_socket;
k_timeout_t recv_timeout;
struct nsos_socket_poll {
struct nsos_mid_pollfd mid;
struct k_poll_signal signal;
sys_dnode_t node;
};
static sys_dlist_t nsos_sockets = SYS_DLIST_STATIC_INIT(&nsos_sockets);
struct nsos_socket {
int fd;
k_timeout_t recv_timeout;
struct nsos_socket_poll poll;
};
static sys_dlist_t nsos_polls = SYS_DLIST_STATIC_INIT(&nsos_polls);
static int socket_family_to_nsos_mid(int family, int *family_mid)
{
@ -184,9 +191,9 @@ static int nsos_socket_create(int family, int type, int proto)
sock->fd = fd;
sock->recv_timeout = K_FOREVER;
sock->pollfd.fd = nsos_adapt_socket(family_mid, type_mid, proto_mid);
if (sock->pollfd.fd < 0) {
errno = errno_from_nsos_mid(-sock->pollfd.fd);
sock->poll.mid.fd = nsos_adapt_socket(family_mid, type_mid, proto_mid);
if (sock->poll.mid.fd < 0) {
errno = errno_from_nsos_mid(-sock->poll.mid.fd);
goto free_sock;
}
@ -213,7 +220,7 @@ static ssize_t nsos_read(void *obj, void *buf, size_t sz)
struct nsos_socket *sock = obj;
int ret;
ret = nsi_host_read(sock->pollfd.fd, buf, sz);
ret = nsi_host_read(sock->poll.mid.fd, buf, sz);
if (ret < 0) {
errno = nsos_adapt_get_zephyr_errno();
}
@ -226,7 +233,7 @@ static ssize_t nsos_write(void *obj, const void *buf, size_t sz)
struct nsos_socket *sock = obj;
int ret;
ret = nsi_host_write(sock->pollfd.fd, buf, sz);
ret = nsi_host_write(sock->poll.mid.fd, buf, sz);
if (ret < 0) {
errno = nsos_adapt_get_zephyr_errno();
}
@ -239,7 +246,7 @@ static int nsos_close(void *obj)
struct nsos_socket *sock = obj;
int ret;
ret = nsi_host_close(sock->pollfd.fd);
ret = nsi_host_close(sock->poll.mid.fd);
if (ret < 0) {
errno = nsos_adapt_get_zephyr_errno();
}
@ -247,33 +254,34 @@ static int nsos_close(void *obj)
return ret;
}
static void pollcb(struct nsos_mid_pollfd *pollfd)
static void pollcb(struct nsos_mid_pollfd *mid)
{
struct nsos_socket *sock = CONTAINER_OF(pollfd, struct nsos_socket, pollfd);
struct nsos_socket_poll *poll = CONTAINER_OF(mid, struct nsos_socket_poll, mid);
k_poll_signal_raise(&sock->poll, sock->pollfd.revents);
k_poll_signal_raise(&poll->signal, poll->mid.revents);
}
static int nsos_poll_prepare(struct nsos_socket *sock, struct zsock_pollfd *pfd,
struct k_poll_event **pev, struct k_poll_event *pev_end)
struct k_poll_event **pev, struct k_poll_event *pev_end,
struct nsos_socket_poll *poll)
{
unsigned int signaled;
int flags;
sock->pollfd.events = pfd->events;
sock->pollfd.revents = 0;
sock->pollfd.cb = pollcb;
poll->mid.events = pfd->events;
poll->mid.revents = 0;
poll->mid.cb = pollcb;
if (*pev == pev_end) {
return -ENOMEM;
}
k_poll_signal_init(&sock->poll);
k_poll_event_init(*pev, K_POLL_TYPE_SIGNAL, K_POLL_MODE_NOTIFY_ONLY, &sock->poll);
k_poll_signal_init(&poll->signal);
k_poll_event_init(*pev, K_POLL_TYPE_SIGNAL, K_POLL_MODE_NOTIFY_ONLY, &poll->signal);
sys_dlist_append(&nsos_sockets, &sock->node);
sys_dlist_append(&nsos_polls, &poll->node);
nsos_adapt_poll_add(&sock->pollfd);
nsos_adapt_poll_add(&poll->mid);
/* Let other sockets use another k_poll_event */
(*pev)++;
@ -281,7 +289,7 @@ static int nsos_poll_prepare(struct nsos_socket *sock, struct zsock_pollfd *pfd,
signaled = 0;
flags = 0;
k_poll_signal_check(&sock->poll, &signaled, &flags);
k_poll_signal_check(&poll->signal, &signaled, &flags);
if (!signaled) {
return 0;
}
@ -291,7 +299,7 @@ static int nsos_poll_prepare(struct nsos_socket *sock, struct zsock_pollfd *pfd,
}
static int nsos_poll_update(struct nsos_socket *sock, struct zsock_pollfd *pfd,
struct k_poll_event **pev)
struct k_poll_event **pev, struct nsos_socket_poll *poll)
{
unsigned int signaled;
int flags;
@ -301,15 +309,15 @@ static int nsos_poll_update(struct nsos_socket *sock, struct zsock_pollfd *pfd,
signaled = 0;
flags = 0;
if (!sys_dnode_is_linked(&sock->node)) {
nsos_adapt_poll_update(&sock->pollfd);
if (!sys_dnode_is_linked(&poll->node)) {
nsos_adapt_poll_update(&poll->mid);
return 0;
}
nsos_adapt_poll_remove(&sock->pollfd);
sys_dlist_remove(&sock->node);
nsos_adapt_poll_remove(&poll->mid);
sys_dlist_remove(&poll->node);
k_poll_signal_check(&sock->poll, &signaled, &flags);
k_poll_signal_check(&poll->signal, &signaled, &flags);
if (!signaled) {
return 0;
}
@ -333,7 +341,7 @@ static int nsos_ioctl(void *obj, unsigned int request, va_list args)
pev = va_arg(args, struct k_poll_event **);
pev_end = va_arg(args, struct k_poll_event *);
return nsos_poll_prepare(obj, pfd, pev, pev_end);
return nsos_poll_prepare(obj, pfd, pev, pev_end, &sock->poll);
}
case ZFD_IOCTL_POLL_UPDATE: {
@ -343,7 +351,7 @@ static int nsos_ioctl(void *obj, unsigned int request, va_list args)
pfd = va_arg(args, struct zsock_pollfd *);
pev = va_arg(args, struct k_poll_event **);
return nsos_poll_update(obj, pfd, pev);
return nsos_poll_update(obj, pfd, pev, &sock->poll);
}
case ZFD_IOCTL_POLL_OFFLOAD:
@ -352,7 +360,7 @@ static int nsos_ioctl(void *obj, unsigned int request, va_list args)
case F_GETFL: {
int flags;
flags = nsos_adapt_fcntl_getfl(sock->pollfd.fd);
flags = nsos_adapt_fcntl_getfl(sock->poll.mid.fd);
return fl_from_nsos_mid(flags);
}
@ -366,7 +374,7 @@ static int nsos_ioctl(void *obj, unsigned int request, va_list args)
return -errno_from_nsos_mid(-ret);
}
ret = nsos_adapt_fcntl_setfl(sock->pollfd.fd, flags);
ret = nsos_adapt_fcntl_setfl(sock->poll.mid.fd, flags);
return -errno_from_nsos_mid(-ret);
}
@ -375,7 +383,7 @@ static int nsos_ioctl(void *obj, unsigned int request, va_list args)
int *avail = va_arg(args, int *);
int ret;
ret = nsos_adapt_fionread(sock->pollfd.fd, avail);
ret = nsos_adapt_fionread(sock->poll.mid.fd, avail);
return -errno_from_nsos_mid(-ret);
}
@ -491,14 +499,22 @@ static int nsos_wait_for_poll(struct nsos_socket *sock, int events,
struct k_poll_event poll_events[1];
struct k_poll_event *pev = poll_events;
struct k_poll_event *pev_end = poll_events + ARRAY_SIZE(poll_events);
struct nsos_socket_poll socket_poll = {};
int ret;
ret = nsos_poll_prepare(sock, &pfd, &pev, pev_end);
ret = nsos_adapt_dup(sock->poll.mid.fd);
if (ret < 0) {
goto return_ret;
}
socket_poll.mid.fd = ret;
ret = nsos_poll_prepare(sock, &pfd, &pev, pev_end, &socket_poll);
if (ret == -EALREADY) {
ret = 0;
goto poll_update;
} else if (ret < 0) {
goto return_ret;
goto close_dup;
}
ret = k_poll(poll_events, ARRAY_SIZE(poll_events), timeout);
@ -510,7 +526,10 @@ static int nsos_wait_for_poll(struct nsos_socket *sock, int events,
poll_update:
pev = poll_events;
nsos_poll_update(sock, &pfd, &pev);
nsos_poll_update(sock, &pfd, &pev, &socket_poll);
close_dup:
nsi_host_close(socket_poll.mid.fd);
return_ret:
if (ret < 0) {
@ -529,7 +548,7 @@ static int nsos_poll_if_blocking(struct nsos_socket *sock, int events,
if (flags & ZSOCK_MSG_DONTWAIT) {
non_blocking = true;
} else {
sock_flags = nsos_adapt_fcntl_getfl(sock->pollfd.fd);
sock_flags = nsos_adapt_fcntl_getfl(sock->poll.mid.fd);
non_blocking = sock_flags & NSOS_MID_O_NONBLOCK;
}
@ -553,7 +572,7 @@ static int nsos_bind(void *obj, const struct sockaddr *addr, socklen_t addrlen)
goto return_ret;
}
ret = nsos_adapt_bind(sock->pollfd.fd, addr_mid, addrlen_mid);
ret = nsos_adapt_bind(sock->poll.mid.fd, addr_mid, addrlen_mid);
return_ret:
if (ret < 0) {
@ -577,7 +596,7 @@ static int nsos_connect(void *obj, const struct sockaddr *addr, socklen_t addrle
goto return_ret;
}
ret = nsos_adapt_connect(sock->pollfd.fd, addr_mid, addrlen_mid);
ret = nsos_adapt_connect(sock->poll.mid.fd, addr_mid, addrlen_mid);
return_ret:
if (ret < 0) {
@ -593,7 +612,7 @@ static int nsos_listen(void *obj, int backlog)
struct nsos_socket *sock = obj;
int ret;
ret = nsos_adapt_listen(sock->pollfd.fd, backlog);
ret = nsos_adapt_listen(sock->poll.mid.fd, backlog);
if (ret < 0) {
errno = errno_from_nsos_mid(-ret);
return -1;
@ -619,7 +638,7 @@ static int nsos_accept(void *obj, struct sockaddr *addr, socklen_t *addrlen)
goto return_ret;
}
ret = nsos_adapt_accept(accept_sock->pollfd.fd, addr_mid, &addrlen_mid);
ret = nsos_adapt_accept(accept_sock->poll.mid.fd, addr_mid, &addrlen_mid);
if (ret < 0) {
goto return_ret;
}
@ -644,7 +663,7 @@ static int nsos_accept(void *obj, struct sockaddr *addr, socklen_t *addrlen)
}
conn_sock->fd = zephyr_fd;
conn_sock->pollfd.fd = adapt_fd;
conn_sock->poll.mid.fd = adapt_fd;
z_finalize_fd(zephyr_fd, conn_sock, &nsos_socket_fd_op_vtable.fd_vtable);
@ -683,7 +702,7 @@ static ssize_t nsos_sendto(void *obj, const void *buf, size_t len, int flags,
goto return_ret;
}
ret = nsos_adapt_sendto(sock->pollfd.fd, buf, len, flags_mid,
ret = nsos_adapt_sendto(sock->poll.mid.fd, buf, len, flags_mid,
addr_mid, addrlen_mid);
return_ret:
@ -737,7 +756,7 @@ static ssize_t nsos_sendmsg(void *obj, const struct msghdr *msg, int flags)
msg_mid.msg_controllen = 0;
msg_mid.msg_flags = 0;
ret = nsos_adapt_sendmsg(sock->pollfd.fd, &msg_mid, flags_mid);
ret = nsos_adapt_sendmsg(sock->poll.mid.fd, &msg_mid, flags_mid);
k_free(msg_iov);
@ -772,7 +791,7 @@ static ssize_t nsos_recvfrom(void *obj, void *buf, size_t len, int flags,
goto return_ret;
}
ret = nsos_adapt_recvfrom(sock->pollfd.fd, buf, len, flags_mid,
ret = nsos_adapt_recvfrom(sock->poll.mid.fd, buf, len, flags_mid,
addr_mid, &addrlen_mid);
if (ret < 0) {
goto return_ret;
@ -878,7 +897,7 @@ static int nsos_getsockopt_int(struct nsos_socket *sock, int nsos_mid_level, int
return -1;
}
err = nsos_adapt_getsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_getsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_KEEPALIVE, optval, &nsos_mid_optlen);
if (err) {
errno = errno_from_nsos_mid(-err);
@ -907,7 +926,7 @@ static int nsos_getsockopt(void *obj, int level, int optname,
return -1;
}
err = nsos_adapt_getsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_getsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_ERROR, &nsos_mid_err, NULL);
if (err) {
errno = errno_from_nsos_mid(-err);
@ -927,7 +946,7 @@ static int nsos_getsockopt(void *obj, int level, int optname,
return -1;
}
err = nsos_adapt_getsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_getsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_TYPE, &nsos_mid_type, NULL);
if (err) {
errno = errno_from_nsos_mid(-err);
@ -951,7 +970,7 @@ static int nsos_getsockopt(void *obj, int level, int optname,
return -1;
}
err = nsos_adapt_getsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_getsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_PROTOCOL, &nsos_mid_proto, NULL);
if (err) {
errno = errno_from_nsos_mid(-err);
@ -975,7 +994,7 @@ static int nsos_getsockopt(void *obj, int level, int optname,
return -1;
}
err = nsos_adapt_getsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_getsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_DOMAIN, &nsos_mid_family, NULL);
if (err) {
errno = errno_from_nsos_mid(-err);
@ -1055,7 +1074,7 @@ static int nsos_setsockopt_int(struct nsos_socket *sock, int nsos_mid_level, int
return -1;
}
err = nsos_adapt_setsockopt(sock->pollfd.fd, nsos_mid_level, nsos_mid_optname,
err = nsos_adapt_setsockopt(sock->poll.mid.fd, nsos_mid_level, nsos_mid_optname,
optval, optlen);
if (err) {
errno = errno_from_nsos_mid(-err);
@ -1084,7 +1103,7 @@ static int nsos_setsockopt(void *obj, int level, int optname,
nsos_mid_priority = *(uint8_t *)optval;
err = nsos_adapt_setsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_setsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_PRIORITY, &nsos_mid_priority,
sizeof(nsos_mid_priority));
if (err) {
@ -1107,7 +1126,7 @@ static int nsos_setsockopt(void *obj, int level, int optname,
nsos_mid_tv.tv_sec = tv->tv_sec;
nsos_mid_tv.tv_usec = tv->tv_usec;
err = nsos_adapt_setsockopt(sock->pollfd.fd, NSOS_MID_SOL_SOCKET,
err = nsos_adapt_setsockopt(sock->poll.mid.fd, NSOS_MID_SOL_SOCKET,
NSOS_MID_SO_RCVTIMEO, &nsos_mid_tv,
sizeof(nsos_mid_tv));
if (err) {
@ -1359,11 +1378,11 @@ static const struct socket_dns_offload nsos_dns_ops = {
static void nsos_isr(const void *obj)
{
struct nsos_socket *sock;
struct nsos_socket_poll *poll;
SYS_DLIST_FOR_EACH_CONTAINER(&nsos_sockets, sock, node) {
if (sock->pollfd.revents) {
sock->pollfd.cb(&sock->pollfd);
SYS_DLIST_FOR_EACH_CONTAINER(&nsos_polls, poll, node) {
if (poll->mid.revents) {
poll->mid.cb(&poll->mid);
}
}
}