1717#define XTENSOR_FFTW_BASIC_HPP
1818
1919#include < xtensor/xarray.hpp>
20+ #include " xtensor/xcomplex.hpp"
21+ #include " xtensor/xeval.hpp"
2022#include < xtl/xcomplex.hpp>
2123#include < complex>
2224#include < tuple>
@@ -191,6 +193,8 @@ namespace xt {
191193
192194 // Callers for fftw_plan_dft, since they have different call signatures and the
193195 // way shape information is extracted from xtensor differs for different dimensionalities.
196+
197+ // REGULAR FFT N-dim
194198 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
195199 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
196200 -> std::enable_if_t<dimensional::is_n<dim, fftw_123dim>::value && (fftw_direction != 0 ), typename fftw_t<input_t>::plan> {
@@ -207,6 +211,7 @@ namespace xt {
207211 flags);
208212 };
209213
214+ // REGULAR FFT 1D
210215 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
211216 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
212217 -> std::enable_if_t<dimensional::is_1<dim, fftw_123dim>::value && (fftw_direction != 0 ), typename fftw_t<input_t>::plan> {
@@ -222,6 +227,7 @@ namespace xt {
222227 flags);
223228 };
224229
230+ // REGULAR FFT 2D
225231 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
226232 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
227233 -> std::enable_if_t<dimensional::is_2<dim, fftw_123dim>::value && (fftw_direction != 0 ), typename fftw_t<input_t>::plan> {
@@ -237,6 +243,7 @@ namespace xt {
237243 flags);
238244 };
239245
246+ // REGULAR FFT 3D
240247 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, fftw_direction, fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
241248 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
242249 -> std::enable_if_t<dimensional::is_3<dim, fftw_123dim>::value && (fftw_direction != 0 ), typename fftw_t<input_t>::plan> {
@@ -252,6 +259,7 @@ namespace xt {
252259 flags);
253260 };
254261
262+ // REAL FFT N-dim
255263 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, 0 , fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
256264 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
257265 -> std::enable_if_t<dimensional::is_n<dim, fftw_123dim>::value && (fftw_direction == 0 ), typename fftw_t<input_t>::plan> {
@@ -267,6 +275,7 @@ namespace xt {
267275 flags);
268276 };
269277
278+ // REAL FFT 1D
270279 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, 0 , fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
271280 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
272281 -> std::enable_if_t<dimensional::is_1<dim, fftw_123dim>::value && (fftw_direction == 0 ), typename fftw_t<input_t>::plan> {
@@ -281,6 +290,7 @@ namespace xt {
281290 flags);
282291 };
283292
293+ // REAL FFT 2D
284294 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, 0 , fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
285295 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
286296 -> std::enable_if_t<dimensional::is_2<dim, fftw_123dim>::value && (fftw_direction == 0 ), typename fftw_t<input_t>::plan> {
@@ -295,6 +305,7 @@ namespace xt {
295305 flags);
296306 };
297307
308+ // REAL FFT 3D
298309 template <std::size_t dim, int fftw_direction, bool fftw_123dim, typename input_t , typename output_t , typename fftw_plan_dft_signature<input_t , output_t , dim, 0 , fftw_123dim>::type fftw_plan_dft, bool half_plus_one_out, bool half_plus_one_in>
299310 inline auto fftw_plan_dft_caller (const xt::xarray<input_t , layout_type::row_major> &input, xt::xarray<output_t , layout_type::row_major> &output, unsigned int flags)
300311 -> std::enable_if_t<dimensional::is_3<dim, fftw_123dim>::value && (fftw_direction == 0 ), typename fftw_t<input_t>::plan> {
@@ -310,8 +321,6 @@ namespace xt {
310321 };
311322
312323
313-
314-
315324 // //
316325 // General: xarray templates
317326 // //
@@ -378,6 +387,63 @@ namespace xt {
378387 return output / N_dft;
379388 };
380389
390+ template <
391+ typename input_t , typename output_t , std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in,
392+ typename fftw_plan_dft_signature<input_t , output_t , dim, fftw_direction, fftw_123dim>::type fftw_plan_dft,
393+ void (&fftw_execute)(typename fftw_t <input_t >::plan), void (&fftw_destroy_plan)(typename fftw_t <input_t >::plan),
394+ typename = std::enable_if_t<
395+ std::is_same< prec_t<input_t>, prec_t<output_t> >::value // input and output precision must be the same
396+ && std::is_floating_point< prec_t<input_t> >::value // numbers must be float, double or long double
397+ && (dimensional::is_123<dim, fftw_123dim>::value // dimensionality must match fftw_123dim
398+ || dimensional::is_n<dim, fftw_123dim>::value)
399+ >
400+ >
401+ inline xt::xarray<output_t> _hfft_(const xt::xarray<input_t , layout_type::row_major> &input) {
402+ auto output_shape = output_shape_from_input (input, half_plus_one_out, half_plus_one_in);
403+ xt::xarray<output_t , layout_type::row_major> output (output_shape);
404+
405+ xt::xarray<input_t , layout_type::row_major> input_conj = xt::conj (input);
406+
407+ auto plan = fftw_plan_dft_caller<dim, fftw_direction, fftw_123dim, input_t , output_t , fftw_plan_dft, half_plus_one_out, half_plus_one_in>(input_conj, output, FFTW_ESTIMATE);
408+ if (plan == nullptr ) {
409+ throw std::runtime_error (" Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW)." );
410+ }
411+
412+ fftw_execute (plan);
413+ fftw_destroy_plan (plan);
414+ return output;
415+ };
416+
417+ template <
418+ typename input_t , typename output_t , std::size_t dim, int fftw_direction, bool fftw_123dim, bool half_plus_one_out, bool half_plus_one_in,
419+ typename fftw_plan_dft_signature<input_t , output_t , dim, fftw_direction, fftw_123dim>::type fftw_plan_dft,
420+ void (&fftw_execute)(typename fftw_t <input_t >::plan), void (&fftw_destroy_plan)(typename fftw_t <input_t >::plan),
421+ typename = std::enable_if_t<
422+ std::is_same< prec_t<input_t>, prec_t<output_t> >::value // input and output precision must be the same
423+ && std::is_floating_point< prec_t<input_t> >::value // numbers must be float, double or long double
424+ && (dimensional::is_123<dim, fftw_123dim>::value // dimensionality must match fftw_123dim
425+ || dimensional::is_n<dim, fftw_123dim>::value)
426+ >
427+ >
428+ inline xt::xarray<output_t> _ihfft_(const xt::xarray<input_t , layout_type::row_major> &input) {
429+ auto output_shape = output_shape_from_input (input, half_plus_one_out, half_plus_one_in);
430+ xt::xarray<output_t , layout_type::row_major> output (output_shape);
431+
432+ auto plan = fftw_plan_dft_caller<dim, fftw_direction, fftw_123dim, input_t , output_t , fftw_plan_dft, half_plus_one_out, half_plus_one_in>(input, output, FFTW_ESTIMATE);
433+ if (plan == nullptr ) {
434+ throw std::runtime_error (" Plan creation returned nullptr. This usually means FFTW cannot create a plan for the given arguments (e.g. a non-destructive multi-dimensional real FFT is impossible in FFTW)." );
435+ }
436+
437+ fftw_execute (plan);
438+ fftw_destroy_plan (plan);
439+
440+ output = xt::conj (output);
441+
442+ auto dft_dimensions = dft_dimensions_from_output (output, half_plus_one_out);
443+ auto N_dft = static_cast <prec_t <output_t > >(std::accumulate (dft_dimensions.begin (), dft_dimensions.end (), 1 , std::multiplies<std::size_t >()));
444+ return output / N_dft;
445+ };
446+
381447
382448 // //
383449 // General: xtensor templates
@@ -682,55 +748,55 @@ namespace xt {
682748 // //
683749
684750 inline xt::xarray<float > hfft (const xt::xarray<std::complex <float > > &input) {
685- return _fft_ <std::complex <float >, float , 1 , 0 , true , false , true , fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input);
751+ return _hfft_ <std::complex <float >, float , 1 , 0 , true , false , true , fftwf_plan_dft_c2r_1d, fftwf_execute, fftwf_destroy_plan> (input);
686752 }
687753
688754 inline xt::xarray<std::complex <float > > ihfft (const xt::xarray<float > &input) {
689- return _ifft_ <float , std::complex <float >, 1 , 0 , true , true , false , fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input);
755+ return _ihfft_ <float , std::complex <float >, 1 , 0 , true , true , false , fftwf_plan_dft_r2c_1d, fftwf_execute, fftwf_destroy_plan> (input);
690756 }
691757
692758 inline xt::xarray<double > hfft (const xt::xarray<std::complex <double > > &input) {
693- return _fft_ <std::complex <double >, double , 1 , 0 , true , false , true , fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input);
759+ return _hfft_ <std::complex <double >, double , 1 , 0 , true , false , true , fftw_plan_dft_c2r_1d, fftw_execute, fftw_destroy_plan> (input);
694760 }
695761
696762 inline xt::xarray<std::complex <double > > ihfft (const xt::xarray<double > &input) {
697- return _ifft_ <double , std::complex <double >, 1 , 0 , true , true , false , fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input);
763+ return _ihfft_ <double , std::complex <double >, 1 , 0 , true , true , false , fftw_plan_dft_r2c_1d, fftw_execute, fftw_destroy_plan> (input);
698764 }
699765
700766 inline xt::xarray<long double > hfft (const xt::xarray<std::complex <long double > > &input) {
701- return _fft_ <std::complex <long double >, long double , 1 , 0 , true , false , true , fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input);
767+ return _hfft_ <std::complex <long double >, long double , 1 , 0 , true , false , true , fftwl_plan_dft_c2r_1d, fftwl_execute, fftwl_destroy_plan> (input);
702768 }
703769
704770 inline xt::xarray<std::complex <long double > > ihfft (const xt::xarray<long double > &input) {
705- return _ifft_ <long double , std::complex <long double >, 1 , 0 , true , true , false , fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input);
771+ return _ihfft_ <long double , std::complex <long double >, 1 , 0 , true , true , false , fftwl_plan_dft_r2c_1d, fftwl_execute, fftwl_destroy_plan> (input);
706772 }
707773
708774 // //
709775 // Hermitian FFT: 2D
710776 // //
711777
712778 inline xt::xarray<float > hfft2 (const xt::xarray<std::complex <float > > &input) {
713- return _fft_ <std::complex <float >, float , 2 , 0 , true , false , true , fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input);
779+ return _hfft_ <std::complex <float >, float , 2 , 0 , true , false , true , fftwf_plan_dft_c2r_2d, fftwf_execute, fftwf_destroy_plan> (input);
714780 }
715781
716782 inline xt::xarray<std::complex <float > > ihfft2 (const xt::xarray<float > &input) {
717- return _ifft_ <float , std::complex <float >, 2 , 0 , true , true , false , fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input);
783+ return _ihfft_ <float , std::complex <float >, 2 , 0 , true , true , false , fftwf_plan_dft_r2c_2d, fftwf_execute, fftwf_destroy_plan> (input);
718784 }
719785
720786 inline xt::xarray<double > hfft2 (const xt::xarray<std::complex <double > > &input) {
721- return _fft_ <std::complex <double >, double , 2 , 0 , true , false , true , fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input);
787+ return _hfft_ <std::complex <double >, double , 2 , 0 , true , false , true , fftw_plan_dft_c2r_2d, fftw_execute, fftw_destroy_plan> (input);
722788 }
723789
724790 inline xt::xarray<std::complex <double > > ihfft2 (const xt::xarray<double > &input) {
725- return _ifft_ <double , std::complex <double >, 2 , 0 , true , true , false , fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input);
791+ return _ihfft_ <double , std::complex <double >, 2 , 0 , true , true , false , fftw_plan_dft_r2c_2d, fftw_execute, fftw_destroy_plan> (input);
726792 }
727793
728794 inline xt::xarray<long double > hfft2 (const xt::xarray<std::complex <long double > > &input) {
729- return _fft_ <std::complex <long double >, long double , 2 , 0 , true , false , true , fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input);
795+ return _hfft_ <std::complex <long double >, long double , 2 , 0 , true , false , true , fftwl_plan_dft_c2r_2d, fftwl_execute, fftwl_destroy_plan> (input);
730796 }
731797
732798 inline xt::xarray<std::complex <long double > > ihfft2 (const xt::xarray<long double > &input) {
733- return _ifft_ <long double , std::complex <long double >, 2 , 0 , true , true , false , fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input);
799+ return _ihfft_ <long double , std::complex <long double >, 2 , 0 , true , true , false , fftwl_plan_dft_r2c_2d, fftwl_execute, fftwl_destroy_plan> (input);
734800 }
735801
736802
@@ -739,27 +805,27 @@ namespace xt {
739805 // //
740806
741807 inline xt::xarray<float > hfft3 (const xt::xarray<std::complex <float > > &input) {
742- return _fft_ <std::complex <float >, float , 3 , 0 , true , false , true , fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input);
808+ return _hfft_ <std::complex <float >, float , 3 , 0 , true , false , true , fftwf_plan_dft_c2r_3d, fftwf_execute, fftwf_destroy_plan> (input);
743809 }
744810
745811 inline xt::xarray<std::complex <float > > ihfft3 (const xt::xarray<float > &input) {
746- return _ifft_ <float , std::complex <float >, 3 , 0 , true , true , false , fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input);
812+ return _ihfft_ <float , std::complex <float >, 3 , 0 , true , true , false , fftwf_plan_dft_r2c_3d, fftwf_execute, fftwf_destroy_plan> (input);
747813 }
748814
749815 inline xt::xarray<double > hfft3 (const xt::xarray<std::complex <double > > &input) {
750- return _fft_ <std::complex <double >, double , 3 , 0 , true , false , true , fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input);
816+ return _hfft_ <std::complex <double >, double , 3 , 0 , true , false , true , fftw_plan_dft_c2r_3d, fftw_execute, fftw_destroy_plan> (input);
751817 }
752818
753819 inline xt::xarray<std::complex <double > > ihfft3 (const xt::xarray<double > &input) {
754- return _ifft_ <double , std::complex <double >, 3 , 0 , true , true , false , fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input);
820+ return _ihfft_ <double , std::complex <double >, 3 , 0 , true , true , false , fftw_plan_dft_r2c_3d, fftw_execute, fftw_destroy_plan> (input);
755821 }
756822
757823 inline xt::xarray<long double > hfft3 (const xt::xarray<std::complex <long double > > &input) {
758- return _fft_ <std::complex <long double >, long double , 3 , 0 , true , false , true , fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input);
824+ return _hfft_ <std::complex <long double >, long double , 3 , 0 , true , false , true , fftwl_plan_dft_c2r_3d, fftwl_execute, fftwl_destroy_plan> (input);
759825 }
760826
761827 inline xt::xarray<std::complex <long double > > ihfft3 (const xt::xarray<long double > &input) {
762- return _ifft_ <long double , std::complex <long double >, 3 , 0 , true , true , false , fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input);
828+ return _ihfft_ <long double , std::complex <long double >, 3 , 0 , true , true , false , fftwl_plan_dft_r2c_3d, fftwl_execute, fftwl_destroy_plan> (input);
763829 }
764830
765831
@@ -769,32 +835,32 @@ namespace xt {
769835
770836 template <std::size_t dim>
771837 inline xt::xarray<float > hfftn (const xt::xarray<std::complex <float > > &input) {
772- return _fft_ <std::complex <float >, float , dim, 0 , false , false , true , fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input);
838+ return _hfft_ <std::complex <float >, float , dim, 0 , false , false , true , fftwf_plan_dft_c2r, fftwf_execute, fftwf_destroy_plan> (input);
773839 }
774840
775841 template <std::size_t dim>
776842 inline xt::xarray<std::complex <float > > ihfftn (const xt::xarray<float > &input) {
777- return _ifft_ <float , std::complex <float >, dim, 0 , false , true , false , fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input);
843+ return _ihfft_ <float , std::complex <float >, dim, 0 , false , true , false , fftwf_plan_dft_r2c, fftwf_execute, fftwf_destroy_plan> (input);
778844 }
779845
780846 template <std::size_t dim>
781847 inline xt::xarray<double > hfftn (const xt::xarray<std::complex <double > > &input) {
782- return _fft_ <std::complex <double >, double , dim, 0 , false , false , true , fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input);
848+ return _hfft_ <std::complex <double >, double , dim, 0 , false , false , true , fftw_plan_dft_c2r, fftw_execute, fftw_destroy_plan> (input);
783849 }
784850
785851 template <std::size_t dim>
786852 inline xt::xarray<std::complex <double > > ihfftn (const xt::xarray<double > &input) {
787- return _ifft_ <double , std::complex <double >, dim, 0 , false , true , false , fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input);
853+ return _ihfft_ <double , std::complex <double >, dim, 0 , false , true , false , fftw_plan_dft_r2c, fftw_execute, fftw_destroy_plan> (input);
788854 }
789855
790856 template <std::size_t dim>
791857 inline xt::xarray<long double > hfftn (const xt::xarray<std::complex <long double > > &input) {
792- return _fft_ <std::complex <long double >, long double , dim, 0 , false , false , true , fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input);
858+ return _hfft_ <std::complex <long double >, long double , dim, 0 , false , false , true , fftwl_plan_dft_c2r, fftwl_execute, fftwl_destroy_plan> (input);
793859 }
794860
795861 template <std::size_t dim>
796862 inline xt::xarray<std::complex <long double > > ihfftn (const xt::xarray<long double > &input) {
797- return _ifft_ <long double , std::complex <long double >, dim, 0 , false , true , false , fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input);
863+ return _ihfft_ <long double , std::complex <long double >, dim, 0 , false , true , false , fftwl_plan_dft_r2c, fftwl_execute, fftwl_destroy_plan> (input);
798864 }
799865
800866 }
0 commit comments