Skip to content

Commit c5edd0f

Browse files
committed
Drop values when copying collections where we only care about keys
1 parent ac1b054 commit c5edd0f

File tree

7 files changed

+52
-26
lines changed

7 files changed

+52
-26
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_dict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def foo(**kwargs):
172172

173173
assert foo(a=5, b=6) == {'a': 1, 'b': 1}
174174

175+
d = dict.fromkeys({'a': 1, 'b': 2, 'c': 3})
176+
assert len(d) == 3
177+
assert set(d.keys()) == {'a', 'b', 'c'}
178+
assert set(d.values()) == {None}
179+
175180

176181
def test_init():
177182
d = dict(a=1, b=2, c=3)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/pickle/PUnpickler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ protected HashingStorage getClonedHashingStorage(VirtualFrame frame, Object obj)
769769
CompilerDirectives.transferToInterpreterAndInvalidate();
770770
getHashingStorageNode = insert(HashingCollectionNodes.GetClonedHashingStorageNode.create());
771771
}
772-
return getHashingStorageNode.doNoValueCached(frame, obj);
772+
return getHashingStorageNode.getForSetsCached(frame, obj);
773773
}
774774

775775
private void setItem(VirtualFrame frame, Object object, Object key, Object value) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingCollectionNodes.java

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import com.oracle.graal.python.builtins.objects.dict.DictNodes;
5656
import com.oracle.graal.python.builtins.objects.dict.PDict;
5757
import com.oracle.graal.python.builtins.objects.dict.PDictView;
58+
import com.oracle.graal.python.builtins.objects.set.PBaseSet;
5859
import com.oracle.graal.python.lib.IteratorExhausted;
5960
import com.oracle.graal.python.lib.PyIterNextNode;
6061
import com.oracle.graal.python.lib.PyObjectGetIter;
@@ -150,42 +151,62 @@ protected static boolean isEconomicMapStorage(Object o) {
150151
}
151152

