Skip to content

Commit 0165830

Browse files
authored
Merge pull request matplotlib#17358 from QuLogic/masked-CubicTriInterpolator
Fix masked CubicTriInterpolator
2 parents eb27af0 + e737d0d commit 0165830

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

lib/matplotlib/tests/test_triangulation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,24 @@ def test_trirefine():
935935
assert_array_almost_equal(xyz_data[0], xyz_data[1])
936936

937937

938+
@pytest.mark.parametrize('interpolator',
939+
[mtri.LinearTriInterpolator,
940+
mtri.CubicTriInterpolator],
941+
ids=['linear', 'cubic'])
942+
def test_trirefine_masked(interpolator):
943+
# Repeated points means we will have fewer triangles than points, and thus
944+
# get masking.
945+
x, y = np.mgrid[:2, :2]
946+
x = np.repeat(x.flatten(), 2)
947+
y = np.repeat(y.flatten(), 2)
948+
949+
z = np.zeros_like(x)
950+
tri = mtri.Triangulation(x, y)
951+
refiner = mtri.UniformTriRefiner(tri)
952+
interp = interpolator(tri, z)
953+
refiner.refine_field(z, triinterpolator=interp, subdiv=2)
954+
955+
938956
def meshgrid_triangles(n):
939957
"""
940958
Return (2*(N-1)**2, 3) array of triangles to mesh (N, N)-point np.meshgrid.

lib/matplotlib/tri/triinterpolate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,8 @@ def __init__(self, triangulation, z, kind='min_E', trifinder=None,
400400
self._triangles = compressed_triangles
401401
self._tri_renum = tri_renum
402402
# Taking into account the node renumbering in self._z:
403-
node_mask = (node_renum == -1)
404-
self._z[node_renum[~node_mask]] = self._z
405-
self._z = self._z[~node_mask]
403+
valid_node = (node_renum != -1)
404+
self._z[node_renum[valid_node]] = self._z[valid_node]
406405

407406
# Computing scale factors
408407
self._unit_x = np.ptp(compressed_x)

lib/matplotlib/tri/tritools.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,17 @@ def _get_compressed_triangulation(self):
220220
tri_mask = self._triangulation.mask
221221
compressed_triangles = self._triangulation.get_masked_triangles()
222222
ntri = self._triangulation.triangles.shape[0]
223-
tri_renum = self._total_to_compress_renum(tri_mask, ntri)
223+
if tri_mask is not None:
224+
tri_renum = self._total_to_compress_renum(~tri_mask)
225+
else:
226+
tri_renum = np.arange(ntri, dtype=np.int32)
224227

225228
# Valid nodes and renumbering
226-
node_mask = (np.bincount(np.ravel(compressed_triangles),
227-
minlength=self._triangulation.x.size) == 0)
228-
compressed_x = self._triangulation.x[~node_mask]
229-
compressed_y = self._triangulation.y[~node_mask]
230-
node_renum = self._total_to_compress_renum(node_mask)
229+
valid_node = (np.bincount(np.ravel(compressed_triangles),
230+
minlength=self._triangulation.x.size) != 0)
231+
compressed_x = self._triangulation.x[valid_node]
232+
compressed_y = self._triangulation.y[valid_node]
233+
node_renum = self._total_to_compress_renum(valid_node)
231234

232235
# Now renumbering the valid triangles nodes
233236
compressed_triangles = node_renum[compressed_triangles]
@@ -236,32 +239,25 @@ def _get_compressed_triangulation(self):
236239
node_renum)
237240

238241
@staticmethod
239-
def _total_to_compress_renum(mask, n=None):
242+
def _total_to_compress_renum(valid):
240243
"""
241244
Parameters
242245
----------
243-
mask : 1d bool array or None
244-
mask
245-
n : int
246-
length of the mask. Useful only id mask can be None
246+
valid : 1d bool array
247+
Validity mask.
247248
248249
Returns
249250
-------
250251
int array
251-
array so that (`valid_array` being a compressed array
252-
based on a `masked_array` with mask *mask*):
252+
Array so that (`valid_array` being a compressed array
253+
based on a `masked_array` with mask ~*valid*):
253254
254-
- For all i such as mask[i] = False:
255+
- For all i with valid[i] = True:
255256
valid_array[renum[i]] = masked_array[i]
256-
- For all i such as mask[i] = True:
257+
- For all i with valid[i] = False:
257258
renum[i] = -1 (invalid value)
258259
"""
259-
if n is None:
260-
n = np.size(mask)
261-
if mask is not None:
262-
renum = np.full(n, -1, dtype=np.int32) # Default num is -1
263-
valid = np.arange(n, dtype=np.int32)[~mask]
264-
renum[valid] = np.arange(np.size(valid, 0), dtype=np.int32)
265-
return renum
266-
else:
267-
return np.arange(n, dtype=np.int32)
260+
renum = np.full(np.size(valid), -1, dtype=np.int32)
261+
n_valid = np.sum(valid)
262+
renum[valid] = np.arange(n_valid, dtype=np.int32)
263+
return renum

0 commit comments

Comments
 (0)