-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpy_message_factory.cc
More file actions
119 lines (105 loc) · 3.79 KB
/
Copy pathpy_message_factory.cc
File metadata and controls
119 lines (105 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "py_message_factory.h"
#include <Python.h> // IWYU pragma: keep - Needed for PyObject
#include <string>
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
namespace cel_python {
PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {
py_descriptor_pool_ = descriptor_pool;
Py_INCREF(py_descriptor_pool_);
PyObject* pName =
PyUnicode_DecodeFSDefault("google.protobuf.message_factory");
PyObject* pModule = PyImport_Import(pName);
Py_DECREF(pName);
if (!pModule) {
ABSL_LOG(FATAL) << "Cannot load module 'google.protobuf.message_factory'";
}
py_func_GetMessageClass_ = PyObject_GetAttrString(pModule, "GetMessageClass");
if (!py_func_GetMessageClass_) {
ABSL_LOG(FATAL) << "Cannot find function "
"'google.protobuf.message_factory.GetMessageClass'";
}
Py_INCREF(py_func_GetMessageClass_);
py_func_MergeFromString_ = PyUnicode_FromString("MergeFromString");
}
PyMessageFactory::~PyMessageFactory() {
auto gil_state = PyGILState_Ensure();
Py_DECREF(py_descriptor_pool_);
Py_XDECREF(py_func_GetMessageClass_);
Py_XDECREF(py_func_MergeFromString_);
for (auto const& [key, py_obj] : message_classes_) {
Py_XDECREF(py_obj);
}
PyGILState_Release(gil_state);
}
PyObject* PyMessageFactory::GetMessageClass(const std::string& message_type) {
auto it = message_classes_.find(message_type);
if (it != message_classes_.end()) {
return it->second;
} else {
PyObject* descriptor =
PyObject_CallMethod(py_descriptor_pool_, "FindMessageTypeByName", "s",
message_type.c_str());
if (!descriptor) {
PyErr_Format(PyExc_TypeError, "Message type not found: %s",
message_type.c_str());
return nullptr;
}
PyObject* message_class =
PyObject_CallFunction(py_func_GetMessageClass_, "O", descriptor);
Py_DECREF(descriptor);
if (!message_class) {
PyErr_Format(PyExc_TypeError, "Couldn't find message class for type: %s",
message_type.c_str());
return nullptr;
}
message_classes_[message_type] = message_class;
return message_class;
}
}
PyObject* PyMessageFactory::FromString(const std::string& message_type,
const std::string& serialized_proto) {
ABSL_CHECK(PyGILState_Check());
PyObject* message_class = GetMessageClass(message_type);
if (!message_class) {
return nullptr;
}
PyObject* message = PyObject_CallObject(message_class, nullptr);
if (!message) {
PyErr_Format(PyExc_RuntimeError, "Cannot create message of type: %s",
message_type.c_str());
return nullptr;
}
PyObject* serialized_proto_py =
PyMemoryView_FromMemory(const_cast<char*>(serialized_proto.data()),
serialized_proto.size(), PyBUF_READ);
if (!serialized_proto_py) {
Py_DECREF(message);
return nullptr;
}
PyObject* ret = PyObject_CallMethodObjArgs(message, py_func_MergeFromString_,
serialized_proto_py, nullptr);
if (!ret) {
Py_DECREF(message);
Py_DECREF(serialized_proto_py);
return nullptr;
}
Py_DECREF(serialized_proto_py);
return message;
}
} // namespace cel_python