Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Lib/test/test_capi/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def test_dict_getitemwitherror(self):
# CRASHES getitem(NULL, 'a')

def test_dict_contains(self):
# Test PyDict_Contains()
contains = _testlimitedcapi.dict_contains
dct = {'a': 1, '\U0001f40d': 2}
self.assertTrue(contains(dct, 'a'))
Expand All @@ -235,11 +236,12 @@ def test_dict_contains(self):

self.assertRaises(TypeError, contains, {}, []) # unhashable
# CRASHES contains({}, NULL)
# CRASHES contains(UserDict(), 'a')
# CRASHES contains(42, 'a')
self.assertRaises(SystemError, contains, UserDict(), 'a')
self.assertRaises(SystemError, contains, 42, 'a')
# CRASHES contains(NULL, 'a')

def test_dict_contains_string(self):
# Test PyDict_ContainsString()
contains_string = _testcapi.dict_containsstring
dct = {'a': 1, '\U0001f40d': 2}
self.assertTrue(contains_string(dct, b'a'))
Expand All @@ -251,6 +253,8 @@ def test_dict_contains_string(self):
self.assertTrue(contains_string(dct2, b'a'))
self.assertFalse(contains_string(dct2, b'b'))

self.assertRaises(SystemError, contains_string, UserDict(), 'a')
self.assertRaises(SystemError, contains_string, 42, 'a')
# CRASHES contains({}, NULL)
# CRASHES contains(NULL, b'a')

Expand Down
28 changes: 20 additions & 8 deletions Objects/dictobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
PyObject *kwds);
static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
static int dict_merge(PyObject *a, PyObject *b, int override);
static int dict_contains(PyObject *op, PyObject *key);
static int dict_merge_from_seq2(PyObject *d, PyObject *seq2, int override);


Expand Down Expand Up @@ -4126,7 +4127,7 @@ dict_merge(PyObject *a, PyObject *b, int override)

for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) {
if (override != 1) {
status = PyDict_Contains(a, key);
status = dict_contains(a, key);
if (status != 0) {
if (status > 0) {
if (override == 0) {
Expand Down Expand Up @@ -4484,7 +4485,7 @@ static PyObject *
dict___contains___impl(PyDictObject *self, PyObject *key)
/*[clinic end generated code: output=1b314e6da7687dae input=fe1cb42ad831e820]*/
{
int contains = PyDict_Contains((PyObject *)self, key);
int contains = dict_contains((PyObject *)self, key);
if (contains < 0) {
return NULL;
}
Expand Down Expand Up @@ -4984,9 +4985,8 @@ static PyMethodDef mapp_methods[] = {
{NULL, NULL} /* sentinel */
};

/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
int
PyDict_Contains(PyObject *op, PyObject *key)
static int
dict_contains(PyObject *op, PyObject *key)
{
Py_hash_t hash = _PyObject_HashFast(key);
if (hash == -1) {
Expand All @@ -4997,6 +4997,18 @@ PyDict_Contains(PyObject *op, PyObject *key)
return _PyDict_Contains_KnownHash(op, key, hash);
}

/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */
int
PyDict_Contains(PyObject *op, PyObject *key)
{
if (!PyAnyDict_Check(op)) {
PyErr_BadInternalCall();
return -1;
}
Comment on lines +5004 to +5007
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding a versionchanged for this in the docs? This looks good otherwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not document SystemError errors in the documentation. It's not really part of the API, but provided to be kind with developers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that's necessarily a good thing. It's helpful for extension authors to know whether they need to validate their input.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They always need to validate their input. You're not supposed to hit SystemError in regular code.


return dict_contains(op, key);
}

int
PyDict_ContainsString(PyObject *op, const char *key)
{
Expand All @@ -5013,7 +5025,7 @@ PyDict_ContainsString(PyObject *op, const char *key)
int
_PyDict_Contains_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
{
PyDictObject *mp = (PyDictObject *)op;
PyDictObject *mp = _PyAnyDict_CAST(op);
PyObject *value;
Py_ssize_t ix;

Expand Down Expand Up @@ -5042,7 +5054,7 @@ static PySequenceMethods dict_as_sequence = {
0, /* sq_slice */
0, /* sq_ass_item */
0, /* sq_ass_slice */
PyDict_Contains, /* sq_contains */
dict_contains, /* sq_contains */
0, /* sq_inplace_concat */
0, /* sq_inplace_repeat */
};
Expand Down Expand Up @@ -6292,7 +6304,7 @@ dictkeys_contains(PyObject *self, PyObject *obj)
_PyDictViewObject *dv = (_PyDictViewObject *)self;
if (dv->dv_dict == NULL)
return 0;
return PyDict_Contains((PyObject *)dv->dv_dict, obj);
return dict_contains((PyObject *)dv->dv_dict, obj);
}

static PySequenceMethods dictkeys_as_sequence = {
Expand Down
Loading