From 01cf35785d462aa94f35fe6458b6b8285a31fc38 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 21 Feb 2026 16:50:56 +0100 Subject: [PATCH] gh-141510: Check argument in PyDict_Contains() PyDict_Contains() and PyDict_ContainsString() now fail with SystemError if the first argument is not a dict, frozendict, dict subclass or frozendict subclass. --- Lib/test/test_capi/test_dict.py | 8 ++++++-- Objects/dictobject.c | 28 ++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/Lib/test/test_capi/test_dict.py b/Lib/test/test_capi/test_dict.py index e726e3d813d888..494981cae88445 100644 --- a/Lib/test/test_capi/test_dict.py +++ b/Lib/test/test_capi/test_dict.py @@ -210,6 +210,7 @@ def test_dict_getitemwitherror(self): # CRASHES getitem(NULL, 'a') def test_dict_contains(self): + # Test PyDict_Contains() contains = _testlimitedcapi.dict_contains dct = {'a': 1, '\U0001f40d': 2} self.assertTrue(contains(dct, 'a')) @@ -222,11 +223,12 @@ def test_dict_contains(self): self.assertRaises(TypeError, contains, {}, []) # unhashable # CRASHES contains({}, NULL) - # CRASHES contains(UserDict(), 'a') - # CRASHES contains(42, 'a') + self.assertRaises(SystemError, contains, UserDict(), 'a') + self.assertRaises(SystemError, contains, 42, 'a') # CRASHES contains(NULL, 'a') def test_dict_contains_string(self): + # Test PyDict_ContainsString() contains_string = _testcapi.dict_containsstring dct = {'a': 1, '\U0001f40d': 2} self.assertTrue(contains_string(dct, b'a')) @@ -238,6 +240,8 @@ def test_dict_contains_string(self): self.assertTrue(contains_string(dct2, b'a')) self.assertFalse(contains_string(dct2, b'b')) + self.assertRaises(SystemError, contains_string, UserDict(), 'a') + self.assertRaises(SystemError, contains_string, 42, 'a') # CRASHES contains({}, NULL) # CRASHES contains(NULL, b'a') diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 8f960352fa4824..10f0d70ebc81f5 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -140,6 +140,7 @@ static PyObject* frozendict_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static PyObject* dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds); static int dict_merge(PyObject *a, PyObject *b, int override); +static int dict_contains(PyObject *op, PyObject *key); /*[clinic input] @@ -4121,7 +4122,7 @@ dict_merge(PyObject *a, PyObject *b, int override) for (key = PyIter_Next(iter); key; key = PyIter_Next(iter)) { if (override != 1) { - status = PyDict_Contains(a, key); + status = dict_contains(a, key); if (status != 0) { if (status > 0) { if (override == 0) { @@ -4464,7 +4465,7 @@ static PyObject * dict___contains___impl(PyDictObject *self, PyObject *key) /*[clinic end generated code: output=1b314e6da7687dae input=fe1cb42ad831e820]*/ { - int contains = PyDict_Contains((PyObject *)self, key); + int contains = dict_contains((PyObject *)self, key); if (contains < 0) { return NULL; } @@ -4964,9 +4965,8 @@ static PyMethodDef mapp_methods[] = { {NULL, NULL} /* sentinel */ }; -/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */ -int -PyDict_Contains(PyObject *op, PyObject *key) +static int +dict_contains(PyObject *op, PyObject *key) { Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { @@ -4977,6 +4977,18 @@ PyDict_Contains(PyObject *op, PyObject *key) return _PyDict_Contains_KnownHash(op, key, hash); } +/* Return 1 if `key` is in dict `op`, 0 if not, and -1 on error. */ +int +PyDict_Contains(PyObject *op, PyObject *key) +{ + if (!PyAnyDict_Check(op)) { + PyErr_BadInternalCall(); + return -1; + } + + return dict_contains(op, key); +} + int PyDict_ContainsString(PyObject *op, const char *key) { @@ -4993,7 +5005,7 @@ PyDict_ContainsString(PyObject *op, const char *key) int _PyDict_Contains_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) { - PyDictObject *mp = (PyDictObject *)op; + PyDictObject *mp = _PyAnyDict_CAST(op); PyObject *value; Py_ssize_t ix; @@ -5022,7 +5034,7 @@ static PySequenceMethods dict_as_sequence = { 0, /* sq_slice */ 0, /* sq_ass_item */ 0, /* sq_ass_slice */ - PyDict_Contains, /* sq_contains */ + dict_contains, /* sq_contains */ 0, /* sq_inplace_concat */ 0, /* sq_inplace_repeat */ }; @@ -6272,7 +6284,7 @@ dictkeys_contains(PyObject *self, PyObject *obj) _PyDictViewObject *dv = (_PyDictViewObject *)self; if (dv->dv_dict == NULL) return 0; - return PyDict_Contains((PyObject *)dv->dv_dict, obj); + return dict_contains((PyObject *)dv->dv_dict, obj); } static PySequenceMethods dictkeys_as_sequence = {