diff --git a/Include/pyport.h b/Include/pyport.h
index 614a2789fb0781..086ed4204e4ff1 100644
--- a/Include/pyport.h
+++ b/Include/pyport.h
@@ -26,8 +26,32 @@
 // _Py_CAST(const PyObject*, expr) fails with a compiler error.
 #ifdef __cplusplus
 #  define _Py_STATIC_CAST(type, expr) static_cast<type>(expr)
-#  define _Py_CAST(type, expr) \
-       const_cast<type>(reinterpret_cast<const type>(expr))
+
+extern "C++" {
+namespace {
+template <typename type, typename expr_type>
+inline type _Py_reinterpret_cast_impl(expr_type *expr) {
+  return reinterpret_cast<type>(expr);
+}
+
+template <typename type, typename expr_type>
+inline type _Py_reinterpret_cast_impl(expr_type const *expr) {
+  return reinterpret_cast<type>(const_cast<expr_type *>(expr));
+}
+
+template <typename type, typename expr_type>
+inline type _Py_reinterpret_cast_impl(expr_type &expr) {
+  return static_cast<type>(expr);
+}
+
+template <typename type, typename expr_type>
+inline type _Py_reinterpret_cast_impl(expr_type const &expr) {
+  return static_cast<type>(const_cast<expr_type &>(expr));
+}
+} // namespace
+}
+#  define _Py_CAST(type, expr) _Py_reinterpret_cast_impl<type>(expr)
+
 #else
 #  define _Py_STATIC_CAST(type, expr) ((type)(expr))
 #  define _Py_CAST(type, expr) ((type)(expr))
diff --git a/Lib/test/_testcppext.cpp b/Lib/test/_testcppext.cpp
index f38b4870e0edbc..f6049eedd00048 100644
--- a/Lib/test/_testcppext.cpp
+++ b/Lib/test/_testcppext.cpp
@@ -40,6 +40,15 @@ test_api_casts(PyObject *Py_UNUSED(module), PyObject *Py_UNUSED(args))
     PyTypeObject *type = Py_TYPE(const_obj);
     assert(Py_REFCNT(const_obj) >= 1);
 
+    struct PyObjectProxy {
+      PyObject* obj;
+      operator PyObject *() { return obj; }
+    } proxy_obj = { obj };
+    Py_INCREF(proxy_obj);
+    Py_DECREF(proxy_obj);
+    assert(Py_REFCNT(proxy_obj) >= 1);
+
+
     assert(type == &PyTuple_Type);
     assert(PyTuple_GET_SIZE(const_obj) == 2);
     PyObject *one = PyTuple_GET_ITEM(const_obj, 0);