/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.msq.exec;

import com.google.common.collect.Iterables;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.IntCollection;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.druid.msq.exec.MemoryIntrospector;
import org.apache.druid.msq.exec.OutputChannelMode;
import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSlices;
import org.apache.druid.msq.input.stage.ReadablePartition;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
import org.apache.druid.msq.kernel.ShuffleSpec;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.util.MultiStageQueryContext;

public class WorkerMemoryParameters {
    public static final int DEFAULT_FRAME_SIZE = 1000000;
    private static final long EXTRA_MEMORY_PER_PROCESSOR = 25000000L;
    private static final double APPENDERATOR_BUNDLE_FREE_MEMORY_FRACTION = 0.67;
    private static final double PARTITION_STATS_MAX_BUNDLE_FREE_MEMORY_FRACTION = 0.1;
    private static final long PARTITION_STATS_MAX_MEMORY_PER_BUNDLE = 300000000L;
    private static final long PARTITION_STATS_MIN_MEMORY_PER_BUNDLE = 10000000L;
    private static final double BROADCAST_BUFFER_TOTAL_MEMORY_FRACTION = 0.2;
    private static final double BROADCAST_BUFFER_OVERHEAD_RATIO = 1.5;
    private static final long SORT_MERGE_JOIN_MEMORY_PER_PROCESSOR = 22500000L;
    private final long bundleFreeMemory;
    private final int frameSize;
    private final int superSorterConcurrentProcessors;
    private final int superSorterMaxChannelsPerMerger;
    private final int partitionStatisticsMaxRetainedBytes;
    private final long broadcastBufferMemory;

    public WorkerMemoryParameters(long bundleFreeMemory, int frameSize, int superSorterConcurrentProcessors, int superSorterMaxChannelsPerMerger, int partitionStatisticsMaxRetainedBytes, long broadcastBufferMemory) {
        this.bundleFreeMemory = bundleFreeMemory;
        this.frameSize = frameSize;
        this.superSorterConcurrentProcessors = superSorterConcurrentProcessors;
        this.superSorterMaxChannelsPerMerger = superSorterMaxChannelsPerMerger;
        this.partitionStatisticsMaxRetainedBytes = partitionStatisticsMaxRetainedBytes;
        this.broadcastBufferMemory = broadcastBufferMemory;
    }

    public static WorkerMemoryParameters createProductionInstance(WorkOrder workOrder, MemoryIntrospector memoryIntrospector, int maxConcurrentStages) {
        StageDefinition stageDef = workOrder.getStageDefinition();
        return WorkerMemoryParameters.createInstance(memoryIntrospector, MultiStageQueryContext.getFrameSize(workOrder.getWorkerContext()), workOrder.getInputs(), stageDef.getBroadcastInputNumbers(), stageDef.doesShuffle() ? stageDef.getShuffleSpec() : null, maxConcurrentStages, WorkerMemoryParameters.computeFramesPerOutputChannel(workOrder.getOutputChannelMode()));
    }

