#pragma once

#include <type_traits>
#include <array>
#include <functional>
#include <c10/util/TypeList.h>
#include <c10/util/Array.h>

namespace c10 { namespace guts {

/**
 * Access information about result type or arguments from a function type.
 * Example:
 * using A = function_traits<int (float, double)>::return_type // A == int
 * using A = function_traits<int (float, double)>::parameter_types::tuple_type // A == tuple<float, double>
 */
template<class Func> struct function_traits {
  static_assert(!std::is_same<Func, Func>::value, "In function_traits<Func>, Func must be a plain function type.");
};
template<class Result, class... Args>
struct function_traits<Result (Args...)> {
  using func_type = Result (Args...);
  using return_type = Result;
  using parameter_types = typelist::typelist<Args...>;
  static constexpr auto number_of_parameters = sizeof...(Args);
};

/**
 * infer_function_traits: creates a `function_traits` type for a simple
 * function (pointer) or functor (lambda/struct). Currently does not support
 * class methods.
 */

template <typename Functor>
struct infer_function_traits {
  using type = function_traits<c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
};

template <typename Result, typename... Args>
struct infer_function_traits<Result (*)(Args...)> {
  using type = function_traits<Result(Args...)>;
};

template <typename Result, typename... Args>
struct infer_function_traits<Result (Args...)> {
  using type = function_traits<Result(Args...)>;
};

template <typename T>
using infer_function_traits_t = typename infer_function_traits<T>::type;

/**
 * Use extract_arg_by_filtered_index to return the i-th argument whose
 * type fulfills a given type trait. The argument itself is perfectly forwarded.
 *
 * Example:
 * std::string arg1 = "Hello";
 * std::string arg2 = "World";
 * std::string&& result = extract_arg_by_filtered_index<is_string, 1>(0, arg1, 2.0, std::move(arg2));
 *
 * Warning: Taking the result by rvalue reference can cause segfaults because ownership will not be passed on
 *          from the original reference. The original reference dies after the expression and the resulting
 */
namespace detail {
template<template <class> class Condition, size_t index, class Enable, class... Args> struct extract_arg_by_filtered_index_;
template<template <class> class Condition, size_t index, class Head, class... Tail>
struct extract_arg_by_filtered_index_<Condition, index, std::enable_if_t<!Condition<Head>::value>, Head, Tail...> {
  static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
    return extract_arg_by_filtered_index_<Condition, index, void, Tail...>::call(std::forward<Tail>(tail)...);
  }
};
template<template <class> class Condition, size_t index, class Head, class... Tail>
struct extract_arg_by_filtered_index_<Condition, index, std::enable_if_t<Condition<Head>::value && index != 0>, Head, Tail...> {
  static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
    return extract_arg_by_filtered_index_<Condition, index-1, void, Tail...>::call(std::forward<Tail>(tail)...);
  }
};
template<template <class> class Condition, size_t index>
struct extract_arg_by_filtered_index_<Condition, index, void> {
  static void call() {
    static_assert(index != index, "extract_arg_by_filtered_index out of range.");
  }
};
template<template <class> class Condition, size_t index, class Head, class... Tail>
struct extract_arg_by_filtered_index_<Condition, index, std::enable_if_t<Condition<Head>::value && index == 0>, Head, Tail...> {
  static decltype(auto) call(Head&& head, Tail&&... /*tail*/) {
    return std::forward<Head>(head);
  }
};
}
template<template <class> class Condition, size_t index, class... Args>
decltype(auto) extract_arg_by_filtered_index(Args&&... args) {
  static_assert(is_type_condition<Condition>::value, "In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
  return detail::extract_arg_by_filtered_index_<Condition, index, void, Args...>::call(std::forward<Args>(args)...);
}



/**
 * Use filter_map to map a subset of the arguments to values.
 * The subset is defined by type traits, and will be evaluated at compile time.
 * At runtime, it will just loop over the pre-filtered arguments to create an std::array.
 *
 * Example:
 *  std::array<double, 2> result = filter_map<double, std::is_integral>([] (auto a) {return (double)a;}, 3, "bla", 4);
 *  // result == {3.0, 4.0}
 */
namespace detail {

template<class ResultType, size_t num_results> struct filter_map_ {
   template<template <class> class Condition, class Mapper, class... Args, size_t... INDEX>
   static guts::array<ResultType, num_results> call(const Mapper& mapper, std::index_sequence<INDEX...>, Args&&... args) {
     return guts::array<ResultType, num_results> { mapper(extract_arg_by_filtered_index<Condition, INDEX>(std::forward<Args>(args)...))... };
   }
};
template<class ResultType> struct filter_map_<ResultType, 0> {
  template<template <class> class Condition, class Mapper, class... Args, size_t... INDEX>
  static guts::array<ResultType, 0> call(const Mapper& /*mapper*/, std::index_sequence<INDEX...>, Args&&... /*args*/) {
    return guts::array<ResultType, 0> { };
  }
};
}

template<class ResultType, template <class> class Condition, class Mapper, class... Args>
decltype(auto) filter_map(const Mapper& mapper, Args&&... args) {
  static_assert(is_type_condition<Condition>::value, "In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");

  static constexpr size_t num_results = typelist::count_if<Condition, typelist::typelist<Args...>>::value;
  return detail::filter_map_<ResultType, num_results>::template call<Condition, Mapper, Args...>(mapper, std::make_index_sequence<num_results>(), std::forward<Args>(args)...);
}


/**
 * Use tuple_elements to extract a position-indexed subset of elements
 * from the argument tuple into a result tuple.
 *
 * Example:
 *  std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
 *  std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0, 2>());
 */
template <class Tuple, size_t... ns>
constexpr auto tuple_elements(Tuple t, std::index_sequence<ns...>) {
  return std::tuple<std::tuple_element_t<ns, Tuple>...>(std::get<ns>(t)...);
}

/**
 * Use tuple_take to extract the first n elements from the argument tuple
 * into a result tuple.
 *
 * Example:
 *  std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
 *  std::tuple<int, const char*> result = tuple_take<decltype(t), 2>(t);
 */
template <class Tuple, size_t n>
constexpr auto tuple_take(Tuple t) {
  return tuple_elements(t, std::make_index_sequence<n>{});
}


/**
 * Use tuple_map to run a mapping function over a tuple to get a new tuple.
 *
 * Example 1:
 *   auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), [] (int32_t a) -> int16_t {return a+1;});
 *   // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
 *
 * Example 2:
 *   struct Mapper {
 *     std::string operator()(int32_t a) const {
 *       return std::to_string(a);
 *     }
 *     int64_t operator()(const std::string& a) const {
 *        return atoi(a.c_str());
 *     }
 *   };
 *   auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"), Mapper());
 *   // result == std::tuple<std::string, int64_t>("3", 4)
 *
 * Example 3:
 *   struct A final {
 *    int32_t func() {
 *      return 5;
 *    }
 *  };
 *  struct B final {
 *    std::string func() {
 *      return "5";
 *    }
 *  };
 *  auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return a.func(); });
 *  // result == std::tuple<int32_t, std::string>(5, "5");
 */
namespace detail {
  template<class Mapper, class... Args, size_t... Indices>
  auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper, std::index_sequence<Indices...>) {
    return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(tuple))))...>(
      mapper(std::forward<Args>(std::get<Indices>(tuple)))...
    );
  }
}

