/*
Copyright (c) 2020, Michael Kazhdan
All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

Redistributions of source code must retain the above copyright notice, this list of
conditions and the following disclaimer. Redistributions in binary form must reproduce
the above copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the distribution. 

Neither the name of the Johns Hopkins University nor the names of its contributors
may be used to endorse or promote products derived from this software without specific
prior written permission. 

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO THE IMPLIED WARRANTIES 
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE  GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.
*/

#ifndef TENSORS_INCLUDED
#define TENSORS_INCLUDED

#define NEW_TENSOR_CODE

#include <iostream>
#ifdef NEW_TENSOR_CODE
#else // !NEW_TENSOR_CODE
#include <random>
#endif // NEW_TENSOR_CODE
#include "Misha/Algebra.h"
#include "Misha/ParameterPack.h"
#include "Misha/MultiDimensionalArray.h"

namespace MishaK
{
	namespace AutoDiff
	{
		template< typename Pack > struct Tensor;

		// A zero-tensor is the same as a double value
		template<>
		struct Tensor< ParameterPack::UIntPack<> > : public InnerProductSpace< double , Tensor< ParameterPack::UIntPack<> > >
		{
			typedef ParameterPack::UIntPack<> Pack;
			static const unsigned int Size = 1;

			double data;

			Tensor( double d=0 ) : data(d) {}
			explicit operator double &( void ){ return data; }
			explicit operator const double &( void ) const { return data; }

#ifdef NEW_AUTO_DIFF_CODE
			double &operator[]( unsigned int idx ){ return data; }
			const double &operator[]( unsigned int idx ) const { return data; }
#endif // NEW_AUTO_DIFF_CODE

			void Add( const Tensor &t ){ data += t.data; }
			void Scale( double s ){ data *= s; }
			double InnerProduct( const Tensor &t ) const { return data * t.data; }
			template< unsigned int ... _Dims >
			Tensor< ParameterPack::UIntPack< _Dims ... > > operator * ( const Tensor< ParameterPack::UIntPack< _Dims ... > > &t ) const { return t * data; }
			template< unsigned int I , unsigned int ... _Dims >
			Tensor< ParameterPack::UIntPack< _Dims ... > > contractedOuterProduct( const Tensor< ParameterPack::UIntPack< _Dims ... > > &t ) const 
			{
				static_assert( I==0 , "[ERROR] Contraction suffix/prefix don't match" );
				return *this * t;
			}

			// Permute indices
			template< unsigned int ... PermutationValues >
			Tensor< ParameterPack::Permutation< Pack , ParameterPack::UIntPack< PermutationValues ... > > > permute( ParameterPack::UIntPack< PermutationValues ... > ) const
			{
#ifdef NEW_AUTO_DIFF_CODE
				static_assert( sizeof ... ( PermutationValues ) == Size || ( sizeof ... ( PermutationValues ) == 0 && Size==1 ), "[ERROR] Permutation size doesn't match dimension" );
#else // !NEW_AUTO_DIFF_CODE
				static_assert( sizeof ... ( PermutationValues ) == Size , "[ERROR] Permutation size doesn't match dimension" );
#endif // NEW_AUTO_DIFF_CODE
				return *this;
			}

#ifdef NEW_TENSOR_CODE
#else // !NEW_TENSOR_CODE
			static Tensor Random( std::default_random_engine &generator )
			{
				// From https://www.cplusplus.com/reference/random/uniform_real_distribution/
				std::uniform_real_distribution< double > distribution( 0.0 , 1.0 );

				return Tensor( distribution( generator ) );
			}
#endif // NEW_TENSOR_CODE

			static Tensor Identity( void ){ return Tensor( 1. ); }

			friend std::ostream &operator << ( std::ostream &os , const Tensor &t ){ return os << t.data; }
		};

		// A general tensor
		template< unsigned int ... Dims >
		struct Tensor< ParameterPack::UIntPack< Dims ... > > : public MultiDimensionalArray::Array< double , Dims ... > , public InnerProductSpace< double , Tensor< ParameterPack::UIntPack< Dims ... > > >
		{
			typedef ParameterPack::UIntPack< Dims ... > Pack;
			static const unsigned int Size = Pack::Size;