152153
/**
153-
* Gets clone of the keys of the storage with all values either set to given value or with no
154-
* guarantees about the values if {@link PNone#NO_VALUE} is passed as {@code value}.
154+
* Gets clone of the keys of the storage with all values set to given value or (when used to
155+
* create a set or frozenset) to NO_VALUE.
155156
*/
156157
@GenerateInline(inlineByDefault = true)
157158
public abstract static class GetClonedHashingStorageNode extends PNodeWithContext {
158-
public abstract HashingStorage execute(VirtualFrame frame, Node inliningTarget, Object iterator, Object value);
159+
protected abstract HashingStorage execute(VirtualFrame frame, Node inliningTarget, Object iterator, Object value);
159160

160-
public final HashingStorage doNoValue(VirtualFrame frame, Node inliningTarget, Object iterator) {
161+
/**
162+
* Gets clone of the keys of the storage with all values either set to given value or, if
163+
* that is PNone.NO_VALUE, all values set to PNone.NONE. Use this method to clone into a
164+
* dict or other object where the values may be accessible from Python to avoid a)
165+
* PNone.NO_VALUE leaking to Python.
166+
*/
167+
public final HashingStorage getForDictionaries(VirtualFrame frame, Node inliningTarget, Object iterator, Object value) {
168+
return execute(frame, inliningTarget, iterator, value == PNone.NO_VALUE ? PNone.NONE : value);
169+
}
170+
171+
/**
172+
* Gets a clone of the keys of the storage with all values set to NO_VALUE. This must be
173+
* used *only* to create new storages for use in sets and frozensets where the values cannot
174+
* be accessed from user code.
175+
*/
176+
public final HashingStorage getForSets(VirtualFrame frame, Node inliningTarget, Object iterator) {
161177
return execute(frame, inliningTarget, iterator, PNone.NO_VALUE);
162178
}
163179

164-
public final HashingStorage doNoValueCached(VirtualFrame frame, Object iterator) {
180+
/**
181+
* IMPORTANT: Only for sets and frozensets.
182+
*
183+
* @see #getForSets(VirtualFrame, Node, Object)
184+
*/
185+
public final HashingStorage getForSetsCached(VirtualFrame frame, Object iterator) {
165186
return execute(frame, null, iterator, PNone.NO_VALUE);
166187
}
167188

168-
@Specialization(guards = "isNoValue(value)")
169-
static HashingStorage doHashingCollectionNoValue(Node inliningTarget, PHashingCollection other, @SuppressWarnings("unused") Object value,
170-
@Shared("copyNode") @Cached HashingStorageCopy copyNode) {
189+
// This for cloning sets (we come here from doNoValue or doNoValueCached). If we clone from
190+
// some other PHashingCollection, we would hold on to keys in the sets, and if we were to
191+
// clone for some other PHashingCollection (not PBaseSet), we might leak NO_VALUE into user
192+
// code.
193+
@Specialization(guards = "isNoValue(givenValue)")
194+
static HashingStorage doSet(Node inliningTarget, PBaseSet other, @SuppressWarnings("unused") Object givenValue,
195+
@Cached HashingStorageCopy copyNode) {
171196
return copyNode.execute(inliningTarget, other.getDictStorage());
172197
}
173198

174-
@Specialization(guards = "isNoValue(value)")
175-
static HashingStorage doPDictKeyViewNoValue(Node inliningTarget, PDictView.PDictKeysView other, Object value,
176-
@Shared("copyNode") @Cached HashingStorageCopy copyNode) {
177-
return copyNode.execute(inliningTarget, other.getWrappedStorage());
178-
}
179-
180-
@Specialization(guards = "!isNoValue(value)")
181-
static HashingStorage doHashingCollection(VirtualFrame frame, PHashingCollection other, Object value,
199+
@Specialization(replaces = "doSet")
200+
static HashingStorage doHashingCollection(VirtualFrame frame, PHashingCollection other, Object givenValue,
182201
@Shared @Cached(inline = false) GetClonedHashingCollectionNode hashingCollectionNode) {
202+
Object value = givenValue == PNone.NO_VALUE ? PNone.NONE : givenValue;
183203
return hashingCollectionNode.execute(frame, other.getDictStorage(), value);
184204
}
185205

186-
@Specialization(guards = "!isNoValue(value)")
187-
static HashingStorage doPDictView(VirtualFrame frame, PDictView.PDictKeysView other, Object value,
206+
@Specialization
207+
static HashingStorage doPDictView(VirtualFrame frame, PDictView.PDictKeysView other, Object givenValue,
188208
@Shared @Cached(inline = false) GetClonedHashingCollectionNode hashingCollectionNode) {
209+
Object value = givenValue == PNone.NO_VALUE ? PNone.NONE : givenValue;
189210
return hashingCollectionNode.execute(frame, other.getWrappedStorage(), value);
190211
}
191212

@@ -282,7 +303,7 @@ static HashingStorage doPDictView(PDictView.PDictKeysView other) {
282303
@InliningCutoff
283304
static HashingStorage doGeneric(VirtualFrame frame, Node inliningTarget, Object other,
284305
@Cached GetClonedHashingStorageNode getHashingStorageNode) {
285-
return getHashingStorageNode.doNoValue(frame, inliningTarget, other);
306+
return getHashingStorageNode.getForSets(frame, inliningTarget, other);
286307
}
287308
}
288309
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DictBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ static Object doKeys(VirtualFrame frame, @SuppressWarnings("unused") Object cls,
606606
@SuppressWarnings("unused") @Cached IsSameTypeNode isSameTypeNode,
607607
@Cached HashingCollectionNodes.GetClonedHashingStorageNode getHashingStorageNode,
608608
@Bind PythonLanguage language) {
609-
HashingStorage s = getHashingStorageNode.execute(frame, inliningTarget, iterable, value);
609+
HashingStorage s = getHashingStorageNode.getForDictionaries(frame, inliningTarget, iterable, value);
610610
return PFactory.createDict(language, s);
611611
}
612612

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ static PFrozenSet frozensetIterable(VirtualFrame frame, Object cls, Object itera
119119
@Cached HashingCollectionNodes.GetClonedHashingStorageNode getHashingStorageNode,
120120
@Bind PythonLanguage language,
121121
@Cached TypeNodes.GetInstanceShape getInstanceShape) {
122-
HashingStorage storage = getHashingStorageNode.doNoValue(frame, inliningTarget, iterable);
122+
HashingStorage storage = getHashingStorageNode.getForSets(frame, inliningTarget, iterable);
123123
return PFactory.createFrozenSet(language, cls, getInstanceShape.execute(cls), storage);
124124
}
125125
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static PNone doNoValue(PSet self, @SuppressWarnings("unused") PNone iterable,
155155
static PNone doGeneric(VirtualFrame frame, PSet self, Object iterable,
156156
@Bind("this") Node inliningTarget,
157157
@Cached HashingCollectionNodes.GetClonedHashingStorageNode getHashingStorageNode) {
158-
HashingStorage storage = getHashingStorageNode.doNoValue(frame, inliningTarget, iterable);
158+
HashingStorage storage = getHashingStorageNode.getForSets(frame, inliningTarget, iterable);
159159
self.setDictStorage(storage);
160160
return PNone.NONE;
161161
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/types/UnionTypeBuiltins.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ static long hash(VirtualFrame frame, PUnionType self,
180180
@Cached PyObjectHashNode hashNode,
181181
@Cached HashingCollectionNodes.GetClonedHashingStorageNode getHashingStorageNode,
182182
@Bind PythonLanguage language) {
183-
PFrozenSet argSet = PFactory.createFrozenSet(language, getHashingStorageNode.doNoValue(frame, inliningTarget, self.getArgs()));
183+
PFrozenSet argSet = PFactory.createFrozenSet(language, getHashingStorageNode.getForSets(frame, inliningTarget, self.getArgs()));
184184
return hashNode.execute(frame, inliningTarget, argSet);
185185
}
186186
}
@@ -273,8 +273,8 @@ static boolean eq(VirtualFrame frame, PUnionType self, PUnionType other, RichCmp
273273
@Cached HashingCollectionNodes.GetClonedHashingStorageNode getHashingStorageNode,
274274
@Cached PyObjectRichCompareBool eqNode,
275275
@Bind PythonLanguage language) {
276-
PFrozenSet argSet1 = PFactory.createFrozenSet(language, getHashingStorageNode.doNoValue(frame, inliningTarget, self.getArgs()));
277-
PFrozenSet argSet2 = PFactory.createFrozenSet(language, getHashingStorageNode.doNoValue(frame, inliningTarget, other.getArgs()));
276+
PFrozenSet argSet1 = PFactory.createFrozenSet(language, getHashingStorageNode.getForSets(frame, inliningTarget, self.getArgs()));
277+
PFrozenSet argSet2 = PFactory.createFrozenSet(language, getHashingStorageNode.getForSets(frame, inliningTarget, other.getArgs()));
278278
return eqNode.execute(frame, inliningTarget, argSet1, argSet2, op);
279279
}
280280

0 commit comments

Comments
 (0)