Skip to content

Commit 67fd6de

Browse files
committed
HFFT family implemented for xarray (#6)
1 parent c778fbb commit 67fd6de

2 files changed

Lines changed: 102 additions & 41 deletions

File tree

include/xtensor-fftw/basic.hpp

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#define XTENSOR_FFTW_BASIC_HPP
1818

1919
#include <xtensor/xarray.hpp>
20+
#include "xtensor/xcomplex.hpp"
21+
#include "xtensor/xeval.hpp"
2022
#include <xtl/xcomplex.hpp>
2123
#include <complex>
2224
#include <tuple>
@@ -191,6 +193,8 @@ namespace xt {
191193

192194
// Callers for fftw_plan_dft, since they have different call signatures and the
193195
// way shape information is extracted from xtensor differs for different dimensionalities.
196+
197+
// REGULAR FFT N-dim
194198
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
195199
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
196200
-> std::enable_if_t<dimensional::is_n<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
@@ -207,6 +211,7 @@ namespace xt {
207211
flags);
208212
};
209213

214+
// REGULAR FFT 1D
210215
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
211216
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
212217
-> std::enable_if_t<dimensional::is_1<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
@@ -222,6 +227,7 @@ namespace xt {
222227
flags);
223228
};
224229

230+
// REGULAR FFT 2D
225231
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
226232
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
227233
-> std::enable_if_t<dimensional::is_2<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
@@ -237,6 +243,7 @@ namespace xt {
237243
flags);
238244
};
239245

246+
// REGULAR FFT 3D
240247
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
241248
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
242249
-> std::enable_if_t<dimensional::is_3<dim, fftw_123dim>::value && (fftw_direction != 0), typename fftw_t<input_t>::plan> {
@@ -252,6 +259,7 @@ namespace xt {
252259
flags);
253260
};
254261

262+
// REAL FFT N-dim
255263
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
256264
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
257265
-> std::enable_if_t<dimensional::is_n<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
@@ -267,6 +275,7 @@ namespace xt {
267275
flags);
268276
};
269277

278+
// REAL FFT 1D
270279
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
271280
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
272281
-> std::enable_if_t<dimensional::is_1<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
@@ -281,6 +290,7 @@ namespace xt {
281290
flags);
282291
};
283292

293+
// REAL FFT 2D
284294
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
285295
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
286296
-> std::enable_if_t<dimensional::is_2<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
@@ -295,6 +305,7 @@ namespace xt {
295305
flags);
296306
};
297307

308+
// REAL FFT 3D
298309
template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t, typename output_t, typename fftw_plan_dft_signature<input_t, output_t, dim, 0, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
299310
inline auto fftw_plan_dft_caller(const xt::xarray<input_t, layout_type::row_major> &input, xt::xarray<output_t, layout_type::row_major> &output, unsigned int flags)
300311
-> std::enable_if_t<dimensional::is_3<dim, fftw_123dim>::value && (fftw_direction == 0), typename fftw_t<input_t>::plan> {
@@ -310,8 +321,6 @@ namespace xt {
310321
};
311322

312323

313-
314-
315324
////
316325
// General: xarray templates
317326
////
@@ -378,6 +387,63 @@ namespace xt {
378387
return output / N_dft;
379388
};
380389

390+
template <
391+
typename input_t, typename output_t, std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in,
392+
typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft,
393+
void (&fftw_execute)(typename fftw_t<input_t>::plan), void (&fftw_destroy_plan)(typename fftw_t<input_t>::plan),
394+
typename = std::enable_if_t<
395+
std::is_same< prec_t<input_t>, prec_t<output_t> >::value // input and output precision must be the same
396+
&& std::is_floating_point< prec_t<input_t> >::value // numbers must be float, double or long double
397+
&& (dimensional::is_123<dim, fftw_123dim>::value // dimensionality must match fftw_123dim
398+
|| dimensional::is_n<dim, fftw_123dim>::value)
399+
>
400+
>
401+
inline xt::xarray<output_t> _hfft_(const xt::xarray<input_t, layout_type::row_major> &input) {
402+
auto output_shape = output_shape_from_input(input, half_plus_one_out, half_plus_one_in);
403+
xt::xarray<output_t, layout_type::row_major> output(output_shape);
404+
405+
xt::xarray<input_t, layout_type::row_major> input_conj = xt::conj(input);
406+
407+
auto plan = fftw_plan_dft_caller<dim, fftw_direction, fftw_123dim, input_t, output_t, fftw_plan_dft, half_plus_one_out, half_plus_one_in>(input_conj, output, FFTW_ESTIMATE);
408+
if (plan == nullptr) {
409+
throw std::runtime_error("Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW).");
410+
}
411+
412+
fftw_execute(plan);
413+
fftw_destroy_plan(plan);
414+
return output;
415+
};
416+
417+
template <
418+
typename input_t, typename output_t, std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in,
419+
typename fftw_plan_dft_signature<input_t, output_t, dim, fftw_direction, fftw_123dim>::type fftw_plan_dft,
420+
void (&fftw_execute)(typename fftw_t<input_t>::plan), void (&fftw_destroy_plan)(typename fftw_t<input_t>::plan),
421+
typename = std::enable_if_t<
422+
std::is_same< prec_t<input_t>, prec_t<output_t> >::value // input and output precision must be the same
423+
&& std::is_floating_point< prec_t<input_t> >::value // numbers must be float, double or long double
424+
&& (dimensional::is_123<dim, fftw_123dim>::value // dimensionality must match fftw_123dim
425+
|| dimensional::is_n<dim, fftw_123dim>::value)
426+
>
427+
>
428+
inline xt::xarray<output_t> _ihfft_(const xt::xarray<input_t, layout_type::row_major> &input) {
429+
auto output_shape = output_shape_from_input(input, half_plus_one_out, half_plus_one_in);
430+
xt::xarray<output_t, layout_type::row_major> output(output_shape);
431+
432+
auto plan = fftw_plan_dft_caller<dim, fftw_direction, fftw_123dim, input_t, output_t, fftw_plan_dft, half_plus_one_out, half_plus_one_in>(input, output, FFTW_ESTIMATE);
433+
if (plan == nullptr) {
434+
throw std::runtime_error("Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW).");
435+
}
436+
437+
fftw_execute(plan);
438+
fftw_destroy_plan(plan);
439+
440+
output = xt::conj(output);
441+
442+
auto dft_dimensions = dft_dimensions_from_output(output, half_plus_one_out);
443+
auto N_dft = static_cast<prec_t<output_t> >(std::accumulate(dft_dimensions.begin(), dft_dimensions.end(), 1, std::multiplies<std::size_t>()));
444+
return output / N_dft;
445+
};
446+
381447

382448
////
383449
// General: xtensor templates
@@ -682,55 +748,55 @@ namespace xt {
682748
////
683749

684750
inline xt::xarray<float> hfft (const xt::xarray<std::complex<float> > &input) {
685-
return _fft_<std::complex<float>, float, 1, 0, true, false, true, fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input);
751+
return _hfft_<std::complex<float>, float, 1, 0, true, false, true, fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input);
686752
}
687753

