|
| 1 | +/* |
| 2 | + nanobind/eigen/tensor.h: type casters for Eigen tensors |
| 3 | +
|
| 4 | + Copyright (c) 2026 INRIA |
| 5 | +
|
| 6 | + Author(s): Wilson Jallet |
| 7 | +
|
| 8 | + All rights reserved. Use of this source code is governed by a |
| 9 | + BSD-style license that can be found in the LICENSE file. |
| 10 | +*/ |
| 11 | + |
| 12 | +#pragma once |
| 13 | + |
| 14 | +#include <nanobind/ndarray.h> |
| 15 | +#include <unsupported/Eigen/CXX11/Tensor> |
| 16 | + |
| 17 | +NAMESPACE_BEGIN(NB_NAMESPACE) |
| 18 | +NAMESPACE_BEGIN(detail) |
| 19 | + |
| 20 | +/// As of April 2026, Eigen::Tensor types support 16-byte alignment or no alignment. |
| 21 | +inline bool is_tensor_aligned(const void *data, std::size_t align = Eigen::Aligned) { |
| 22 | + return (reinterpret_cast<std::size_t>(data) % align) == 0; |
| 23 | +} |
| 24 | + |
| 25 | +/// Type trait for inheriting from Eigen::TensorBase. |
| 26 | +/// All TensorBase specializations inherit from TensorBase<T, ReadOnlyAccessors>. |
| 27 | +template<typename T> constexpr bool is_eigen_tensor_v = |
| 28 | + std::is_base_of_v<Eigen::TensorBase<T, Eigen::ReadOnlyAccessors>, T>; |
| 29 | + |
| 30 | +template<typename T> constexpr bool eigen_tensor_is_row_major_v = T::Layout == Eigen::RowMajor; |
| 31 | +template<typename T> constexpr bool eigen_tensor_is_col_major_v = T::Layout == Eigen::ColMajor; |
| 32 | + |
| 33 | +template<typename T> |
| 34 | +constexpr bool is_eigen_tensor_map_v = false; |
| 35 | + |
| 36 | +// Covers const case |
| 37 | +template<typename T, int Options, template<class> class MakePointer> |
| 38 | +constexpr bool is_eigen_tensor_map_v<Eigen::TensorMap<T, Options, MakePointer>> = true; |
| 39 | + |
| 40 | +template<typename T> |
| 41 | +constexpr bool is_eigen_tensor_ref_v = false; |
| 42 | + |
| 43 | +// Covers const case |
| 44 | +template<typename T> |
| 45 | +constexpr bool is_eigen_tensor_ref_v<Eigen::TensorRef<T>> = true; |
| 46 | + |
| 47 | +template<typename T> |
| 48 | +constexpr bool is_eigen_tensor_plain_v = false; |
| 49 | + |
| 50 | +template<typename Scalar, int NumIndices, int Options, typename IndexType> |
| 51 | +constexpr bool is_eigen_tensor_plain_v<Eigen::Tensor<Scalar, NumIndices, Options, IndexType>> = true; |
| 52 | + |
| 53 | +template<typename Scalar, std::ptrdiff_t... Indices, int Options, typename IndexType> |
| 54 | +constexpr bool is_eigen_tensor_plain_v<Eigen::TensorFixedSize<Scalar, Eigen::Sizes<Indices...>, Options, IndexType>> = true; |
| 55 | + |
| 56 | +template<typename T> |
| 57 | +constexpr bool is_eigen_tensor_xpr_v = |
| 58 | + is_eigen_tensor_v<T> && |
| 59 | + !is_eigen_tensor_plain_v<T> && |
| 60 | + !is_eigen_tensor_map_v<T> && |
| 61 | + !is_eigen_tensor_ref_v<T>; |
| 62 | + |
| 63 | +template<typename T, typename Scalar = typename T::Scalar> |
| 64 | +using ndarray_for_eigen_tensor_t = ndarray< |
| 65 | + Scalar, |
| 66 | + numpy, |
| 67 | + ndim<T::NumDimensions>, |
| 68 | + std::conditional_t< |
| 69 | + eigen_tensor_is_row_major_v<T>, |
| 70 | + c_contig, |
| 71 | + f_contig>>; |
| 72 | + |
| 73 | +/** \brief Type caster for ``Eigen::TensorMap<T>`` |
| 74 | + */ |
| 75 | +template<typename T, int MapOptions, template<class> class MakePointer> |
| 76 | +struct type_caster< |
| 77 | + Eigen::TensorMap<T, MapOptions, MakePointer>, |
| 78 | + enable_if_t<is_ndarray_scalar_v<typename T::Scalar>>> { |
| 79 | + |
| 80 | + using Scalar = typename T::Scalar; |
| 81 | + using IndexType = typename T::Index; |
| 82 | + static constexpr int NumIndices = T::NumIndices; |
| 83 | + static constexpr int Options = T::Options; |
| 84 | + using PlainTensor = Eigen::Tensor<Scalar, NumIndices, Options, IndexType>; |
| 85 | + using Dimensions = typename T::Dimensions; |
| 86 | + using MapType = Eigen::TensorMap<T, MapOptions, MakePointer>; |
| 87 | + static constexpr bool IsAligned = MapType::IsAligned; |
| 88 | + |
| 89 | + // Only partial specification. Dimensions not known at compile time... |
| 90 | + using NDArray = |
| 91 | + ndarray_for_eigen_tensor_t<T, std::conditional_t<std::is_const_v<T>, |
| 92 | + const Scalar, |
| 93 | + Scalar>>; |
| 94 | + using NDArrayCaster = make_caster<NDArray>; |
| 95 | + static constexpr auto Name = NDArrayCaster::Name; |
| 96 | + template<typename T_> using Cast = MapType; |
| 97 | + template<typename T_> static constexpr bool can_cast() { return true; }; |
| 98 | + |
| 99 | + NDArrayCaster caster; |
| 100 | + |
| 101 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { |
| 102 | + // Disable implicit conversions |
| 103 | + flags &= ~(uint8_t)cast_flags::convert; |
| 104 | + // Do not accept None |
| 105 | + flags &= ~(uint8_t)cast_flags::accepts_none; |
| 106 | + |
| 107 | + if (!caster.from_python(src, flags, cleanup)) |
| 108 | + return false; |
| 109 | + if(IsAligned && !is_tensor_aligned(caster.value.data())) |
| 110 | + return false; |
| 111 | + |
| 112 | + return true; |
| 113 | + } |
| 114 | + |
| 115 | + static handle from_cpp(const MapType &v, rv_policy policy, cleanup_list *cleanup) noexcept { |
| 116 | + size_t shape[NumIndices]; |
| 117 | + for (size_t i = 0 ; i < NumIndices; i++) { |
| 118 | + shape[i] = (size_t) v.dimension(i); |
| 119 | + } |
| 120 | + |
| 121 | + void* ptr = (void *)v.data(); |
| 122 | + if (policy == rv_policy::automatic || policy == rv_policy::automatic_reference) |
| 123 | + policy = rv_policy::reference; |
| 124 | + return NDArrayCaster::from_cpp( |
| 125 | + NDArray {ptr, NumIndices, shape, handle()}, |
| 126 | + policy, |
| 127 | + cleanup); |
| 128 | + } |
| 129 | + |
| 130 | + operator MapType() { |
| 131 | + NDArray &t = caster.value; |
| 132 | + std::array<long, NumIndices> shape; |
| 133 | + for (size_t i = 0 ; i < NumIndices; i++) { |
| 134 | + shape[i] = t.shape(i); |
| 135 | + } |
| 136 | + return MapType(t.data(), shape); |
| 137 | + } |
| 138 | +}; |
| 139 | + |
| 140 | + |
| 141 | +/** \brief Type caster for plain ``Eigen::Tensor<T>`` types. |
| 142 | + */ |
| 143 | +template<typename Scalar, int NumIndices, int Options, typename IndexType> |
| 144 | +struct type_caster< |
| 145 | + Eigen::Tensor<Scalar, NumIndices, Options, IndexType>, |
| 146 | + enable_if_t<is_ndarray_scalar_v<Scalar>>> { |
| 147 | + |
| 148 | + using PlainTensor = Eigen::Tensor<Scalar, NumIndices, Options, IndexType>; |
| 149 | + using Dimensions = typename PlainTensor::Dimensions; |
| 150 | + using Coeffs = typename PlainTensor::CoeffReturnType; |
| 151 | + static constexpr bool IsRowMajor = bool(Options & Eigen::RowMajorBit); |
| 152 | + using NDArray = ndarray_for_eigen_tensor_t<PlainTensor>; |
| 153 | + using NDArrayCaster = make_caster<NDArray>; |
| 154 | + |
| 155 | + // PlainTensor value; |
| 156 | + NB_TYPE_CASTER(PlainTensor, NDArrayCaster::Name); |
| 157 | + |
| 158 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { |
| 159 | + using NDArrayConst = ndarray_for_eigen_tensor_t<PlainTensor, const Scalar>; |
| 160 | + make_caster<NDArrayConst> caster; |
| 161 | + // Do not accept None |
| 162 | + if (!caster.from_python(src, flags & ~(uint8_t)cast_flags::accepts_none, cleanup)) |
| 163 | + return false; |
| 164 | + |
| 165 | + const NDArrayConst &array = caster.value; |
| 166 | + // copy tensor dims |
| 167 | + std::array<long, NumIndices> out_dims; |
| 168 | + for(size_t i = 0; i < NumIndices; i++) { |
| 169 | + out_dims[i] = array.shape(i); |
| 170 | + } |
| 171 | + value.resize(out_dims); |
| 172 | + |
| 173 | + memcpy(value.data(), array.data(), array.size() * sizeof(Scalar)); |
| 174 | + |
| 175 | + return true; |
| 176 | + } |
| 177 | + |
| 178 | + template<typename T2> |
| 179 | + static handle from_cpp(T2 &&v, rv_policy policy, cleanup_list *cleanup) noexcept { |
| 180 | + policy = infer_policy<T2>(policy); |
| 181 | + if constexpr (std::is_pointer_v<T2>) |
| 182 | + return from_cpp_internal((const PlainTensor &) *v, policy, cleanup); |
| 183 | + else |
| 184 | + return from_cpp_internal((const PlainTensor &) v, policy, cleanup); |
| 185 | + } |
| 186 | + |
| 187 | + static handle from_cpp_internal(const PlainTensor &v, rv_policy policy, cleanup_list *cleanup) noexcept { |
| 188 | + size_t shape[NumIndices]; |
| 189 | + |
| 190 | + for (size_t i = 0 ; i < NumIndices; i++) { |
| 191 | + shape[i] = (size_t) v.dimension(i); |
| 192 | + } |
| 193 | + |
| 194 | + void *ptr = (void *)v.data(); |
| 195 | + |
| 196 | + object owner; |
| 197 | + if (policy == rv_policy::move) { |
| 198 | + PlainTensor *tmp = new PlainTensor((PlainTensor&&)v); |
| 199 | + owner = capsule(tmp, [](void* p) noexcept { |
| 200 | + delete (PlainTensor*) p; |
| 201 | + }); |
| 202 | + ptr = tmp->data(); |
| 203 | + policy = rv_policy::reference; |
| 204 | + } else if (policy == rv_policy::reference_internal && cleanup->self()) { |
| 205 | + owner = borrow(cleanup->self()); |
| 206 | + policy = rv_policy::reference; |
| 207 | + } |
| 208 | + return NDArrayCaster::from_cpp( |
| 209 | + NDArray {ptr, NumIndices, shape, owner}, |
| 210 | + policy, cleanup); |
| 211 | + } |
| 212 | +}; |
| 213 | + |
| 214 | +/** \brief Type caster for Tensor expressions. From-cpp conversion just converts the expression to a plain Tensor object. |
| 215 | + */ |
| 216 | +template<typename T> |
| 217 | +struct type_caster<T, enable_if_t<is_eigen_tensor_xpr_v<T> && is_ndarray_scalar_v<typename T::Scalar>>> { |
| 218 | + static constexpr int NumDimensions = T::NumDimensions; |
| 219 | + static constexpr int Options = T::Options; |
| 220 | + using IndexType = typename T::Index; |
| 221 | + using XprTraits = typename Eigen::internal::traits<T>; |
| 222 | + static constexpr int Layout = XprTraits::Layout; |
| 223 | + using PlainTensor = Eigen::Tensor<typename T::Scalar, NumDimensions, Layout, IndexType>; |
| 224 | + using Caster = make_caster<PlainTensor>; |
| 225 | + static constexpr auto Name = Caster::Name; |
| 226 | + template<typename T_> using Cast = T; |
| 227 | + template<typename T_> static constexpr bool can_cast() { return true; } |
| 228 | + |
| 229 | + /// Generating an expression template from a Python object is impossible. |
| 230 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept = delete; |
| 231 | + |
| 232 | + template <typename T2> |
| 233 | + static handle from_cpp(T2 &&v, rv_policy policy, cleanup_list *cleanup) noexcept { |
| 234 | + return Caster::from_cpp(std::forward<T2>(v), policy, cleanup); |
| 235 | + } |
| 236 | +}; |
| 237 | + |
| 238 | +/** \brief Type caster for ``Eigen::TensorRef<T>`` |
| 239 | + */ |
| 240 | +template<typename T> |
| 241 | +struct type_caster< |
| 242 | + Eigen::TensorRef<T>, |
| 243 | + enable_if_t<is_ndarray_scalar_v<typename T::Scalar>>> { |
| 244 | + |
| 245 | + using Scalar = typename T::Scalar; |
| 246 | + using IndexType = typename T::Index; |
| 247 | + static constexpr int NumIndices = T::NumIndices; |
| 248 | + static constexpr int Options = T::Options; |
| 249 | + using PlainTensor = Eigen::Tensor<Scalar, NumIndices, Options, IndexType>; |
| 250 | + using Dimensions = typename T::Dimensions; |
| 251 | + |
| 252 | + // Only partial specification. Dimensions not known at compile time... |
| 253 | + using NDArray = |
| 254 | + ndarray_for_eigen_tensor_t<T, std::conditional_t<std::is_const_v<T>, |
| 255 | + const Scalar, |
| 256 | + Scalar>>; |
| 257 | + using NDArrayCaster = make_caster<NDArray>; |
| 258 | + |
| 259 | + using MapType = Eigen::TensorMap<T>; |
| 260 | + using MapCaster = make_caster<MapType>; |
| 261 | + |
| 262 | + using RefType = Eigen::TensorRef<T>; |
| 263 | + |
| 264 | + |
| 265 | + static constexpr bool MaybeConvert = std::is_const_v<T>; |
| 266 | + using PlainCaster = make_caster<PlainTensor>; |
| 267 | + |
| 268 | + static constexpr auto Name = const_name<MaybeConvert>(PlainCaster::Name, MapCaster::Name); |
| 269 | + |
| 270 | + template<typename T_> using Cast = RefType; |
| 271 | + template<typename T_> static constexpr bool can_cast() { return true; }; |
| 272 | + |
| 273 | + MapCaster caster; |
| 274 | + struct Empty {}; |
| 275 | + std::conditional_t<MaybeConvert, PlainCaster, Empty> plain_caster; |
| 276 | + |
| 277 | + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { |
| 278 | + // no conversion for mutable Ref |
| 279 | + if constexpr (!std::is_const_v<T>) |
| 280 | + flags &= ~(uint8_t) cast_flags::convert; |
| 281 | + |
| 282 | + // Try direct cast |
| 283 | + if (caster.from_python(src, flags, cleanup)) |
| 284 | + return true; |
| 285 | + |
| 286 | + // if const T, attempt leveraging PlainTensor conversion |
| 287 | + if constexpr (MaybeConvert) { |
| 288 | + // we create a new temporary tensor object, and |
| 289 | + // its lifetime is that of plain_caster. |
| 290 | + // for manual conversion, disable conversion. |
| 291 | + if ((flags & (uint8_t) cast_flags::manual)) |
| 292 | + flags &= ~(uint8_t) cast_flags::convert; |
| 293 | + if (plain_caster.from_python(src, flags, cleanup)) |
| 294 | + return true; |
| 295 | + } |
| 296 | + |
| 297 | + return false; |
| 298 | + } |
| 299 | + |
| 300 | + static handle from_cpp(const RefType &v, rv_policy policy, cleanup_list *cleanup) noexcept { |
| 301 | + size_t shape[NumIndices]; |
| 302 | + |
| 303 | + for (size_t i = 0; i < NumIndices; i++) { |
| 304 | + shape[i] = (size_t) v.dimension(i); |
| 305 | + } |
| 306 | + |
| 307 | + return NDArrayCaster::from_cpp( |
| 308 | + NDArray((void *) v.data(), NumIndices, shape, handle()), |
| 309 | + (policy == rv_policy::automatic || |
| 310 | + policy == rv_policy::automatic_reference) |
| 311 | + ? rv_policy::reference |
| 312 | + : policy, |
| 313 | + cleanup); |
| 314 | + } |
| 315 | + |
| 316 | + operator RefType() { |
| 317 | + if constexpr (MaybeConvert) { |
| 318 | + // if there's a value, return it |
| 319 | + if (plain_caster.caster.value.is_valid()) |
| 320 | + return RefType(plain_caster.operator PlainTensor&()); |
| 321 | + } |
| 322 | + return RefType(caster.operator MapType()); |
| 323 | + } |
| 324 | +}; |
| 325 | + |
| 326 | + |
| 327 | +NAMESPACE_END(detail) |
| 328 | + |
| 329 | +NAMESPACE_END(NB_NAMESPACE) |
0 commit comments