diff --git a/Lib/test/test_capi/test_dict.py b/Lib/test/test_capi/test_dict.py index d3cc279cd3f955..f69ccbdbd1117d 100644 --- a/Lib/test/test_capi/test_dict.py +++ b/Lib/test/test_capi/test_dict.py @@ -223,6 +223,7 @@ def test_dict_getitemwitherror(self): # CRASHES getitem(NULL, 'a') def test_dict_contains(self): + # Test PyDict_Contains() contains = _testlimitedcapi.dict_contains dct = {'a': 1, '\U0001f40d': 2} self.assertTrue(contains(dct, 'a')) @@ -235,11 +236,12 @@ def test_dict_contains(self): self.assertRaises(TypeError, contains, {}, []) # unhashable # CRASHES contains({}, NULL) - # CRASHES contains(UserDict(), 'a') - # CRASHES contains(42, 'a') + self.assertRaises(SystemError, contains, UserDict(), 'a') + self.assertRaises(SystemError, contains, 42, 'a') # CRASHES contains(NULL, 'a') def test_dict_contains_string(self): + # Test PyDict_ContainsString() contains_string = _testcapi.dict_containsstring dct = {'a': 1, '\U0001f40d': 2} self.assertTrue(contains_string(dct, b'a')) @@ -251,6 +253,8 @@ def test_dict_contains_string(self): self.assertTrue(contains_string(dct2, b'a')) self.assertFalse(contains_string(dct2, b'b')) + self.assertRaises(SystemError, contains_string, UserDict(), 'a') + self.assertRaises(SystemError, contains_string, 42, 'a') # CRASHES contains({}, NULL) # CRASHES contains(NULL, b'a') diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 276e1df21a80d8..0a8ba74c2287c1 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static int dict_merge(PyObject *a, PyObject *b, int override); +static int dict_contains(PyObject *op, PyObject *key); static int dict_merge_from_seq2(PyObject *d, PyObject *seq2, int override); @@ -4126,7 +4127,7 @@ dict_merge(PyObject *a, PyObject *b, int override) for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) { if (override != 1) { - status = PyDict_Contains(a, key); + status = dict_contains(a, key); if (status != 0) { if (status > 0) { if (override == 0) { @@ -4484,7 +4485,7 @@ static PyObject * dict___contains___impl(PyDictObject *self, PyObject *key) /*[clinic end generated code: output=1b314e6da7687dae input=fe1cb42ad831e820]*/ { - int contains = PyDict_Contains((PyObject *)self, key); + int contains = dict_contains((PyObject *)self, key); if (contains < 0) { return NULL; } @@ -4984,9 +4985,8 @@ static PyMethodDef mapp_methods[] = { {NULL, NULL} /* sentinel */ }; -/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */ -int -PyDict_Contains(PyObject *op, PyObject *key) +static int +dict_contains(PyObject *op, PyObject *key) { Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { @@ -4997,6 +4997,18 @@ PyDict_Contains(PyObject *op, PyObject *key) return _PyDict_Contains_KnownHash(op, key, hash); } +/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */ +int +PyDict_Contains(PyObject *op, PyObject *key) +{ + if (!PyAnyDict_Check(op)) { + PyErr_BadInternalCall(); + return -1; + } + + return dict_contains(op, key); +} + int PyDict_ContainsString(PyObject *op, const char *key) { @@ -5013,7 +5025,7 @@ PyDict_ContainsString(PyObject *op, const char *key) int _PyDict_Contains_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) { - PyDictObject *mp = (PyDictObject *)op; + PyDictObject *mp = _PyAnyDict_CAST(op); PyObject *value; Py_ssize_t ix; @@ -5042,7 +5054,7 @@ static PySequenceMethods dict_as_sequence = { 0, /* sq_slice */ 0, /* sq_ass_item */ 0, /* sq_ass_slice */ - PyDict_Contains, /* sq_contains */ + dict_contains, /* sq_contains */ 0, /* sq_inplace_concat */ 0, /* sq_inplace_repeat */ }; @@ -6292,7 +6304,7 @@ dictkeys_contains(PyObject *self, PyObject *obj) _PyDictViewObject *dv = (_PyDictViewObject *)self; if (dv->dv_dict == NULL) return 0; - return PyDict_Contains((PyObject *)dv->dv_dict, obj); + return dict_contains((PyObject *)dv->dv_dict, obj); } static PySequenceMethods dictkeys_as_sequence = {