/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.analytics;

import java.net.UnknownHostException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.cassandra.distributed.api.ICluster;
import org.apache.cassandra.distributed.api.IInstance;
import org.apache.cassandra.distributed.api.IInstanceConfig;
import org.apache.cassandra.distributed.shared.JMXUtil;
import org.apache.cassandra.sidecar.common.server.dns.DnsResolver;
import org.apache.cassandra.sidecar.testing.MtlsTestHelper;
import org.apache.cassandra.sidecar.testing.QualifiedName;
import org.apache.cassandra.spark.KryoRegister;
import org.apache.cassandra.spark.bulkwriter.BulkSparkConf;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.DataFrameWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.AbstractCollectionAssert;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.jetbrains.annotations.NotNull;

public class SparkTestUtils {
    public static final Function<Object[], String> VALIDATION_DEFAULT_COLUMNS_MAPPER = columns -> String.format("%s:%s:%s", columns[0], columns[1], columns[2]);
    public static final Function<Row, String> VALIDATION_DEFAULT_ROW_MAPPER = row -> String.format("%s:%s:%d", row.get(0), row.get(1), row.getInt(2));
    protected ICluster<? extends IInstance> cluster;
    protected DnsResolver dnsResolver;
    protected int sidecarPort;
    protected MtlsTestHelper mtlsTestHelper;

    public void initialize(ICluster<? extends IInstance> cluster, DnsResolver dnsResolver, int sidecarPort, MtlsTestHelper mtlsTestHelper) {
        this.cluster = Objects.requireNonNull(cluster, "cluster is required");
        this.dnsResolver = Objects.requireNonNull(dnsResolver, "dnsResolver is required");
        this.mtlsTestHelper = Objects.requireNonNull(mtlsTestHelper, "mtlsTestHelper is required");
        this.sidecarPort = sidecarPort;
    }

    public void setMtlsTestHelper(MtlsTestHelper mtlsTestHelper) {
        this.mtlsTestHelper = Objects.requireNonNull(mtlsTestHelper);
    }

    public void tearDown() {
    }

    public DataFrameReader defaultBulkReaderDataFrame(SparkConf sparkConf, SparkSession spark, QualifiedName tableName, Map<String, String> additionalOptions) {
        SQLContext sql = spark.sqlContext();
        SparkContext sc = spark.sparkContext();
        int coresPerExecutor = sparkConf.getInt("spark.executor.cores", 1);
        int numExecutors = sparkConf.getInt("spark.dynamicAllocation.maxExecutors", sparkConf.getInt("spark.executor.instances", 1));
        int numCores = coresPerExecutor * numExecutors;
        HashMap<String, String> options = new HashMap<String, String>();
        options.put("sidecar_contact_points", this.sidecarInstancesOption(this.cluster, this.dnsResolver));
        options.put("keyspace", tableName.keyspace());
        options.put("table", tableName.table());
        options.put("DC", "datacenter1");
        options.put("snapshotName", UUID.randomUUID().toString());
        options.put("createSnapshot", "true");
        options.put("clearSnapshotStrategy", "noop");
        options.put("defaultParallelism", String.valueOf(sc.defaultParallelism()));
        options.put("numCores", String.valueOf(numCores));
        options.put("sizing", "default");
        options.put("sidecar_port", String.valueOf(this.sidecarPort));
        options.putAll(additionalOptions);
        return sql.read().format("org.apache.cassandra.spark.sparksql.CassandraDataSource").options(options).options(this.mtlsTestHelper.mtlOptionMap());
    }

    public DataFrameWriter<Row> defaultBulkWriterDataFrameWriter(Dataset<Row> df, QualifiedName tableName, Map<String, String> additionalOptions) {
        return df.write().format("org.apache.cassandra.spark.sparksql.CassandraDataSink").option("sidecar_contact_points", this.sidecarInstancesOption(this.cluster, this.dnsResolver)).option("keyspace", tableName.keyspace()).option("table", tableName.table()).option("local_dc", "datacenter1").option("bulk_writer_cl", "LOCAL_QUORUM").option("number_splits", "-1").option("sidecar_port", (long)this.sidecarPort).options(additionalOptions).options(this.mtlsTestHelper.mtlOptionMap()).mode("append");
    }

