// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#include <stdio.h>
#include <gtest/gtest.h>
#include <vector>
#include <numeric>

#include <hip/hip_runtime.h>
#include <rocrand/rocrand.h>

#include <rng/generator_type.hpp>
#include <rng/generators.hpp>

#include "test_common.hpp"
#include "test_rocrand_common.hpp"

TEST(rocrand_philox_prng_tests, uniform_uint_test)
{
    const size_t size = 1313;
    unsigned int * data;
    HIP_CHECK(hipMallocHelper(&data, sizeof(unsigned int) * (size + 1)));

    rocrand_philox4x32_10 g;
    ROCRAND_CHECK(g.generate(data+1, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int host_data[size];
    HIP_CHECK(hipMemcpy(host_data, data+1, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned long long sum = 0;
    for(size_t i = 0; i < size; i++)
    {
        sum += host_data[i];
    }
    const unsigned int mean = sum / size;
    ASSERT_NEAR(mean, UINT_MAX / 2, UINT_MAX / 20);

    HIP_CHECK(hipFree(data));
}

TEST(rocrand_philox_prng_tests, uniform_float_test)
{
    const size_t size = 1313;
    float * data;
    HIP_CHECK(hipMallocHelper(&data, sizeof(float) * size));

    rocrand_philox4x32_10 g;
    ROCRAND_CHECK(g.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    float host_data[size];
    HIP_CHECK(hipMemcpy(host_data, data, sizeof(float) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    double sum = 0;
    for(size_t i = 0; i < size; i++)
    {
        ASSERT_GT(host_data[i], 0.0f);
        ASSERT_LE(host_data[i], 1.0f);
        sum += host_data[i];
    }
    const float mean = sum / size;
    ASSERT_NEAR(mean, 0.5f, 0.05f);

    HIP_CHECK(hipFree(data));
}

// Check if the numbers generated by first generate() call are different from
// the numbers generated by the 2nd call (same generator)
TEST(rocrand_philox_prng_tests, state_progress_test)
{
    // Device data
    const size_t size = 1025;
    unsigned int * data;
    HIP_CHECK(hipMallocHelper(&data, sizeof(unsigned int) * size));

    // Generator
    rocrand_philox4x32_10 g0;

    // Generate using g0 and copy to host
    ROCRAND_CHECK(g0.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int host_data1[size];
    HIP_CHECK(hipMemcpy(host_data1, data, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    // Generate using g0 and copy to host
    ROCRAND_CHECK(g0.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int host_data2[size];
    HIP_CHECK(hipMemcpy(host_data2, data, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    size_t same = 0;
    for(size_t i = 0; i < size; i++)
    {
        if(host_data1[i] == host_data2[i]) same++;
    }
    // It may happen that numbers are the same, so we
    // just make sure that most of them are different.
    EXPECT_LT(same, static_cast<size_t>(0.01f * size));
    HIP_CHECK(hipFree(data));
}

// Checks if generators with the same seed and in the same state
// generate the same numbers
TEST(rocrand_philox_prng_tests, same_seed_test)
{
    const unsigned long long seed = 0xdeadbeefdeadbeefULL;

    // Device side data
    const size_t size = 1024;
    unsigned int * data;
    HIP_CHECK(hipMallocHelper(&data, sizeof(unsigned int) * size));

    // Generators
    rocrand_philox4x32_10 g0, g1;
    // Set same seeds
    g0.set_seed(seed);
    g1.set_seed(seed);

    // Generate using g0 and copy to host
    ROCRAND_CHECK(g0.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int g0_host_data[size];
    HIP_CHECK(hipMemcpy(g0_host_data, data, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    // Generate using g1 and copy to host
    ROCRAND_CHECK(g1.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int g1_host_data[size];
    HIP_CHECK(hipMemcpy(g1_host_data, data, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    // Numbers generated using same generator with same
    // seed should be the same
    for(size_t i = 0; i < size; i++)
    {
        ASSERT_EQ(g0_host_data[i], g1_host_data[i]);
    }
    HIP_CHECK(hipFree(data));
}

// Checks if generators with the same seed and in the same state generate
// the same numbers
TEST(rocrand_philox_prng_tests, different_seed_test)
{
    const unsigned long long seed0 = 0xdeadbeefdeadbeefULL;
    const unsigned long long seed1 = 0xbeefdeadbeefdeadULL;

    // Device side data
    const size_t size = 1024;
    unsigned int * data;
    HIP_CHECK(hipMallocHelper(&data, sizeof(unsigned int) * size));

    // Generators
    rocrand_philox4x32_10 g0, g1;
    // Set different seeds
    g0.set_seed(seed0);
    g1.set_seed(seed1);
    ASSERT_NE(g0.get_seed(), g1.get_seed());

    // Generate using g0 and copy to host
    ROCRAND_CHECK(g0.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int g0_host_data[size];
    HIP_CHECK(hipMemcpy(g0_host_data, data, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    // Generate using g1 and copy to host
    ROCRAND_CHECK(g1.generate(data, size));
    HIP_CHECK(hipDeviceSynchronize());

    unsigned int g1_host_data[size];
    HIP_CHECK(hipMemcpy(g1_host_data, data, sizeof(unsigned int) * size, hipMemcpyDeviceToHost));
    HIP_CHECK(hipDeviceSynchronize());

    size_t same = 0;
    for(size_t i = 0; i < size; i++)
    {
        if(g1_host_data[i] == g0_host_data[i]) same++;
    }
    // It may happen that numbers are the same, so we
    // just make sure that most of them are different.
    EXPECT_LT(same, static_cast<size_t>(0.01f * size));
    HIP_CHECK(hipFree(data));
}

///
/// rocrand_philox_prng_state_tests TEST GROUP
///

// Just get access to internal state
class rocrand_philox4x32_10_engine_type_test : public rocrand_philox4x32_10::engine_type
{
public:

    __host__ rocrand_philox4x32_10_engine_type_test()
        : rocrand_philox4x32_10::engine_type(0, 0, 0) {}

    __host__ state_type& internal_state_ref()
    {
        return m_state;
    }
};

TEST(rocrand_philox_prng_state_tests, seed_test)
{
    rocrand_philox4x32_10_engine_type_test engine;
    rocrand_philox4x32_10_engine_type_test::state_type& state = engine.internal_state_ref();

    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    engine.discard(1 * 4ULL);
    EXPECT_EQ(state.counter.x, 1U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    engine.seed(3331, 0, 5 * 4ULL);
    EXPECT_EQ(state.counter.x, 5U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);
}

// Check if the philox state counter is calculated correctly during
// random number generation.
TEST(rocrand_philox_prng_state_tests, discard_test)
{
    rocrand_philox4x32_10_engine_type_test engine;
    rocrand_philox4x32_10_engine_type_test::state_type& state = engine.internal_state_ref();

    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    engine.discard(UINT_MAX * 4ULL);
    EXPECT_EQ(state.counter.x, UINT_MAX);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    engine.discard(UINT_MAX * 4ULL);
    EXPECT_EQ(state.counter.x, UINT_MAX - 1);
    EXPECT_EQ(state.counter.y, 1U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    engine.discard(2 * 4ULL);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 2U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    state.counter.x = UINT_MAX;
    state.counter.y = UINT_MAX;
    state.counter.z = UINT_MAX;
    engine.discard(1 * 4ULL);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 1U);

    state.counter.x = UINT_MAX;
    state.counter.y = UINT_MAX;
    state.counter.z = UINT_MAX;
    engine.discard(1 * 4ULL);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 2U);

    state.counter.x = 123;
    state.counter.y = 456;
    state.counter.z = 789;
    state.counter.w = 999;
    engine.discard(1 * 4ULL);
    EXPECT_EQ(state.counter.x, 124U);
    EXPECT_EQ(state.counter.y, 456U);
    EXPECT_EQ(state.counter.z, 789U);
    EXPECT_EQ(state.counter.w, 999U);

    state.counter.x = 123;
    state.counter.y = 0;
    state.counter.z = 0;
    state.counter.w = 0;
    engine.discard(1 * 4ULL);
    EXPECT_EQ(state.counter.x, 124U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 0U);

    state.counter.x = UINT_MAX - 1;
    state.counter.y = 2;
    state.counter.z = 3;
    state.counter.w = 4;
    engine.discard(((1ull << 32) + 2ull) * 4ULL);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 4U);
    EXPECT_EQ(state.counter.z, 3U);
    EXPECT_EQ(state.counter.w, 4U);
}

TEST(rocrand_philox_prng_state_tests, discard_sequence_test)
{
    rocrand_philox4x32_10_engine_type_test engine;
    rocrand_philox4x32_10_engine_type_test::state_type& state = engine.internal_state_ref();

    engine.discard_subsequence(UINT_MAX);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, UINT_MAX);
    EXPECT_EQ(state.counter.w, 0U);

    engine.discard_subsequence(UINT_MAX);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, UINT_MAX - 1);
    EXPECT_EQ(state.counter.w, 1U);

    engine.discard_subsequence(2);
    EXPECT_EQ(state.counter.x, 0U);
    EXPECT_EQ(state.counter.y, 0U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 2U);

    state.counter.x = 123;
    state.counter.y = 456;
    state.counter.z = 789;
    state.counter.w = 999;
    engine.discard_subsequence(1);
    EXPECT_EQ(state.counter.x, 123U);
    EXPECT_EQ(state.counter.y, 456U);
    EXPECT_EQ(state.counter.z, 790U);
    EXPECT_EQ(state.counter.w, 999U);

    state.counter.x = 1;
    state.counter.y = 2;
    state.counter.z = UINT_MAX - 1;
    state.counter.w = 4;
    engine.discard_subsequence((1ull << 32) + 2ull);
    EXPECT_EQ(state.counter.x, 1U);
    EXPECT_EQ(state.counter.y, 2U);
    EXPECT_EQ(state.counter.z, 0U);
    EXPECT_EQ(state.counter.w, 6U);
}

template <typename T>
class rocrand_philox_prng_offset : public ::testing::Test {
public:
  using output_type = T;
};

using RocrandPhiloxPrngOffsetTypes = ::testing::Types<unsigned int, float>;
TYPED_TEST_SUITE(rocrand_philox_prng_offset, RocrandPhiloxPrngOffsetTypes);

TYPED_TEST(rocrand_philox_prng_offset, offsets_test)
{
    using T = typename TestFixture::output_type;
    const size_t size = 131313;

    constexpr size_t offsets[] = { 0, 1, 4, 11, 65536, 112233 };

    for(const auto offset : offsets)
    {
        SCOPED_TRACE(::testing::Message() << "with offset=" << offset);

        const size_t size0 = size;
        const size_t size1 = (size + offset);
        T* data0;
        T* data1;
        hipMalloc(&data0, sizeof(T) * size0);
        hipMalloc(&data1, sizeof(T) * size1);

        rocrand_philox4x32_10 g0;
        g0.set_offset(offset);
        g0.generate(data0, size0);
    
        rocrand_philox4x32_10 g1;
        g1.generate(data1, size1);

        std::vector<T> host_data0(size0);
        std::vector<T> host_data1(size1);
        hipMemcpy(host_data0.data(), data0, sizeof(T) * size0, hipMemcpyDeviceToHost);
        hipMemcpy(host_data1.data(), data1, sizeof(T) * size1, hipMemcpyDeviceToHost);
        hipDeviceSynchronize();
    
        for(size_t i = 0; i < size; ++i)
        {
            ASSERT_EQ(host_data0[i], host_data1[i + offset]);
        }

        hipFree(data0);
        hipFree(data1);
    }
}

// Check that subsequent generations of different sizes produce one
// sequence without gaps, no matter how many values are generated per call.
template<typename T, typename GenerateFunc>
void continuity_test(GenerateFunc generate_func, unsigned int divisor = 1)
{
    std::vector<size_t> sizes0({ 100, 1, 24783, 3, 2, 776543, 1048576 });
    std::vector<size_t> sizes1({ 1024, 55, 65536, 623456, 30, 1048576, 111331 });
    if (divisor > 1)
    {
        for (size_t& s : sizes0) s = (s + divisor - 1) & ~static_cast<size_t>(divisor - 1);
        for (size_t& s : sizes1) s = (s + divisor - 1) & ~static_cast<size_t>(divisor - 1);
    }

    const size_t size0 = std::accumulate(sizes0.cbegin(), sizes0.cend(), std::size_t{0});
    const size_t size1 = std::accumulate(sizes1.cbegin(), sizes1.cend(), std::size_t{0});

    T * data0;
    T * data1;
    hipMalloc(&data0, sizeof(T) * size0);
    hipMalloc(&data1, sizeof(T) * size1);

    rocrand_philox4x32_10 g0;
    rocrand_philox4x32_10 g1;

    std::vector<T> host_data0(size0);
    std::vector<T> host_data1(size1);

    size_t current0 = 0;
    for (size_t s : sizes0)
    {
        generate_func(g0, data0, s);
        hipMemcpy(
            host_data0.data() + current0,
            data0,
            sizeof(T) * s, hipMemcpyDefault);
        current0 += s;
    }
    size_t current1 = 0;
    for (size_t s : sizes1)
    {
        generate_func(g1, data1, s);
        hipMemcpy(
            host_data1.data() + current1,
            data1,
            sizeof(T) * s, hipMemcpyDefault);
        current1 += s;
    }

    for(size_t i = 0; i < std::min(size0, size1); i++)
    {
        ASSERT_EQ(host_data0[i], host_data1[i]);
    }

    hipFree(data0);
    hipFree(data1);
}

TEST(rocrand_philox_prng_tests, continuity_uniform_uint_test)
{
    continuity_test<unsigned int>([](rocrand_philox4x32_10& g, unsigned int * data, size_t s) { g.generate(data, s); });
}

TEST(rocrand_philox_prng_tests, continuity_uniform_char_test)
{
    continuity_test<unsigned char>([](rocrand_philox4x32_10& g, unsigned char * data, size_t s) { g.generate(data, s); }, 4);
}

TEST(rocrand_philox_prng_tests, continuity_uniform_float_test)
{
    continuity_test<float>([](rocrand_philox4x32_10& g, float * data, size_t s) { g.generate_uniform(data, s); });
}

TEST(rocrand_philox_prng_tests, continuity_uniform_double_test)
{
    continuity_test<double>([](rocrand_philox4x32_10& g, double * data, size_t s) { g.generate_uniform(data, s); });
}

TEST(rocrand_philox_prng_tests, continuity_normal_float_test)
{
    continuity_test<float>([](rocrand_philox4x32_10& g, float * data, size_t s) { g.generate_normal(data, s, 0.0f, 1.0f); }, 2);
}

TEST(rocrand_philox_prng_tests, continuity_normal_double_test)
{
    continuity_test<double>([](rocrand_philox4x32_10& g, double * data, size_t s) { g.generate_normal(data, s, 0.0, 1.0); }, 2);
}

TEST(rocrand_philox_prng_tests, continuity_log_normal_float_test)
{
    continuity_test<float>([](rocrand_philox4x32_10& g, float * data, size_t s) { g.generate_log_normal(data, s, 0.0f, 1.0f); }, 2);
}

TEST(rocrand_philox_prng_tests, continuity_log_normal_double_test)
{
    continuity_test<double>([](rocrand_philox4x32_10& g, double * data, size_t s) { g.generate_log_normal(data, s, 0.0, 1.0); }, 2);
}

TEST(rocrand_philox_prng_tests, continuity_poisson_test)
{
    continuity_test<unsigned int>([](rocrand_philox4x32_10& g, unsigned int * data, size_t s) { g.generate_poisson(data, s, 100.0); });
}
