Skip to content

Commit a835245

Browse files
authored
Add support for Eigen's Tensor module (#1320)
This PR adds support for the types in Eigen's "unsupported" Tensor module, from the header `<unsupported/Eigen/CXX11/Tensor>` (we keep `CXX11/` in the include to support Eigen3).
1 parent cfbb02f commit a835245

File tree

8 files changed

+597
-7
lines changed

8 files changed

+597
-7
lines changed

cmake/nanobind-config.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ function (nanobind_build_library TARGET_NAME)
192192
${NB_DIR}/include/nanobind/stl/vector.h
193193
${NB_DIR}/include/nanobind/eigen/dense.h
194194
${NB_DIR}/include/nanobind/eigen/sparse.h
195+
${NB_DIR}/include/nanobind/eigen/tensor.h
195196

196197
${NB_DIR}/src/buffer.h
197198
${NB_DIR}/src/hash.h

docs/eigen.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,18 @@ case.
178178

179179
There is no support for Eigen sparse vectors because an equivalent type does
180180
not exist as part of ``scipy.sparse``.
181+
182+
Tensors
183+
-------
184+
185+
Add the following include directive to your binding code to exchange Eigen Tensor
186+
types:
187+
188+
.. code-block:: cpp
189+
190+
#include <nanobind/eigen/tensor.h>
191+
192+
The ``Eigen::Tensor<..>``, ``Eigen::TensorMap<..>`` and ``Eigen::TensorRef<..>``
193+
types are all supported, and map to `numpy.ndarray` with the appropriate sizes.
194+
Both column-major and row-major tensors are supported. Note that taking
195+
non-contiguous NumPy arrays as arguments is not supported for the Map and Ref types.

docs/exchanging.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ to external projects that provide further casters:
143143
- ``#include <nanobind/eigen/dense.h>``
144144
* - ``Eigen::SparseMatrix<..>``
145145
- ``#include <nanobind/eigen/sparse.h>``
146+
* - ``Eigen::Tensor<..>``, ``Eigen::TensorMap<..>``, ``Eigen::TensorRef<..>``
147+
- ``#include <nanobind/eigen/tensor.h>``
146148
* - Apache Arrow types
147149
- `https://github.com/maximiliank/nanobind_pyarrow <https://github.com/maximiliank/nanobind_pyarrow>`__
148150
* - ...

