/*
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hive.storage.jdbc.dao;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.serde.serdeConstants;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hive.storage.jdbc.conf.JdbcStorageConfig;
import org.apache.hive.storage.jdbc.exception.HiveJdbcDatabaseAccessException;

import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertThat;

public class TestGenericJdbcDatabaseAccessor {

  @Test
  public void testGetColumnNames_starQuery() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    List<String> columnNames = accessor.getColumnNames(conf);

    assertThat(columnNames, is(notNullValue()));
    assertThat(columnNames.size(), is(equalTo(7)));
    assertThat(columnNames.get(0), is(equalToIgnoringCase("strategy_id")));
  }

  @Test
  public void testGetColumnTypes_starQuery_allTypes() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from all_types_table");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);

    List<TypeInfo> expectedTypes = new ArrayList<>();
    expectedTypes.add(TypeInfoFactory.getCharTypeInfo(1));
    expectedTypes.add(TypeInfoFactory.getCharTypeInfo(20));
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.varcharTypeInfo);
    expectedTypes.add(TypeInfoFactory.getVarcharTypeInfo(1024));
    expectedTypes.add(TypeInfoFactory.varcharTypeInfo);
    expectedTypes.add(TypeInfoFactory.booleanTypeInfo);
    expectedTypes.add(TypeInfoFactory.byteTypeInfo);
    expectedTypes.add(TypeInfoFactory.shortTypeInfo);
    expectedTypes.add(TypeInfoFactory.intTypeInfo);
    expectedTypes.add(TypeInfoFactory.longTypeInfo);
    expectedTypes.add(TypeInfoFactory.getDecimalTypeInfo(38, 0));
    expectedTypes.add(TypeInfoFactory.getDecimalTypeInfo(9, 3));
    expectedTypes.add(TypeInfoFactory.floatTypeInfo);
    expectedTypes.add(TypeInfoFactory.doubleTypeInfo);
    expectedTypes.add(TypeInfoFactory.getDecimalTypeInfo(38, 0));
    expectedTypes.add(TypeInfoFactory.dateTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.timestampTypeInfo);
    expectedTypes.add(TypeInfoFactory.timestampTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    expectedTypes.add(TypeInfoFactory.getListTypeInfo(TypeInfoFactory.unknownTypeInfo));
    expectedTypes.add(TypeInfoFactory.unknownTypeInfo);
    Assert.assertEquals(expectedTypes, accessor.getColumnTypes(conf));
  }

  @Test
  public void testGetColumnNames_fieldListQuery() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select name,referrer from test_strategy");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    List<String> columnNames = accessor.getColumnNames(conf);

    assertThat(columnNames, is(notNullValue()));
    assertThat(columnNames.size(), is(equalTo(2)));
    assertThat(columnNames.get(0), is(equalToIgnoringCase("name")));
  }

  @Test
  public void testGetColumnTypes_fieldListQuery() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select name,referrer from test_strategy");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);

    List<TypeInfo> expectedTypes = new ArrayList<>(2);
    expectedTypes.add(TypeInfoFactory.getVarcharTypeInfo(50));
    expectedTypes.add(TypeInfoFactory.getVarcharTypeInfo(1024));
    Assert.assertEquals(expectedTypes, accessor.getColumnTypes(conf));
  }


  @Test(expected = HiveJdbcDatabaseAccessException.class)
  public void testGetColumnNames_invalidQuery() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from invalid_strategy");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    @SuppressWarnings("unused")
      List<String> columnNames = accessor.getColumnNames(conf);
  }


  @Test
  public void testGetTotalNumberOfRecords() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    int numRecords = accessor.getTotalNumberOfRecords(conf);

    assertThat(numRecords, is(equalTo(5)));
  }


  @Test
  public void testGetTotalNumberOfRecords_whereClause() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from test_strategy where strategy_id = '5'");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    int numRecords = accessor.getTotalNumberOfRecords(conf);

    assertThat(numRecords, is(equalTo(1)));
  }


  @Test
  public void testGetTotalNumberOfRecords_noRecords() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from test_strategy where strategy_id = '25'");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    int numRecords = accessor.getTotalNumberOfRecords(conf);

    assertThat(numRecords, is(equalTo(0)));
  }


  @Test(expected = HiveJdbcDatabaseAccessException.class)
  public void testGetTotalNumberOfRecords_invalidQuery() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from strategyx where strategy_id = '5'");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    @SuppressWarnings("unused")
      int numRecords = accessor.getTotalNumberOfRecords(conf);
  }


  @Test
  public void testGetRecordIterator() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    JdbcRecordIterator iterator = accessor.getRecordIterator(conf, null, null, null,2, 0);

    assertThat(iterator, is(notNullValue()));

    int count = 0;
    while (iterator.hasNext()) {
      Map<String, Object> record = iterator.next();
      count++;

      assertThat(record, is(notNullValue()));
      assertThat(record.size(), is(equalTo(7)));
      assertThat(record.get("strategy_id"), is(equalTo(count)));
    }

    assertThat(count, is(equalTo(2)));
    iterator.close();
  }


  @Test
  public void testGetRecordIterator_offsets() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    JdbcRecordIterator iterator = accessor.getRecordIterator(conf, null, null, null, 2, 2);

    assertThat(iterator, is(notNullValue()));

    int count = 0;
    while (iterator.hasNext()) {
      Map<String, Object> record = iterator.next();
      count++;

      assertThat(record, is(notNullValue()));
      assertThat(record.size(), is(equalTo(7)));
      assertThat(record.get("strategy_id"), is(equalTo(count + 2)));
    }

    assertThat(count, is(equalTo(2)));
    iterator.close();
  }


  @Test
  public void testGetRecordIterator_emptyResultSet() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from test_strategy where strategy_id = '25'");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    JdbcRecordIterator iterator = accessor.getRecordIterator(conf, null, null, null, 0, 2);

    assertThat(iterator, is(notNullValue()));
    assertThat(iterator.hasNext(), is(false));
    iterator.close();
  }


  @Test
  public void testGetRecordIterator_largeOffset() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    JdbcRecordIterator iterator = accessor.getRecordIterator(conf, null, null, null, 10, 25);

    assertThat(iterator, is(notNullValue()));
    assertThat(iterator.hasNext(), is(false));
    iterator.close();
  }


  @Test(expected = HiveJdbcDatabaseAccessException.class)
  public void testGetRecordIterator_invalidQuery() throws HiveJdbcDatabaseAccessException {
    Configuration conf = buildConfiguration();
    conf.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from strategyx");
    DatabaseAccessor accessor = DatabaseAccessorFactory.getAccessor(conf);
    @SuppressWarnings("unused")
      JdbcRecordIterator iterator = accessor.getRecordIterator(conf, null, null, null, 0, 2);
  }


  private Configuration buildConfiguration() {
    String scriptPath =
        TestGenericJdbcDatabaseAccessor.class.getClassLoader().getResource("test_script.sql")
      .getPath();
    Configuration config = new Configuration();
    config.set(JdbcStorageConfig.DATABASE_TYPE.getPropertyName(), "H2");
    config.set(JdbcStorageConfig.JDBC_DRIVER_CLASS.getPropertyName(), "org.h2.Driver");
    config.set(JdbcStorageConfig.JDBC_URL.getPropertyName(), "jdbc:h2:mem:test;MODE=MySQL;INIT=runscript from '"
        + scriptPath + "'");
    config.set(JdbcStorageConfig.QUERY.getPropertyName(), "select * from test_strategy");
    config.set(serdeConstants.LIST_COLUMNS, "strategy_id,name,referrer,landing,priority,implementation,last_modified");
    config.set(serdeConstants.LIST_COLUMN_TYPES, "int,string,string,string,int,string,timestamp");
    return config;
  }

}
