diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 52ef0c14ea3a64..164169b30fcd7f 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -11,6 +11,15 @@ from test.support import import_helper, get_c_recursion_limit +class CustomHash: + def __init__(self, hash): + self.hash = hash + def __hash__(self): + return self.hash + def __repr__(self): + return f'' + + class DictTest(unittest.TestCase): def test_invalid_keyword_arguments(self): @@ -1701,6 +1710,29 @@ class MyClass: pass d[MyStr("attr1")] = 2 self.assertIsInstance(list(d)[0], MyStr) + def test_hash_collision_remove_add(self): + self.maxDiff = None + # There should be enough space, so all elements with unique hash + # will be placed in corresponding cells without collision. + n = 64 + items = [(CustomHash(h), h) for h in range(n)] + # Keys with hash collision. + a = CustomHash(n) + b = CustomHash(n) + items += [(a, 'a'), (b, 'b')] + d = dict(items) + self.assertEqual(len(d), len(items), d) + del d[a] + # "a" has been replaced with a dummy. + del items[n] + self.assertEqual(len(d), len(items), d) + self.assertEqual(d, dict(items)) + d[b] = 'c' + # "b" should not replace the dummy. + items[n] = (b, 'c') + self.assertEqual(len(d), len(items), d) + self.assertEqual(d, dict(items)) + class CAPITest(unittest.TestCase): diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 2c2c8702b6c011..bc0fa558960317 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -19,6 +19,14 @@ def check_pass_thru(): raise PassThru yield 1 +class CustomHash: + def __init__(self, hash): + self.hash = hash + def __hash__(self): + return self.hash + def __repr__(self): + return f'' + class BadCmp: def __hash__(self): return 1 @@ -635,6 +643,38 @@ def __le__(self, some_set): myset >= myobj self.assertTrue(myobj.le_called) + def test_set_membership(self): + myfrozenset = frozenset(range(3)) + myset = {myfrozenset, "abc", 1} + self.assertIn(set(range(3)), myset) + self.assertNotIn(set(range(1)), myset) + myset.discard(set(range(3))) + self.assertEqual(myset, {"abc", 1}) + self.assertRaises(KeyError, myset.remove, set(range(1))) + self.assertRaises(KeyError, myset.remove, set(range(3))) + + def test_hash_collision_remove_add(self): + self.maxDiff = None + # There should be enough space, so all elements with unique hash + # will be placed in corresponding cells without collision. + n = 64 + elems = [CustomHash(h) for h in range(n)] + # Elements with hash collision. + a = CustomHash(n) + b = CustomHash(n) + elems += [a, b] + s = self.thetype(elems) + self.assertEqual(len(s), len(elems), s) + s.remove(a) + # "a" has been replaced with a dummy. + del elems[n] + self.assertEqual(len(s), len(elems), s) + self.assertEqual(s, set(elems)) + s.add(b) + # "b" should not replace the dummy. + self.assertEqual(len(s), len(elems), s) + self.assertEqual(s, set(elems)) + class SetSubclass(set): pass