Skip to content

gh-116738: Make grp module thread-safe #135434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions Lib/test/support/threading_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,27 @@ def requires_working_threading(*, module=False):
raise unittest.SkipTest(msg)
else:
return unittest.skipUnless(can_start_thread, msg)


def run_concurrently(worker_func, args, nthreads):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = threading.Barrier(nthreads)

def wrapper_func(*args):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)

with catch_threading_exception() as cm:
workers = (
threading.Thread(target=wrapper_func, args=args)
for _ in range(nthreads)
)
with start_threads(workers):
pass

# If a worker thread raises an exception, re-raise it.
if cm.exc_value is not None:
raise cm.exc_value
36 changes: 36 additions & 0 deletions Lib/test/test_free_threading/test_grp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from test.support import import_helper, threading_helper
from test.support.threading_helper import run_concurrently

grp = import_helper.import_module("grp")

from test import test_grp


NTHREADS = 10


@threading_helper.requires_working_threading()
class TestGrp(unittest.TestCase):
def setUp(self):
self.test_grp = test_grp.GroupDatabaseTestCase()

def test_racing_test_values(self):
# test_grp.test_values() calls grp.getgrall() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values, args=(), nthreads=NTHREADS
)

def test_racing_test_values_extended(self):
# test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
# grp.getgrnam() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values_extended,
args=(),
nthreads=NTHREADS,
)


if __name__ == "__main__":
unittest.main()
43 changes: 11 additions & 32 deletions Lib/test/test_free_threading/test_heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import heapq

from enum import Enum
from threading import Thread, Barrier
from random import shuffle, randint

from test.support import threading_helper
from test.support.threading_helper import run_concurrently
from test import test_heapq


Expand All @@ -28,7 +28,7 @@ def test_racing_heapify(self):
heap = list(range(OBJECT_COUNT))
shuffle(heap)

self.run_concurrently(
run_concurrently(
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
)
self.test_heapq.check_invariant(heap)
Expand All @@ -40,7 +40,7 @@ def heappush_func(heap):
for item in reversed(range(OBJECT_COUNT)):
heapq.heappush(heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
)
self.test_heapq.check_invariant(heap)
Expand All @@ -61,7 +61,7 @@ def heappop_func(heap, pop_count):
# Each local list should be sorted
self.assertTrue(self.is_sorted_ascending(local_list))

self.run_concurrently(
run_concurrently(
worker_func=heappop_func,
args=(heap, per_thread_pop_count),
nthreads=NTHREADS,
Expand All @@ -77,7 +77,7 @@ def heappushpop_func(heap, pushpop_items):
popped_item = heapq.heappushpop(heap, item)
self.assertTrue(popped_item <= item)

self.run_concurrently(
run_concurrently(
worker_func=heappushpop_func,
args=(heap, pushpop_items),
nthreads=NTHREADS,
Expand All @@ -93,7 +93,7 @@ def heapreplace_func(heap, replace_items):
for item in replace_items:
heapq.heapreplace(heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heapreplace_func,
args=(heap, replace_items),
nthreads=NTHREADS,
Expand All @@ -105,7 +105,7 @@ def test_racing_heapify_max(self):
max_heap = list(range(OBJECT_COUNT))
shuffle(max_heap)

self.run_concurrently(
run_concurrently(
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
)
self.test_heapq.check_max_invariant(max_heap)
Expand All @@ -117,7 +117,7 @@ def heappush_max_func(max_heap):
for item in range(OBJECT_COUNT):
heapq.heappush_max(max_heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
)
self.test_heapq.check_max_invariant(max_heap)
Expand All @@ -138,7 +138,7 @@ def heappop_max_func(max_heap, pop_count):
# Each local list should be sorted
self.assertTrue(self.is_sorted_descending(local_list))

self.run_concurrently(
run_concurrently(
worker_func=heappop_max_func,
args=(max_heap, per_thread_pop_count),
nthreads=NTHREADS,
Expand All @@ -154,7 +154,7 @@ def heappushpop_max_func(max_heap, pushpop_items):
popped_item = heapq.heappushpop_max(max_heap, item)
self.assertTrue(popped_item >= item)

self.run_concurrently(
run_concurrently(
worker_func=heappushpop_max_func,
args=(max_heap, pushpop_items),
nthreads=NTHREADS,
Expand All @@ -170,7 +170,7 @@ def heapreplace_max_func(max_heap, replace_items):
for item in replace_items:
heapq.heapreplace_max(max_heap, item)

self.run_concurrently(
run_concurrently(
worker_func=heapreplace_max_func,
args=(max_heap, replace_items),
nthreads=NTHREADS,
Expand Down Expand Up @@ -214,27 +214,6 @@ def create_random_list(a, b, size):
"""
return [randint(-a, b) for _ in range(size)]

def run_concurrently(self, worker_func, args, nthreads):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = Barrier(nthreads)

def wrapper_func(*args):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)

with threading_helper.catch_threading_exception() as cm:
workers = (
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
)
with threading_helper.start_threads(workers):
pass

# Worker threads should not raise any exceptions
self.assertIsNone(cm.exc_value)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make methods in :mod:`grp` thread-safe on the :term:`free threaded <free threading>` build.
19 changes: 19 additions & 0 deletions Modules/grpmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,16 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)

Py_END_ALLOW_THREADS
#else
static PyMutex getgrgid_mutex = {0};
PyMutex_Lock(&getgrgid_mutex);
// The getgrgid() function need not be thread-safe.
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrgid.html
p = getgrgid(gid);
#endif
if (p == NULL) {
#ifndef HAVE_GETGRGID_R
PyMutex_Unlock(&getgrgid_mutex);
#endif
PyMem_RawFree(buf);
if (nomem == 1) {
return PyErr_NoMemory();
Expand All @@ -185,6 +192,8 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
retval = mkgrent(module, p);
#ifdef HAVE_GETGRGID_R
PyMem_RawFree(buf);
#else
PyMutex_Unlock(&getgrgid_mutex);
#endif
return retval;
}
Expand Down Expand Up @@ -249,9 +258,16 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)

Py_END_ALLOW_THREADS
#else
static PyMutex getgrnam_mutex = {0};
PyMutex_Lock(&getgrnam_mutex);
// The getgrnam() function need not be thread-safe.
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrnam.html
p = getgrnam(name_chars);
#endif
if (p == NULL) {
#ifndef HAVE_GETGRNAM_R
PyMutex_Unlock(&getgrnam_mutex);
#endif
if (nomem == 1) {
PyErr_NoMemory();
}
Expand All @@ -261,6 +277,9 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
goto out;
}
retval = mkgrent(module, p);
#ifndef HAVE_GETGRNAM_R
PyMutex_Unlock(&getgrnam_mutex);
#endif
out:
PyMem_RawFree(buf);
Py_DECREF(bytes);
Expand Down
Loading