Skip to content

Commit 80aa06e

Browse files
authored
Windows: fix multi-thread safety (BoboTiG#159)
* Windows: fix multi-thread unsafe and update test On Windows, the handle of entire window device context is saved to `srcdc`. But the device context will be released once the thread who creates it has died, so that `srcdc` is no loner valid. Replace `srcdc` with `srcdc_dict` to maintain srcdc values created by multiple threads which ensure the validity of srcdc when it's used. A threading lock is add to prevent multiple threads from grabbing and modifying shared class attributes `bmp`/`srcdc`/`memdc` (their windows object in fact) at same time. Otherwise, unexpected screenshot or unpredictable error will occur. Add test_thread_safety in test_windows.py
1 parent a8cb2d2 commit 80aa06e

File tree

5 files changed

+60
-6
lines changed

5 files changed

+60
-6
lines changed

CHANGELOG

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ History:
77
- MSS: renamed again MSSMixin to MSSBase, now derived from abc.ABCMeta
88
- tools: force write of file when saving a PNG file
99
- tests: fix tests on macOS with Retina display
10+
- Windows: fixed multi-thread safety (fixes #150)
11+
- :heart: contributors: @narumishi
1012

1113
5.0.0 2019/12/31
1214
- removed support for Python 2.7

CHANGES.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ base.py
88
- ``MSSBase.monitor`` is now an abstract property
99
- ``MSSBase.grab()`` is now an abstract method
1010

11+
windows.py
12+
----------
13+
- Replaced ``MSS.srcdc`` with ``MSS.srcdc_dict``
14+
1115

1216
5.0.0 (2019-12-31)
1317
==================

CONTRIBUTORS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ Jochen 'cycomanic' Schroeder [https://github.com/cycomanic]
3030
Karan Lyons <[email protected]> [https://karanlyons.com] [https://github.com/karanlyons]
3131
- MacOS: Proper support for display scaling
3232

33+
narumi [https://github.com/narumishi]
34+
- Windows: fix multi-thread unsafe
35+
3336
Oros <[email protected]> [https://ecirtam.net]
3437
- GNU/Linux tester
3538

mss/tests/test_windows.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import platform
7+
import threading
78

89
import mss
910
import pytest
@@ -48,3 +49,23 @@ def test_region_caching():
4849
# Grab the area 2 again, the cached BMP is used
4950
sct.grab(region2)
5051
assert bmp2 is MSS.bmp
52+
53+
54+
def run_child_thread(loops):
55+
"""Every loop will take about 1 second."""
56+
for _ in range(loops):
57+
with mss.mss() as sct:
58+
sct.grab(sct.monitors[1])
59+
60+
61+
def test_thread_safety():
62+
"""Thread safety test for issue #150.
63+
The following code will throw a ScreenShotError exception if thread-safety is not guaranted.
64+
"""
65+
# Let thread 1 finished ahead of thread 2
66+
thread1 = threading.Thread(target=run_child_thread, args=(30,))
67+
thread2 = threading.Thread(target=run_child_thread, args=(50,))
68+
thread1.start()
69+
thread2.start()
70+
thread1.join()
71+
thread2.join()

mss/windows.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import sys
77
import ctypes
8+
import threading
89
from ctypes.wintypes import (
910
BOOL,
1011
DOUBLE,
@@ -70,10 +71,15 @@ class MSS(MSSBase):
7071

7172
__slots__ = {"_bbox", "_bmi", "_data", "gdi32", "monitorenumproc", "user32"}
7273

73-
# Class attributes instancied one time to prevent resource leaks.
74+
# Class attributes instanced one time to prevent resource leaks.
7475
bmp = None
7576
memdc = None
76-
srcdc = None
77+
78+
# A dict to maintain *srcdc* values created by multiple threads.
79+
srcdc_dict = {}
80+
81+
# A threading lock to lock resources(bmp/memdc/srcdc) inside .grab() method.
82+
_lock = threading.Lock()
7783

7884
def __init__(self, **_):
7985
# type: (Any) -> None
@@ -93,9 +99,9 @@ def __init__(self, **_):
9399
self._bbox = {"height": 0, "width": 0}
94100
self._data = ctypes.create_string_buffer(0) # type: ctypes.Array[ctypes.c_char]
95101

96-
if not MSS.srcdc or not MSS.memdc:
97-
MSS.srcdc = self.user32.GetWindowDC(0)
98-
MSS.memdc = self.gdi32.CreateCompatibleDC(MSS.srcdc)
102+
srcdc = self._get_srcdc()
103+
if not MSS.memdc:
104+
MSS.memdc = self.gdi32.CreateCompatibleDC(srcdc)
99105

100106
bmi = BITMAPINFO()
101107
bmi.bmiHeader.biSize = ctypes.sizeof(BITMAPINFOHEADER)
@@ -174,6 +180,20 @@ def _set_dpi_awareness(self):
174180
# Windows Vista, 7, 8 and Server 2012
175181
self.user32.SetProcessDPIAware()
176182

183+
def _get_srcdc(self):
184+
"""
185+
Retrieve a thread-safe HDC from GetWindowDC().
186+
In multithreading, if the thread who creates *srcdc* is dead, *srcdc* will
187+
no longer be valid to grab the screen. The *srcdc* attribute is replaced
188+
with *srcdc_dict* to maintain the *srcdc* values in multithreading.
189+
Since the current thread and main thread are always alive, reuse their *srcdc* value first.
190+
"""
191+
cur_thread, main_thread = threading.current_thread(), threading.main_thread()
192+
srcdc = MSS.srcdc_dict.get(cur_thread) or MSS.srcdc_dict.get(main_thread)
193+
if not srcdc:
194+
srcdc = MSS.srcdc_dict[cur_thread] = self.user32.GetWindowDC(0)
195+
return srcdc
196+
177197
@property
178198
def monitors(self):
179199
# type: () -> Monitors
@@ -251,6 +271,9 @@ def grab(self, monitor):
251271
Thanks to http://stackoverflow.com/a/3688682
252272
"""
253273

274+
# Acquire lock to prevent resources from being modified by multiple threads at same time.
275+
MSS._lock.acquire()
276+
254277
# Convert PIL bbox style
255278
if isinstance(monitor, tuple):
256279
monitor = {
@@ -260,7 +283,7 @@ def grab(self, monitor):
260283
"height": monitor[3] - monitor[1],
261284
}
262285

263-
srcdc, memdc = MSS.srcdc, MSS.memdc
286+
srcdc, memdc = self._get_srcdc(), MSS.memdc
264287
width, height = monitor["width"], monitor["height"]
265288

266289
if (self._bbox["height"], self._bbox["width"]) != (height, width):
@@ -287,6 +310,7 @@ def grab(self, monitor):
287310
bits = self.gdi32.GetDIBits(
288311
memdc, MSS.bmp, 0, height, self._data, self._bmi, DIB_RGB_COLORS
289312
)
313+
MSS._lock.release()
290314
if bits != height:
291315
raise ScreenShotError("gdi32.GetDIBits() failed.")
292316

0 commit comments

Comments
 (0)