Skip to content

Commit 50bd022

Browse files
Move descriptor pool containers to flat_hash_map to speed it up.
PiperOrigin-RevId: 826110301
1 parent ec86098 commit 50bd022

File tree

2 files changed

+46
-6
lines changed

2 files changed

+46
-6
lines changed

python/google/protobuf/internal/descriptor_pool_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
__author__ = '[email protected] (Matt Toia)'
1111

1212
import copy
13+
import timeit
1314
import unittest
1415
import warnings
1516

@@ -43,9 +44,51 @@
4344

4445
warnings.simplefilter('error', DeprecationWarning)
4546

47+
# Enable this to run the benchmarks.
48+
ALSO_RUN_BENCHMARKS = False
49+
4650

4751
class DescriptorPoolTestBase(object):
4852

53+
@unittest.skipIf(not ALSO_RUN_BENCHMARKS, 'Benchmarks are disabled.')
54+
def testDescriptorPoolBenchmark(self):
55+
if ALSO_RUN_BENCHMARKS:
56+
n_trials = 100
57+
58+
# FindFileByName
59+
name = 'google/protobuf/internal/factory_test1.proto'
60+
duration = timeit.timeit(
61+
lambda: self.pool.FindFileByName(name),
62+
number=n_trials,
63+
)
64+
print(f'FindFileByName: {duration / n_trials * 1000}ms')
65+
66+
# FindEnumTypeByName
67+
name = 'google.protobuf.python.internal.Factory1Enum'
68+
duration = timeit.timeit(
69+
lambda: self.pool.FindEnumTypeByName(name),
70+
number=n_trials,
71+
)
72+
print(f'FindEnumTypeByName: {duration / n_trials * 1000}ms')
73+
74+
# FindOneofByName
75+
name = 'google.protobuf.python.internal.Factory2Message.oneof_field'
76+
duration = timeit.timeit(
77+
lambda: self.pool.FindOneofByName(name),
78+
number=n_trials,
79+
)
80+
print(f'FindOneofByName: {duration / n_trials * 1000}ms')
81+
82+
# FindExtensionByName
83+
name = 'google.protobuf.python.internal.another_field'
84+
duration = timeit.timeit(
85+
lambda: self.pool.FindExtensionByName(name),
86+
number=n_trials,
87+
)
88+
print(f'FindExtensionByName: {duration / n_trials * 1000}ms')
89+
else:
90+
print('Skipping benchmark in non-benchmark mode.')
91+
4992
def testFindFileByName(self):
5093
name1 = 'google/protobuf/internal/factory_test1.proto'
5194
file_desc1 = self.pool.FindFileByName(name1)

python/google/protobuf/pyext/descriptor_pool.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
// Implements the DescriptorPool, which collects all descriptors.
99

1010
#include <string>
11-
#include <unordered_map>
1211
#include <utility>
1312
#include <vector>
1413

@@ -19,7 +18,6 @@
1918
#include "absl/container/flat_hash_map.h"
2019
#include "absl/status/status.h"
2120
#include "absl/strings/str_cat.h"
22-
#include "absl/strings/str_replace.h"
2321
#include "absl/strings/string_view.h"
2422
#include "google/protobuf/pyext/descriptor.h"
2523
#include "google/protobuf/pyext/descriptor_database.h"
@@ -45,7 +43,7 @@ namespace python {
4543

4644
// A map to cache Python Pools per C++ pointer.
4745
// Pointers are not owned here, and belong to the PyDescriptorPool.
48-
static std::unordered_map<const DescriptorPool*, PyDescriptorPool*>*
46+
static absl::flat_hash_map<const DescriptorPool*, PyDescriptorPool*>*
4947
descriptor_pool_map;
5048

5149
namespace cdescriptor_pool {
@@ -684,7 +682,7 @@ bool InitDescriptorPool() {
684682
// generated_pool() contains all messages already linked in C++ libraries, and
685683
// is used as underlay.
686684
descriptor_pool_map =
687-
new std::unordered_map<const DescriptorPool*, PyDescriptorPool*>;
685+
new absl::flat_hash_map<const DescriptorPool*, PyDescriptorPool*>;
688686
python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay(
689687
DescriptorPool::generated_pool());
690688
if (python_generated_pool == nullptr) {
@@ -714,8 +712,7 @@ PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) {
714712
pool == DescriptorPool::generated_pool()) {
715713
return python_generated_pool;
716714
}
717-
std::unordered_map<const DescriptorPool*, PyDescriptorPool*>::iterator it =
718-
descriptor_pool_map->find(pool);
715+
auto it = descriptor_pool_map->find(pool);
719716
if (it == descriptor_pool_map->end()) {
720717
PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool");
721718
return nullptr;

0 commit comments

Comments
 (0)