    public DataFrameWriter<Row> coordinatedBulkWriterDataFrameWriter(Dataset<Row> df, QualifiedName tableName, Map<String, String> additionalOptions) {
        return df.write().format("org.apache.cassandra.spark.sparksql.CassandraDataSink").option("keyspace", tableName.keyspace()).option("table", tableName.table()).option("bulk_writer_cl", "LOCAL_QUORUM").option("number_splits", "-1").options(additionalOptions).options(this.mtlsTestHelper.mtlOptionMap()).mode("append");
    }

    public SparkConf defaultSparkConf() {
        SparkConf sparkConf = new SparkConf().setAppName("Integration test Spark Cassandra Bulk Analytics Job").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.sql.caseSensitive", "True").set("spark.driver.bindAddress", "127.0.0.1").set("spark.master", "local[8,4]").set("spark.cassandra_analytics.cassandra.version", "5.0.0").set("spark.cassandra_analytics.sidecar.request.retries", "5").set("spark.cassandra_analytics.sidecar.request.retries.delay.milliseconds", "500").set("spark.cassandra_analytics.sidecar.request.retries.max.delay.milliseconds", "500");
        BulkSparkConf.setupSparkConf((SparkConf)sparkConf, (boolean)true);
        KryoRegister.setup((SparkConf)sparkConf);
        return sparkConf;
    }

    public void validateWrites(List<Row> sourceData, Object[][] queriedData) {
        this.validateWrites(sourceData, queriedData, VALIDATION_DEFAULT_COLUMNS_MAPPER, VALIDATION_DEFAULT_ROW_MAPPER);
    }

    public void validateWrites(List<Row> sourceData, @NotNull Object[][] queriedData, @NotNull Function<Object[], String> columnsMapper, @NotNull Function<Row, String> rowMapper) {
        Set actualEntries = Arrays.stream(queriedData).map(columnsMapper).collect(Collectors.toSet());
        Assertions.assertThat((int)actualEntries.size()).isEqualTo(sourceData.size());
        sourceData.forEach(row -> {
            String key = (String)rowMapper.apply((Row)row);
            ((AbstractBooleanAssert)Assertions.assertThat((boolean)actualEntries.remove(key)).as(key + " is expected to exist in the actual entries", new Object[0])).isTrue();
        });
        ((AbstractCollectionAssert)Assertions.assertThat(actualEntries).as("All entries are expected to be read from database", new Object[0])).isEmpty();
    }

    public void assertExpectedBulkWriteFailure(String writeCL, DataFrameWriter<Row> dfWriter) {
        Throwable cause;
        Throwable thrown = Assertions.catchThrowable(() -> dfWriter.save());
        ((AbstractThrowableAssert)Assertions.assertThat((Throwable)thrown).isInstanceOf(RuntimeException.class)).hasMessageContaining("Bulk Write to Cassandra has failed");
        for (cause = thrown; cause != null && !StringUtils.contains((CharSequence)cause.getMessage(), (CharSequence)"Failed to write"); cause = cause.getCause()) {
        }
        ((AbstractThrowableAssert)Assertions.assertThat((Throwable)cause).isNotNull()).hasMessageFindingMatch("Failed to write (\\d+) ranges with " + writeCL + " for job ([a-zA-Z0-9-]+) in phase .*");
    }

    protected String sidecarInstancesOption(ICluster<? extends IInstance> cluster, DnsResolver dnsResolver) {
        return SparkTestUtils.sidecarInstancesOptionStream(cluster, dnsResolver).collect(Collectors.joining(","));
    }

    public static Stream<String> sidecarInstancesOptionStream(ICluster<? extends IInstance> cluster, DnsResolver dnsResolver) {
        return IntStream.rangeClosed(1, cluster.size()).filter(i -> !cluster.get(i).isShutdown()).mapToObj(i -> {
            String ipAddress = JMXUtil.getJmxHost((IInstanceConfig)cluster.get(i).config());
            try {
                return dnsResolver.reverseResolve(ipAddress);
            }
            catch (UnknownHostException e) {
                return ipAddress;
            }
        });
    }
}

