net: socket: socketpair: mitigate possible race condition

There was a possible race condition between sock_is_nonblock()
and k_sem_take() in spair_read() and spair_write() that was
mitigated.

Also clarified some of the conditional branching in those
functions.

Signed-off-by: Christopher Friedt <chrisfriedt@gmail.com>
This commit is contained in:
Christopher Friedt 2020-05-19 16:36:20 -04:00 committed by Carles Cufí
parent 2593a919ee
commit 6161ea2542

View file

@ -369,13 +369,13 @@ out:
static ssize_t spair_write(void *obj, const void *buffer, size_t count)
{
int res;
bool is_connected;
int key;
size_t avail;
bool is_nonblock;
bool will_block;
size_t bytes_written;
bool have_local_sem = false;
bool have_remote_sem = false;
bool will_block = false;
struct spair *const spair = (struct spair *)obj;
struct spair *remote = NULL;
@ -385,9 +385,10 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count)
goto out;
}
key = irq_lock();
is_nonblock = sock_is_nonblock(spair);
res = k_sem_take(&spair->sem, K_NO_WAIT);
irq_unlock(key);
if (res < 0) {
if (is_nonblock) {
errno = EAGAIN;
@ -401,6 +402,7 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count)
res = -1;
goto out;
}
is_nonblock = sock_is_nonblock(spair);
}
have_local_sem = true;
@ -408,10 +410,7 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count)
remote = z_get_fd_obj(spair->remote,
(const struct fd_op_vtable *)&spair_fd_op_vtable, 0);
is_connected = sock_is_connected(spair);
is_nonblock = sock_is_nonblock(spair);
if (!is_connected) {
if (remote == NULL) {
errno = EPIPE;
res = -1;
goto out;
@ -434,14 +433,17 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count)
have_remote_sem = true;
avail = is_connected ? spair_write_avail(spair) : 0;
if (avail == 0 && is_nonblock) {
errno = EAGAIN;
res = -1;
goto out;
avail = spair_write_avail(spair);
if (avail == 0) {
if (is_nonblock) {
errno = EAGAIN;
res = -1;
goto out;
}
will_block = true;
}
will_block = (count > avail) && !is_nonblock;
if (will_block) {
for (int signaled = false, result = -1; !signaled;
@ -464,6 +466,16 @@ static ssize_t spair_write(void *obj, const void *buffer, size_t count)
goto out;
}
remote = z_get_fd_obj(spair->remote,
(const struct fd_op_vtable *)
&spair_fd_op_vtable, 0);
if (remote == NULL) {
errno = EPIPE;
res = -1;
goto out;
}
res = k_sem_take(&remote->sem, K_NO_WAIT);
if (res < 0) {
if (is_nonblock) {
@ -569,14 +581,13 @@ out:
static ssize_t spair_read(void *obj, void *buffer, size_t count)
{
int res;
int key;
bool is_connected;
size_t avail;
bool is_nonblock;
bool will_block;
size_t bytes_read;
bool have_local_sem = false;
bool will_block = false;
struct spair *const spair = (struct spair *)obj;
if (obj == NULL || buffer == NULL || count == 0) {
@ -585,9 +596,10 @@ static ssize_t spair_read(void *obj, void *buffer, size_t count)
goto out;
}
key = irq_lock();
is_nonblock = sock_is_nonblock(spair);
res = k_sem_take(&spair->sem, K_NO_WAIT);
irq_unlock(key);
if (res < 0) {
if (is_nonblock) {
errno = EAGAIN;
@ -601,24 +613,28 @@ static ssize_t spair_read(void *obj, void *buffer, size_t count)
res = -1;
goto out;
}
is_nonblock = sock_is_nonblock(spair);
}
have_local_sem = true;
is_connected = sock_is_connected(spair);
avail = spair_read_avail(spair);
will_block = (avail == 0) && !is_nonblock;
if (avail == 0 && !is_connected) {
/* signal EOF */
res = 0;
goto out;
}
if (avail == 0) {
if (!is_connected) {
/* signal EOF */
res = 0;
goto out;
}
if (avail == 0 && is_nonblock) {
errno = EAGAIN;
res = -1;
goto out;
if (is_nonblock) {
errno = EAGAIN;
res = -1;
goto out;
}
will_block = true;
}
if (will_block) {