688754
inline xt::xarray<std::complex<float> > ihfft (const xt::xarray<float> &input) {
689-
return _ifft_<float, std::complex<float>, 1, 0, true, true, false, fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input);
755+
return _ihfft_<float, std::complex<float>, 1, 0, true, true, false, fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input);
690756
}
691757

692758
inline xt::xarray<double> hfft (const xt::xarray<std::complex<double> > &input) {
693-
return _fft_<std::complex<double>, double, 1, 0, true, false, true, fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input);
759+
return _hfft_<std::complex<double>, double, 1, 0, true, false, true, fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input);
694760
}
695761

696762
inline xt::xarray<std::complex<double> > ihfft (const xt::xarray<double> &input) {
697-
return _ifft_<double, std::complex<double>, 1, 0, true, true, false, fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input);
763+
return _ihfft_<double, std::complex<double>, 1, 0, true, true, false, fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input);
698764
}
699765

700766
inline xt::xarray<long double> hfft (const xt::xarray<std::complex<long double> > &input) {
701-
return _fft_<std::complex<long double>, long double, 1, 0, true, false, true, fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input);
767+
return _hfft_<std::complex<long double>, long double, 1, 0, true, false, true, fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input);
702768
}
703769

704770
inline xt::xarray<std::complex<long double> > ihfft (const xt::xarray<long double> &input) {
705-
return _ifft_<long double, std::complex<long double>, 1, 0, true, true, false, fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input);
771+
return _ihfft_<long double, std::complex<long double>, 1, 0, true, true, false, fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input);
706772
}
707773

708774
////
709775
// Hermitian FFT: 2D
710776
////
711777

712778
inline xt::xarray<float> hfft2 (const xt::xarray<std::complex<float> > &input) {
713-
return _fft_<std::complex<float>, float, 2, 0, true, false, true, fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input);
779+
return _hfft_<std::complex<float>, float, 2, 0, true, false, true, fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input);
714780
}
715781

716782
inline xt::xarray<std::complex<float> > ihfft2 (const xt::xarray<float> &input) {
717-
return _ifft_<float, std::complex<float>, 2, 0, true, true, false, fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input);
783+
return _ihfft_<float, std::complex<float>, 2, 0, true, true, false, fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input);
718784
}
719785

720786
inline xt::xarray<double> hfft2 (const xt::xarray<std::complex<double> > &input) {
721-
return _fft_<std::complex<double>, double, 2, 0, true, false, true, fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input);
787+
return _hfft_<std::complex<double>, double, 2, 0, true, false, true, fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input);
722788
}
723789

724790
inline xt::xarray<std::complex<double> > ihfft2 (const xt::xarray<double> &input) {
725-
return _ifft_<double, std::complex<double>, 2, 0, true, true, false, fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input);
791+
return _ihfft_<double, std::complex<double>, 2, 0, true, true, false, fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input);
726792
}
727793