include/nanobind/eigen/tensor.h

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
/*
2+
nanobind/eigen/tensor.h: type casters for Eigen tensors
3+
4+
Copyright (c) 2026 INRIA
5+
6+
Author(s): Wilson Jallet
7+
8+
All rights reserved. Use of this source code is governed by a
9+
BSD-style license that can be found in the LICENSE file.
10+
*/
11+
12+
#pragma once
13+
14+
#include <nanobind/ndarray.h>
15+
#include <unsupported/Eigen/CXX11/Tensor>
16+
17+
NAMESPACE_BEGIN(NB_NAMESPACE)
18+
NAMESPACE_BEGIN(detail)
19+
20+
/// As of April 2026, Eigen::Tensor types support 16-byte alignment or no alignment.
21+
inline bool is_tensor_aligned(const void *data, std::size_t align = Eigen::Aligned) {
22+
return (reinterpret_cast<std::size_t>(data) % align) == 0;
23+
}
24+
25+
/// Type trait for inheriting from Eigen::TensorBase.
26+
/// All TensorBase specializations inherit from TensorBase<T, ReadOnlyAccessors>.
27+
template<typename T> constexpr bool is_eigen_tensor_v =
28+
std::is_base_of_v<Eigen::TensorBase<T, Eigen::ReadOnlyAccessors>, T>;
29+
30+
template<typename T> constexpr bool eigen_tensor_is_row_major_v = T::Layout == Eigen::RowMajor;
31+
template<typename T> constexpr bool eigen_tensor_is_col_major_v = T::Layout == Eigen::ColMajor;
32+
33+
template<typename T>
34+
constexpr bool is_eigen_tensor_map_v = false;
35+
36+
// Covers const case
37+
template<typename T, int Options, template<class> class MakePointer>
38+
constexpr bool is_eigen_tensor_map_v<Eigen::TensorMap<T, Options, MakePointer>> = true;
39+
40+
template<typename T>
41+
constexpr bool is_eigen_tensor_ref_v = false;
42+
43+
// Covers const case
44+
template<typename T>
45+
constexpr bool is_eigen_tensor_ref_v<Eigen::TensorRef<T>> = true;
46+
47+
template<typename T>
48+
constexpr bool is_eigen_tensor_plain_v = false;
49+
50+
template<typename Scalar, int NumIndices, int Options, typename IndexType>
51+
constexpr bool is_eigen_tensor_plain_v<Eigen::Tensor<Scalar, NumIndices, Options, IndexType>> = true;
52+
53+
template<typename Scalar, std::ptrdiff_t... Indices, int Options, typename IndexType>
54+
constexpr bool is_eigen_tensor_plain_v<Eigen::TensorFixedSize<Scalar, Eigen::Sizes<Indices...>, Options, IndexType>> = true;
55+
56+
template<typename T>
57+
constexpr bool is_eigen_tensor_xpr_v =
58+
is_eigen_tensor_v<T> &&
59+
!is_eigen_tensor_plain_v<T> &&
60+
!is_eigen_tensor_map_v<T> &&
61+
!is_eigen_tensor_ref_v<T>;
62+
63+
template<typename T, typename Scalar = typename T::Scalar>
64+
using ndarray_for_eigen_tensor_t = ndarray<
65+
Scalar,
66+
numpy,
67+
ndim<T::NumDimensions>,
68+
std::conditional_t<
69+
eigen_tensor_is_row_major_v<T>,
70+
c_contig,
71+
f_contig>>;
72+
73+
/** \brief Type caster for ``Eigen::TensorMap<T>``
74+
*/
75+
template<typename T, int MapOptions, template<class> class MakePointer>
76+
struct type_caster<
77+
Eigen::TensorMap<T, MapOptions, MakePointer>,
78+
enable_if_t<is_ndarray_scalar_v<typename T::Scalar>>> {
79+
80+
using Scalar = typename T::Scalar;
81+
using IndexType = typename T::Index;
82+
static constexpr int NumIndices = T::NumIndices;
83+
static constexpr int Options = T::Options;
84+
using PlainTensor = Eigen::Tensor<Scalar, NumIndices, Options, IndexType>;
85+
using Dimensions = typename T::Dimensions;
86+
using MapType = Eigen::TensorMap<T, MapOptions, MakePointer>;
87+
static constexpr bool IsAligned = MapType::IsAligned;
88+
89+
// Only partial specification. Dimensions not known at compile time...
90+
using NDArray =
91+
ndarray_for_eigen_tensor_t<T, std::conditional_t<std::is_const_v<T>,
92+
const Scalar,
93+
Scalar>>;
94+
using NDArrayCaster = make_caster<NDArray>;
95+
static constexpr auto Name = NDArrayCaster::Name;
96+
template<typename T_> using Cast = MapType;
97+
template<typename T_> static constexpr bool can_cast() { return true; };
98+
99+
NDArrayCaster caster;
100+
101+
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
102+
// Disable implicit conversions
103+
flags &= ~(uint8_t)cast_flags::convert;
104+
// Do not accept None
105+
flags &= ~(uint8_t)cast_flags::accepts_none;
106+
107+
if (!caster.from_python(src, flags, cleanup))
108+
return false;
109+
if(IsAligned && !is_tensor_aligned(caster.value.data()))
110+
return false;
111+
112+
return true;
113+
}
114+
115+
static handle from_cpp(const MapType &v, rv_policy policy, cleanup_list *cleanup) noexcept {
116+
size_t shape[NumIndices];
117+
for (size_t i = 0 ; i < NumIndices; i++) {
118+
shape[i] = (size_t) v.dimension(i);
119+
}
120+
121+
void* ptr = (void *)v.data();
122+
if (policy == rv_policy::automatic || policy == rv_policy::automatic_reference)
123+
policy = rv_policy::reference;
124+
return NDArrayCaster::from_cpp(
125+
NDArray {ptr, NumIndices, shape, handle()},
126+
policy,
127+
cleanup);
128+
}
129+
130+
operator MapType() {
131+
NDArray &t = caster.value;
132+
std::array<long, NumIndices> shape;
133+
for (size_t i = 0 ; i < NumIndices; i++) {
134+
shape[i] = t.shape(i);
135+
}
136+
return MapType(t.data(), shape);
137+
}
138+
};
139+
140+
141+
/** \brief Type caster for plain ``Eigen::Tensor<T>`` types.
142+
*/
143+
template<typename Scalar, int NumIndices, int Options, typename IndexType>
144+
struct type_caster<
145+
Eigen::Tensor<Scalar, NumIndices, Options, IndexType>,
146+
enable_if_t<is_ndarray_scalar_v<Scalar>>> {
147+
148+
using PlainTensor = Eigen::Tensor<Scalar, NumIndices, Options, IndexType>;
149+
using Dimensions = typename PlainTensor::Dimensions;
150+
using Coeffs = typename PlainTensor::CoeffReturnType;
151+
static constexpr bool IsRowMajor = bool(Options & Eigen::RowMajorBit);
152+
using NDArray = ndarray_for_eigen_tensor_t<PlainTensor>;
153+
using NDArrayCaster = make_caster<NDArray>;
154+
155+
// PlainTensor value;
156+
NB_TYPE_CASTER(PlainTensor, NDArrayCaster::Name);
157+
158+
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
159+
using NDArrayConst = ndarray_for_eigen_tensor_t<PlainTensor, const Scalar>;
160+
make_caster<NDArrayConst> caster;
161+
// Do not accept None
162+
if (!caster.from_python(src, flags & ~(uint8_t)cast_flags::accepts_none, cleanup))
163+
return false;
164+
165+
const NDArrayConst &array = caster.value;
166+
// copy tensor dims
167+
std::array<long, NumIndices> out_dims;
168+
for(size_t i = 0; i < NumIndices; i++) {
169+
out_dims[i] = array.shape(i);
170+
}
171+
value.resize(out_dims);
172+
173+
memcpy(value.data(), array.data(), array.size() * sizeof(Scalar));
174+
175+
return true;
176+
}
177+
178+
template<typename T2>
179+
static handle from_cpp(T2 &&v, rv_policy policy, cleanup_list *cleanup) noexcept {
180+
policy = infer_policy<T2>(policy);
181+
if constexpr (std::is_pointer_v<T2>)
182+
return from_cpp_internal((const PlainTensor &) *v, policy, cleanup);
183+
else
184+
return from_cpp_internal((const PlainTensor &) v, policy, cleanup);
185+
}
186+
187+
static handle from_cpp_internal(const PlainTensor &v, rv_policy policy, cleanup_list *cleanup) noexcept {
188+
size_t shape[NumIndices];
189+
190+
for (size_t i = 0 ; i < NumIndices; i++) {
191+
shape[i] = (size_t) v.dimension(i);
192+
}
193+
194+
void *ptr = (void *)v.data();
195+
196+
object owner;
197+
if (policy == rv_policy::move) {
198+
PlainTensor *tmp = new PlainTensor((PlainTensor&&)v);
199+
owner = capsule(tmp, [](void* p) noexcept {
200+
delete (PlainTensor*) p;
201+
});
202+
ptr = tmp->data();
203+
policy = rv_policy::reference;
204+
} else if (policy == rv_policy::reference_internal && cleanup->self()) {
205+
owner = borrow(cleanup->self());
206+
policy = rv_policy::reference;
207+
}
208+
return NDArrayCaster::from_cpp(
209+
NDArray {ptr, NumIndices, shape, owner},
210+
policy, cleanup);
211+
}
212+
};
213+
214+
/** \brief Type caster for Tensor expressions. From-cpp conversion just converts the expression to a plain Tensor object.
215+
*/
216+
template<typename T>
217+
struct type_caster<T, enable_if_t<is_eigen_tensor_xpr_v<T> && is_ndarray_scalar_v<typename T::Scalar>>> {
218+
static constexpr int NumDimensions = T::NumDimensions;
219+
static constexpr int Options = T::Options;
220+
using IndexType = typename T::Index;
221+
using XprTraits = typename Eigen::internal::traits<T>;
222+
static constexpr int Layout = XprTraits::Layout;
223+
using PlainTensor = Eigen::Tensor<typename T::Scalar, NumDimensions, Layout, IndexType>;
224+
using Caster = make_caster<PlainTensor>;
225+
static constexpr auto Name = Caster::Name;
226+
template<typename T_> using Cast = T;
227+
template<typename T_> static constexpr bool can_cast() { return true; }
228+
229+
/// Generating an expression template from a Python object is impossible.
230+
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept = delete;
231+
232+
template <typename T2>
233+
static handle from_cpp(T2 &&v, rv_policy policy, cleanup_list *cleanup) noexcept {
234+
return Caster::from_cpp(std::forward<T2>(v), policy, cleanup);
235+
}
236+
};
237+
238+
/** \brief Type caster for ``Eigen::TensorRef<T>``
239+
*/
240+
template<typename T>
241+
struct type_caster<
242+
Eigen::TensorRef<T>,
243+
enable_if_t<is_ndarray_scalar_v<typename T::Scalar>>> {
244+
245+
using Scalar = typename T::Scalar;
246+
using IndexType = typename T::Index;
247+
static constexpr int NumIndices = T::NumIndices;
248+
static constexpr int Options = T::Options;
249+
using PlainTensor = Eigen::Tensor<Scalar, NumIndices, Options, IndexType>;
250+
using Dimensions = typename T::Dimensions;
251+
252+
// Only partial specification. Dimensions not known at compile time...
253+
using NDArray =
254+
ndarray_for_eigen_tensor_t<T, std::conditional_t<std::is_const_v<T>,
255+
const Scalar,
256+
Scalar>>;
257+
using NDArrayCaster = make_caster<NDArray>;
258+
259+
using MapType = Eigen::TensorMap<T>;
260+
using MapCaster = make_caster<MapType>;
261+
262+
using RefType = Eigen::TensorRef<T>;
263+
264+
265+
static constexpr bool MaybeConvert = std::is_const_v<T>;
266+
using PlainCaster = make_caster<PlainTensor>;
267+
268+
static constexpr auto Name = const_name<MaybeConvert>(PlainCaster::Name, MapCaster::Name);
269+
270+
template<typename T_> using Cast = RefType;
271+
template<typename T_> static constexpr bool can_cast() { return true; };
272+
273+
MapCaster caster;
274+
struct Empty {};
275+
std::conditional_t<MaybeConvert, PlainCaster, Empty> plain_caster;
276+
277+
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept {
278+
// no conversion for mutable Ref
279+
if constexpr (!std::is_const_v<T>)
280+
flags &= ~(uint8_t) cast_flags::convert;
281+
282+
// Try direct cast
283+
if (caster.from_python(src, flags, cleanup))
284+
return true;
285+
286+
// if const T, attempt leveraging PlainTensor conversion
287+
if constexpr (MaybeConvert) {
288+
// we create a new temporary tensor object, and
289+
// its lifetime is that of plain_caster.
290+
// for manual conversion, disable conversion.
291+
if ((flags & (uint8_t) cast_flags::manual))
292+
flags &= ~(uint8_t) cast_flags::convert;
293+
if (plain_caster.from_python(src, flags, cleanup))
294+
return true;
295+
}
296+
297+
return false;
298+
}
299+
300+
static handle from_cpp(const RefType &v, rv_policy policy, cleanup_list *cleanup) noexcept {
301+
size_t shape[NumIndices];
302+
303+
for (size_t i = 0; i < NumIndices; i++) {
304+
shape[i] = (size_t) v.dimension(i);
305+
}
306+
307+
return NDArrayCaster::from_cpp(
308+
NDArray((void *) v.data(), NumIndices, shape, handle()),
309+
(policy == rv_policy::automatic ||
310+
policy == rv_policy::automatic_reference)
311+
? rv_policy::reference
312+
: policy,
313+
cleanup);
314+
}
315+
316+
operator RefType() {
317+
if constexpr (MaybeConvert) {
318+
// if there's a value, return it
319+
if (plain_caster.caster.value.is_valid())
320+
return RefType(plain_caster.operator PlainTensor&());
321+
}
322+
return RefType(caster.operator MapType());
323+
}
324+
};
325+
326+
327+
NAMESPACE_END(detail)
328+
329+
NAMESPACE_END(NB_NAMESPACE)

tests/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ find_package (Eigen3 3.3.1 NO_MODULE)
134134
if (TARGET Eigen3::Eigen)
135135
nanobind_add_module(test_eigen_ext test_eigen.cpp ${NB_EXTRA_ARGS})
136136
target_link_libraries(test_eigen_ext PRIVATE Eigen3::Eigen)
137+
nanobind_add_module(test_eigen_tensor_ext test_eigen_tensor.cpp ${NB_EXTRA_ARGS})
138+
target_link_libraries(test_eigen_tensor_ext PRIVATE Eigen3::Eigen)
137139
endif()
138140

139141
add_library(
@@ -159,6 +161,7 @@ set(TEST_FILES
159161
test_callbacks.py
160162
test_classes.py
161163
test_eigen.py
164+
test_eigen_tensor.py
162165
test_enum.py
163166
test_eval.py
164167
test_exception.py

0 commit comments

Comments
 (0)