template<class Mapper, class... Args>
auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
  return detail::tuple_map(std::move(tuple), mapper, std::index_sequence_for<Args...>());
}


/**
 * tuple_concat concatenates several tuples into one.
 */

namespace detail {
  // extract_tuple_element_by_index is a helper that takes a list of tuples and extracts the i-th element
  // in a flattened view of the tuples.
  // Example: extract_tuple_element_by_index<3>(tuple(2,3), tuple(4,5), tuple(6,7)) == 5.

  template<size_t index, class HeadTuple, class... TailTuples, std::enable_if_t<index < std::tuple_size<HeadTuple>::value, int> = 0>
  decltype(auto) extract_tuple_element_by_index(HeadTuple&& head_tuple, TailTuples&&... tail_tuples) {
    // TODO if constexpr instead of enable_if
    return std::get<index>(std::forward<HeadTuple>(head_tuple));
  }

  template<size_t index, class HeadTuple, class... TailTuples, std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0>
  decltype(auto) extract_tuple_element_by_index(HeadTuple&& head_tuple, TailTuples&&... tail_tuples) {
    // TODO if constexpr instead of enable_if
    return extract_tuple_element_by_index<index - std::tuple_size<HeadTuple>::value, TailTuples...>(std::forward<TailTuples>(tail_tuples)...);
  }

  static_assert(
    std::is_same<
      int&&,
      decltype(extract_tuple_element_by_index<2>(std::tuple<int32_t>(2), std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))
    >::value,
    "extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value"
  );

  template<class ConcatenatedTuple, class... Tuples, size_t... ElementIndices>
  auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) {
    return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>(std::forward<Tuples>(tuples)...)...);
  }
}

template<class... Tuples>
  auto tuple_concat(Tuples&&... tuples) {
    using flattened_types = guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>;
    using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>;
    constexpr size_t num_elements = guts::typelist::size<flattened_types>::value;
    return detail::tuple_concat<concatenated_tuple, Tuples...>(std::forward<Tuples>(tuples)..., std::make_index_sequence<num_elements>());
  }


}}