728794
inline xt::xarray<long double> hfft2 (const xt::xarray<std::complex<long double> > &input) {
729-
return _fft_<std::complex<long double>, long double, 2, 0, true, false, true, fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input);
795+
return _hfft_<std::complex<long double>, long double, 2, 0, true, false, true, fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input);
730796
}
731797

732798
inline xt::xarray<std::complex<long double> > ihfft2 (const xt::xarray<long double> &input) {
733-
return _ifft_<long double, std::complex<long double>, 2, 0, true, true, false, fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input);
799+
return _ihfft_<long double, std::complex<long double>, 2, 0, true, true, false, fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input);
734800
}
735801

736802

@@ -739,27 +805,27 @@ namespace xt {
739805
////
740806

741807
inline xt::xarray<float> hfft3 (const xt::xarray<std::complex<float> > &input) {
742-
return _fft_<std::complex<float>, float, 3, 0, true, false, true, fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input);
808+
return _hfft_<std::complex<float>, float, 3, 0, true, false, true, fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input);
743809
}
744810

745811
inline xt::xarray<std::complex<float> > ihfft3 (const xt::xarray<float> &input) {
746-
return _ifft_<float, std::complex<float>, 3, 0, true, true, false, fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input);
812+
return _ihfft_<float, std::complex<float>, 3, 0, true, true, false, fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input);
747813
}
748814

749815
inline xt::xarray<double> hfft3 (const xt::xarray<std::complex<double> > &input) {
750-
return _fft_<std::complex<double>, double, 3, 0, true, false, true, fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input);
816+
return _hfft_<std::complex<double>, double, 3, 0, true, false, true, fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input);
751817
}
752818

753819
inline xt::xarray<std::complex<double> > ihfft3 (const xt::xarray<double> &input) {
754-
return _ifft_<double, std::complex<double>, 3, 0, true, true, false, fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input);
820+
return _ihfft_<double, std::complex<double>, 3, 0, true, true, false, fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input);
755821
}
756822

757823
inline xt::xarray<long double> hfft3 (const xt::xarray<std::complex<long double> > &input) {
758-
return _fft_<std::complex<long double>, long double, 3, 0, true, false, true, fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input);
824+
return _hfft_<std::complex<long double>, long double, 3, 0, true, false, true, fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input);
759825
}
760826

761827
inline xt::xarray<std::complex<long double> > ihfft3 (const xt::xarray<long double> &input) {
762-
return _ifft_<long double, std::complex<long double>, 3, 0, true, true, false, fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input);
828+
return _ihfft_<long double, std::complex<long double>, 3, 0, true, true, false, fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input);
763829
}
764830

765831

@@ -769,32 +835,32 @@ namespace xt {
769835

770836
template <std::size_t dim>
771837
inline xt::xarray<float> hfftn (const xt::xarray<std::complex<float> > &input) {
772-
return _fft_<std::complex<float>, float, dim, 0, false, false, true, fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input);
838+
return _hfft_<std::complex<float>, float, dim, 0, false, false, true, fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input);
773839
}
774840

775841
template <std::size_t dim>
776842
inline xt::xarray<std::complex<float> > ihfftn (const xt::xarray<float> &input) {
777-
return _ifft_<float, std::complex<float>, dim, 0, false, true, false, fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input);
843+
return _ihfft_<float, std::complex<float>, dim, 0, false, true, false, fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input);
778844
}
779845

780846
template <std::size_t dim>
781847
inline xt::xarray<double> hfftn (const xt::xarray<std::complex<double> > &input) {
782-
return _fft_<std::complex<double>, double, dim, 0, false, false, true, fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input);
848+
return _hfft_<std::complex<double>, double, dim, 0, false, false, true, fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input);
783849
}
784850

785851
template <std::size_t dim>
786852
inline xt::xarray<std::complex<double> > ihfftn (const xt::xarray<double> &input) {
787-
return _ifft_<double, std::complex<double>, dim, 0, false, true, false, fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input);
853+
return _ihfft_<double, std::complex<double>, dim, 0, false, true, false, fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input);
788854
}
789855

790856
template <std::size_t dim>
791857
inline xt::xarray<long double> hfftn (const xt::xarray<std::complex<long double> > &input) {
792-
return _fft_<std::complex<long double>, long double, dim, 0, false, false, true, fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input);
858+
return _hfft_<std::complex<long double>, long double, dim, 0, false, false, true, fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input);
793859
}
794860

795861
template <std::size_t dim>
796862
inline xt::xarray<std::complex<long double> > ihfftn (const xt::xarray<long double> &input) {
797-
return _ifft_<long double, std::complex<long double>, dim, 0, false, true, false, fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input);
863+
return _ihfft_<long double, std::complex<long double>, dim, 0, false, true, false, fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input);
798864
}
799865

800866
}

0 commit comments

Comments
 (0)