			Tensor( void ){ memset( MultiDimensionalArray::Array< double , Dims ... >::data , 0 , sizeof( double ) * MultiDimensionalArray::ArraySize< Dims ... >() ); }

			template< typename ... UInts >
			double &operator()( unsigned int index , UInts ... indices )
			{
				static_assert( sizeof...(indices)==Pack::Size-1 , "[ERROR] Wrong number of indices" );
				unsigned int idx[] = { index , indices ... };
				return MultiDimensionalArray::Array< double , Dims ... >::operator()( idx );
			}

			template< typename ... UInts >
			const double &operator()( unsigned int index , UInts ... indices ) const
			{
				static_assert( sizeof...(indices)==Pack::Size-1 , "[ERROR] Wrong number of indices" );
				unsigned int idx[] = { index , indices ... };
				return MultiDimensionalArray::Array< double , Dims ... >::operator()( idx );
			}

#ifdef NEW_AUTO_DIFF_CODE
			double &operator[]( unsigned int idx ){ return MultiDimensionalArray::Array< double , Dims ... >::operator[](idx); }
			const double &operator[]( unsigned int idx ) const { return MultiDimensionalArray::Array< double , Dims ... >::operator[](idx); }
#endif // NEW_AUTO_DIFF_CODE

			double &operator()( const unsigned int indices[] ){ return MultiDimensionalArray::Array< double , Dims ... >::operator()( indices ); }
			const double &operator()( const unsigned int indices[] ) const { return MultiDimensionalArray::Array< double , Dims ... >::operator()( indices ); }

			// Inner-product space methods
			void Add( const Tensor &t )
			{
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Values ,
					[]( int d , int i ){} ,
					[]( double &v1 , const double &v2 ){ v1 += v2; } ,
					*this , t
				);
			}
			void Scale( double s )
			{
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Values ,
					[]( int d , int i ){} ,
					[&]( double &v ){ v *= s; } ,
					*this
				);
			}
			double InnerProduct( const Tensor &t ) const
			{
				double innerProduct = 0;
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Values ,
					[]( int d , int i ){} ,
					[&]( double v1 , double v2 ){ innerProduct += v1*v2; } ,
					*this , t
				);
				return innerProduct;
			}

#ifdef NEW_TENSOR_CODE
#else // !NEW_TENSOR_CODE
			static Tensor Random( std::default_random_engine &generator )
			{
				// From https://www.cplusplus.com/reference/random/uniform_real_distribution/
				std::uniform_real_distribution< double > distribution( 0.0 , 1.0 );

				Tensor t;

				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Values ,
					[&]( int , int ){} ,
					[&]( double &v ){ v = distribution( generator ); } ,
					t
				);
				return t;
			}