    public static WorkerMemoryParameters createInstance(MemoryIntrospector memoryIntrospector, int frameSize, List<InputSlice> inputSlices, IntSet broadcastInputNumbers, @Nullable ShuffleSpec shuffleSpec, int maxConcurrentStages, int numFramesPerOutputChannel) {
        int superSorterConcurrentProcessors;
        long minimumBundleFreeMemory;
        long bundleMemory = WorkerMemoryParameters.computeBundleMemory(memoryIntrospector.memoryPerTask(), maxConcurrentStages);
        long processorMemory = WorkerMemoryParameters.computeProcessorMemory(WorkerMemoryParameters.computeMaxSimultaneousInputChannelsPerProcessor(inputSlices, broadcastInputNumbers), frameSize);
        boolean hasBroadcastInputs = !broadcastInputNumbers.isEmpty();
        long broadcastBufferMemory = hasBroadcastInputs ? WorkerMemoryParameters.computeBroadcastBufferMemoryIncludingOverhead(bundleMemory) : 0L;
        int numProcessingThreads = memoryIntrospector.numProcessingThreads();
        int maxSimultaneousWorkProcessors = Math.min(numProcessingThreads, WorkerMemoryParameters.computeNumInputPartitions(inputSlices));
        long bundleFreeMemory = bundleMemory - (long)maxSimultaneousWorkProcessors * processorMemory - broadcastBufferMemory;
        if (bundleFreeMemory < (minimumBundleFreeMemory = WorkerMemoryParameters.computeMinimumBundleFreeMemory(frameSize, numFramesPerOutputChannel))) {
            long requiredTaskMemory = (bundleMemory - bundleFreeMemory + minimumBundleFreeMemory) * (long)maxConcurrentStages;
            throw new MSQException(new NotEnoughMemoryFault(memoryIntrospector.computeJvmMemoryRequiredForTaskMemory(requiredTaskMemory), memoryIntrospector.totalMemoryInJvm(), memoryIntrospector.memoryPerTask(), memoryIntrospector.numTasksInJvm(), memoryIntrospector.numProcessingThreads(), WorkerMemoryParameters.computeNumInputWorkers(inputSlices), maxConcurrentStages));
        }
        int partitionStatsMemory = StageDefinition.mustGatherResultKeyStatistics(shuffleSpec) ? WorkerMemoryParameters.computePartitionStatsMemory(bundleFreeMemory) : 0;
        long superSorterMemory = bundleFreeMemory - (long)partitionStatsMemory;
        int maxOutputPartitions = WorkerMemoryParameters.computeMaxOutputPartitions(shuffleSpec);
        int superSorterMaxChannelsPerMerger = -1;
        for (superSorterConcurrentProcessors = maxOutputPartitions == 0 ? numProcessingThreads : Math.min(maxOutputPartitions, numProcessingThreads); superSorterConcurrentProcessors > 0; --superSorterConcurrentProcessors) {
            long memoryPerProcessor = superSorterMemory / (long)superSorterConcurrentProcessors;
            int minMemoryForInputsPerProcessor = 2 * frameSize;
            int memoryForOutputsPerProcessor = numFramesPerOutputChannel * frameSize;
            if (memoryPerProcessor < (long)(minMemoryForInputsPerProcessor + memoryForOutputsPerProcessor)) continue;
            long memoryForInputsPerProcessor = memoryPerProcessor - (long)memoryForOutputsPerProcessor;
            superSorterMaxChannelsPerMerger = Ints.checkedCast((long)(memoryForInputsPerProcessor / (long)frameSize));
            break;
        }
        if (superSorterConcurrentProcessors == 0) {
            throw new MSQException(new NotEnoughMemoryFault(0L, memoryIntrospector.totalMemoryInJvm(), memoryIntrospector.memoryPerTask(), memoryIntrospector.numTasksInJvm(), memoryIntrospector.numProcessingThreads(), WorkerMemoryParameters.computeNumInputWorkers(inputSlices), maxConcurrentStages));
        }
        return new WorkerMemoryParameters(bundleFreeMemory, frameSize, superSorterConcurrentProcessors, superSorterMaxChannelsPerMerger, partitionStatsMemory, hasBroadcastInputs ? WorkerMemoryParameters.computeBroadcastBufferMemory(bundleMemory) : 0L);
    }

    public int getSuperSorterConcurrentProcessors() {
        return this.superSorterConcurrentProcessors;
    }

    public int getSuperSorterMaxChannelsPerMerger() {
        return this.superSorterMaxChannelsPerMerger;
    }

    public long getAppenderatorMaxBytesInMemory() {
        return Math.max(1L, this.getAppenderatorMemory() / 2L);
    }

    public int getAppenderatorMaxColumnsToMerge() {
        long calculatedMaxColumnsToMerge = this.getAppenderatorMemory() / 2L / 3000L;
        return Ints.checkedCast((long)Math.max(2L, calculatedMaxColumnsToMerge));
    }

    public int getFrameSize() {
        return this.frameSize;
    }

    public long getBroadcastBufferMemory() {
        return this.broadcastBufferMemory;
    }

    public long getSortMergeJoinMemory() {
        return 22500000L;
    }

    public int getPartitionStatisticsMaxRetainedBytes() {
        return this.partitionStatisticsMaxRetainedBytes;
    }

    private long getAppenderatorMemory() {
        return (long)((double)this.bundleFreeMemory * 0.67);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        WorkerMemoryParameters that = (WorkerMemoryParameters)o;
        return this.bundleFreeMemory == that.bundleFreeMemory && this.frameSize == that.frameSize && this.superSorterConcurrentProcessors == that.superSorterConcurrentProcessors && this.superSorterMaxChannelsPerMerger == that.superSorterMaxChannelsPerMerger && this.partitionStatisticsMaxRetainedBytes == that.partitionStatisticsMaxRetainedBytes && this.broadcastBufferMemory == that.broadcastBufferMemory;
    }

    public int hashCode() {
        return Objects.hash(this.bundleFreeMemory, this.frameSize, this.superSorterConcurrentProcessors, this.superSorterMaxChannelsPerMerger, this.partitionStatisticsMaxRetainedBytes, this.broadcastBufferMemory);
    }