#endif // NEW_TENSOR_CODE

			static Tensor< ParameterPack::UIntPack< Dims ... , Dims ... > > Identity( void )
			{
				static const unsigned int Size = sizeof ... ( Dims );
				Tensor< ParameterPack::UIntPack< Dims ... , Dims ... > > id;
				unsigned int indices[ Size ];
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , ParameterPack::UIntPack< Dims ... >::Values ,
					[&]( int d , int i ){ indices[d] = indices[ d+Size ] = i;} ,
					[&]( void ){ id( indices ) = 1.; }
				);
				return id;
			}

			template< unsigned int ... PermutationValues >
			static auto PermutationTensor( ParameterPack::UIntPack< PermutationValues ... > )
			{
#pragma message( "[WARNING] Should avoid using PermutationTensor" )
				MK_WARN_ONCE( "Invoking PermutationTensor" );
				Tensor< ParameterPack::Concatenation< ParameterPack::Permutation< Pack , ParameterPack::UIntPack< PermutationValues ... > > , Pack > > t;
				const unsigned int permutation[] = { PermutationValues ... };
				unsigned int idx[ 2*Size ];
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Values ,
					[&]( int d , int i ){ idx[ permutation[d] ] = idx[ Size+d ] = i; } ,
					[&]( void ){ t( idx ) = 1; }
				);
				return t;
			}

			// Permute indices
			template< unsigned int ... PermutationValues >
			Tensor< ParameterPack::Permutation< Pack , ParameterPack::UIntPack< PermutationValues ... > > > permute( ParameterPack::UIntPack< PermutationValues ... > ) const
			{
				static_assert( sizeof ... ( PermutationValues ) == Size , "[ERROR] Permutation size doesn't match dimension" );
				typedef ParameterPack::UIntPack< PermutationValues ... > PPack;
				const unsigned int PValues[] = { PermutationValues ... };
				unsigned int IPValues[ Size ];
				for( unsigned int i=0 ; i<Size ; i++ ) IPValues[ PValues[i] ] = i;

				Tensor< ParameterPack::Permutation< Pack , PPack > > t;
				unsigned int idx[ Size ];
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Values ,
					[&]( int d , int i ){ idx[ IPValues[d] ] = i; } ,
					[&]( const double &v ){ t( idx ) = v; } ,
					*this
				);
				return t;
			}

			// Extract slice
			template< unsigned int I >
			auto extract( const unsigned int indices[/*I*/] ) const
			{
				typedef typename ParameterPack::Partition< I , Pack >::Second Remainder;
				Tensor< Remainder> t;

				if constexpr( Remainder::Size!=0 )
				{
					unsigned int _indices[ Pack::Size ];
					for( unsigned int i=0 ; i<I ; i++ ) _indices[i] = indices[i];

					MultiDimensionalArray::Loop< Remainder::Size >::Run
					(
						ParameterPack::IsotropicUIntPack< Remainder::Size >::Values , Remainder::Values ,
						[&]( int d , int i ){ _indices[d+I] = i; } ,
						[&]( double &_t ){ _t = operator()( _indices ); } ,
						t
					);
				}
				else static_cast< double & >( t ) = operator()( indices );
				return t;
			}

			static auto TransposeTensor( void )
			{
#pragma message( "[WARNING] Should avoid using TransposeTensor" )
				MK_WARN_ONCE( "Invoking TransposeTensor" );
				Tensor< ParameterPack::Concatenation< typename Pack::Transpose , Pack > > t;
				unsigned int idx[ 2*Size ];
				MultiDimensionalArray::Loop< Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Size >::Values , Pack::Transpose::Values ,
					[&]( int d , int i ){ idx[d] = idx[ 2*Size - 1 - d ] = i; } ,
					[&]( void ){ t( idx ) = 1; }
				);
				return t;
			}

			// Transpose operator
			Tensor< typename Pack::Transpose > transpose( void ) const
			{
				return permute( ParameterPack::SequentialPack< unsigned int , Pack::Size >::Transpose() );
			}

			// Outer product
			template< unsigned int ... _Dims >
			Tensor< ParameterPack::Concatenation< Pack , ParameterPack::UIntPack< _Dims ... > > > operator * ( const Tensor< ParameterPack::UIntPack< _Dims ... > > &t ) const 
			{
				typedef ParameterPack::UIntPack< _Dims ... > _Pack;
				Tensor< ParameterPack::Concatenation< Pack , _Pack > > _t;

				MultiDimensionalArray::Loop< Pack::Size >::Run
				(
					ParameterPack::IsotropicUIntPack< Pack::Size >::Values , Pack::Values ,
					[]( int d , int i ){} ,
					[&]( MultiDimensionalArray::ArrayWrapper< double , _Dims ... > __t , const double &v )
					{
						MultiDimensionalArray::Loop< _Pack::Size >::Run
						(
							ParameterPack::IsotropicUIntPack< _Pack::Size >::Values , _Pack::Values ,
							[]( int d , int i ){} ,
							[&]( double &v1 , const double &v2 ){ v1 += v*v2; } ,
							__t , t
						);
					} ,
					_t , *this
				);
				return _t;
			}

			Tensor< Pack > operator * ( const Tensor< ParameterPack::UIntPack<> > &t ) const { return *this * t.data; }

		protected:
			template< unsigned int D1 , unsigned int D2 >
			static auto _ContractionTensor( void )
			{
				static_assert( D1<D2 , "[ERROR] Contraction indices are the same" );
				static_assert( D1<Pack::Size , "[ERROR] First contraction index too large" );
				static_assert( D2<Pack::Size , "[ERROR] Second contraction index too large" );
				static_assert( ParameterPack::Selection< D1 , Pack >::Value==ParameterPack::Selection< D2 , Pack >::Value , "[ERROR] Contraction dimensions differ" );
				typedef typename ParameterPack::Selection< D1 , typename ParameterPack::Selection< D2 , Pack >::Complement >::Complement OutPack;

				Tensor< ParameterPack::Concatenation< OutPack , Pack > > t;

				unsigned int index[ Pack::Size+OutPack::Size ];
				if constexpr( OutPack::Size==0 )
					for( unsigned int i=0 ; i<Pack::template Get<D1>() ; i++ )
					{
						index[D1] = index[D2] = i;
						t( index ) = 1;
					}
				else
				{
					unsigned int out2in[ OutPack::Size ];
					{
						unsigned int count = 0;
						for( unsigned int i=0 ; i<Pack::Size ; i++ ) if( i!=D1 && i!=D2 ) out2in[ count++ ] = i;
					}

					MultiDimensionalArray::Loop< OutPack::Size >::Run
					(
						ParameterPack::IsotropicUIntPack< OutPack::Size >::Values , OutPack::Values ,
						[&]( int d , int i ){ index[d] = index[ out2in[d] ] = i; } ,
						[&]( void )
						{
							for( unsigned int i=0 ; i<Pack::template Get<D1>() ; i++ )
							{
								index[ OutPack::Size+D1 ] = i;
								index[ OutPack::Size+D2 ] = i;
								t( index ) = 1;
							}
						}
					);
				}

				return t;
			}

		public:
			template< unsigned int D1 , unsigned int D2 >
			static auto ContractionTensor( void )
			{
#pragma message( "[WARNING] Should avoid using ContractionTensor" )
				MK_WARN_ONCE( "Invoking ContractionTensor" );
				if constexpr( D1<D2 ) return _ContractionTensor< D1 , D2 >();
				else                  return _ContractionTensor< D2 , D1 >();
			}

			// Tensor contraction
			template< unsigned int I1 , unsigned int I2 >
			auto contract( void ) const
			{
				static_assert( I1!=I2 , "[ERROR] Contraction indices must differ" );
				static_assert( Pack::template Get< I1 >()==Pack::template Get< I2 >() , "[ERROR] Contraction dimensions don't match" );
				static_assert( I1<Pack::Size && I2<Pack::Size , "[ERROR] Contraction indices out of bounds" );
				if constexpr( I2<I1 ) return this->template contract< I2 , I1 >();
				typedef typename ParameterPack::Selection< I1 , typename ParameterPack::Selection< I2 , Pack >::Complement >::Complement OutPack;
				Tensor< OutPack > out;
				if constexpr( Pack::Size>2 )
				{
					unsigned int indices[ OutPack::Size ];
					MultiDimensionalArray::Loop< OutPack::Size >::Run
					(
						ParameterPack::IsotropicUIntPack< OutPack::Size >::Values , OutPack::Values ,
						[&]( int d , int i ){ indices[d] = i; } ,
						[&]( double &_out )
						{
							unsigned int _indices[ Pack::Size ];
							unsigned int idx=0;
							for( unsigned int i=0 ; i<Pack::Size ; i++ ) if( i!=I1 && i!=I2 ) _indices[i] = indices[idx++];
							_out = 0;
							for( unsigned int i=0 ; i<Pack::template Get< I1 >() ; i++ )
							{
								_indices[I1] = _indices[I2] = i;
								_out += operator()( _indices );
							}
						} ,
						out
					);
				}
				else
				{
					double &_out = static_cast< double & >( out );
					unsigned int _indices[2];
					for( unsigned int i=0 ; i<Pack::template Get<I1>() ; i++ )
					{
						_indices[I1] = _indices[I2] = i;
						_out += operator()( _indices );
					}
				}
				return out;
			}

			// In1 := [ N{1} , ... , N{I} , N{I+1} , ... , N{K} ]
			// In2 :=                     [ N{I+1} , ... , N{K} , N{K+1} , ... N{M} ]
			// Out := [ N{1} , ... , N{I}             ,           N{K+1} , ... N{M} ]
			template< unsigned int I , unsigned int ... _Dims >
			Tensor< ParameterPack::Concatenation< typename ParameterPack::Partition< Size-I , Pack >::First , typename ParameterPack::Partition< I , ParameterPack::UIntPack< _Dims ... > >::Second > > contractedOuterProduct( const Tensor< ParameterPack::UIntPack< _Dims ... > > &t ) const 
			{
				static_assert( ParameterPack::Comparison< typename ParameterPack::Partition< Size-I , Pack >::Second , typename ParameterPack::Partition< I , ParameterPack::UIntPack< _Dims ... > >::First >::Equal , "[ERROR] Contraction suffix/prefix don't match" );
				typedef ParameterPack::UIntPack< _Dims ... > _Pack;
				static const unsigned int _Size = _Pack::Size;
				typedef typename ParameterPack::Partition< Size-I ,  Pack >:: First P1;
				typedef typename ParameterPack::Partition< Size-I ,  Pack >::Second P2;
				typedef typename ParameterPack::Partition<      I , _Pack >::Second P3;

				typedef typename MultiDimensionalArray::SliceType< P1::Size , double ,  Dims ... >::const_type In1SliceType;
				typedef typename MultiDimensionalArray::SliceType< P2::Size , double , _Dims ... >::const_type In2SliceType;
				// In the case that we are collapsing completely, out is of type Tensor< ParameterPack::UIntPack<> >
				// -- Then the first and last loops are trivial and we never access the contents of out using operator[]
				typedef typename std::conditional< ParameterPack::Concatenation< P1 , P3 >::Size!=0 , double , Tensor< ParameterPack::UIntPack<> > >::type OutBaseType;
				typedef typename std::conditional< P3::Size!=0 , typename MultiDimensionalArray::SliceType< P2::Size , double , _Dims ... >::type , OutBaseType & >::type OutSliceType;

				const Tensor<  Pack > &in1 = *this;
				const Tensor< _Pack > &in2 = t;
				Tensor< ParameterPack::Concatenation< P1 , P3 > > out;

				// Iterate over {1,...,I} of in1 and out
				MultiDimensionalArray::Loop< P1::Size >::Run
				(
					ParameterPack::IsotropicUIntPack< P1::Size >::Values , P1::Values ,
					[]( int d , int i ){} ,
					[&]( In1SliceType _in1 , OutSliceType _out )
					{
						// Iterate over {I,...,K} of in1 and in2
						MultiDimensionalArray::Loop< P2::Size >::Run
						(
							ParameterPack::IsotropicUIntPack< P2::Size >::Values , P2::Values ,
							[]( int d , int i ){} ,
							[&]( double __in1 , In2SliceType _in2 )
							{
								// Iterate over {K+1,...,M} of in2 and out
								MultiDimensionalArray::Loop< P3::Size >::Run
								(
									ParameterPack::IsotropicUIntPack< P3::Size >::Values , P3::Values ,
									[]( int d , int i ){} ,
									[&]( double __in2 , OutBaseType &_out_ ){ _out_ += __in1 * __in2; } ,
									_in2 , _out
								);
							} ,
							_in1 , in2
						);
					} ,
					in1 , out
				);
				return out;
			}

			template< unsigned int I >
			Tensor< Pack > contractedOuterProduct( const Tensor< ParameterPack::UIntPack<> > &t ) const { return *this * t; }
		};
	}
}
#endif // TENSORS_INCLUDED