    public String toString() {
        return "WorkerMemoryParameters{bundleFreeMemory=" + this.bundleFreeMemory + ", frameSize=" + this.frameSize + ", superSorterConcurrentProcessors=" + this.superSorterConcurrentProcessors + ", superSorterMaxChannelsPerMerger=" + this.superSorterMaxChannelsPerMerger + ", partitionStatisticsMaxRetainedBytes=" + this.partitionStatisticsMaxRetainedBytes + ", broadcastBufferMemory=" + this.broadcastBufferMemory + "}";
    }

    static long computeBundleMemory(long memoryPerWorker, int maxConcurrentStages) {
        return memoryPerWorker / (long)maxConcurrentStages;
    }

    static int computePartitionStatsMemory(long bundleFreeMemory) {
        return Ints.checkedCast((long)Math.max((long)Math.min((double)bundleFreeMemory * 0.1, 3.0E8), 10000000L));
    }

    static long computeBroadcastBufferMemory(long bundleMemory) {
        return (long)((double)bundleMemory * 0.2);
    }

    static long computeBroadcastBufferMemoryIncludingOverhead(long bundleMemory) {
        return (long)((double)WorkerMemoryParameters.computeBroadcastBufferMemory(bundleMemory) * 1.5);
    }

    static long computeProcessorMemory(int maxSimultaneousInputChannelsPerProcessor, int frameSize) {
        return 25000000L + WorkerMemoryParameters.computeProcessorMemoryForInputChannels(maxSimultaneousInputChannelsPerProcessor, frameSize) + (long)frameSize;
    }

    static long computeProcessorMemoryForInputChannels(int maxSimultaneousInputChannelsPerProcessor, int frameSize) {
        return (long)maxSimultaneousInputChannelsPerProcessor * (long)frameSize;
    }

    static int computeNumInputPartitions(List<InputSlice> inputSlices) {
        int retVal = 0;
        for (StageInputSlice slice : InputSlices.allStageSlices(inputSlices)) {
            retVal += Iterables.size((Iterable)slice.getPartitions());
        }
        return retVal;
    }

    static int computeMaxSimultaneousInputChannelsPerProcessor(List<InputSlice> inputSlices, IntSet broadcastInputNumbers) {
        long totalNonBroadcastInputChannels = 0L;
        long totalBroadcastInputChannels = 0L;
        List<StageInputSlice> allStageSlices = InputSlices.allStageSlices(inputSlices);
        for (int inputNumber = 0; inputNumber < allStageSlices.size(); ++inputNumber) {
            StageInputSlice slice = allStageSlices.get(inputNumber);
            int maxWorkers = 0;
            for (ReadablePartition partition : slice.getPartitions()) {
                maxWorkers = Math.max(maxWorkers, partition.getWorkerNumbers().size());
            }
            if (broadcastInputNumbers.contains(inputNumber)) {
                totalBroadcastInputChannels += (long)maxWorkers;
                continue;
            }
            totalNonBroadcastInputChannels += (long)maxWorkers;
        }
        return Ints.checkedCast((long)Math.max(totalBroadcastInputChannels, totalNonBroadcastInputChannels));
    }

    static int computeNumInputWorkers(List<InputSlice> inputSlices) {
        IntOpenHashSet workerNumbers = new IntOpenHashSet();
        for (StageInputSlice slice : InputSlices.allStageSlices(inputSlices)) {
            for (ReadablePartition partition : slice.getPartitions()) {
                workerNumbers.addAll((IntCollection)partition.getWorkerNumbers());
            }
        }
        return workerNumbers.size();
    }

    static int computeMaxOutputPartitions(@Nullable ShuffleSpec shuffleSpec) {
        if (shuffleSpec == null) {
            return 0;
        }
        switch (shuffleSpec.kind()) {
            case HASH: 
            case HASH_LOCAL_SORT: 
            case MIX: {
                return shuffleSpec.partitionCount();
            }
            case GLOBAL_SORT: {
                if (!(shuffleSpec instanceof GlobalSortMaxCountShuffleSpec)) break;
                return ((GlobalSortMaxCountShuffleSpec)shuffleSpec).getMaxPartitions();
            }
        }
        return 0;
    }

    static int computeFramesPerOutputChannel(OutputChannelMode outputChannelMode) {
        return outputChannelMode.isDurable() ? 2 : 1;
    }

    static long computeMinimumBundleFreeMemory(int frameSize, int numFramesPerOutputChannel) {
        long minMemory = 10000000L;
        return minMemory += (long)(2 + numFramesPerOutputChannel) * (long)frameSize;
    }
}

