/*
 * Decompiled with CFR 0.152.
 */
package com.databricks.jdbc.api.impl.arrow;

import com.databricks.internal.google.common.annotations.VisibleForTesting;
import com.databricks.jdbc.api.impl.arrow.AbstractArrowResultChunk;
import com.databricks.jdbc.api.impl.arrow.ArrowResultChunk;
import com.databricks.jdbc.api.impl.arrow.ChunkStatus;
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.DatabricksClientType;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.dbclient.impl.common.StatementId;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.exception.DatabricksValidationException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.core.ChunkLinkFetchResult;
import com.databricks.jdbc.model.core.ExternalLink;
import java.time.Instant;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

public class ChunkLinkDownloadService<T extends AbstractArrowResultChunk> {
    private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(ChunkLinkDownloadService.class);
    private final IDatabricksSession session;
    private final StatementId statementId;
    private final long totalChunks;
    private final Map<Long, CompletableFuture<ExternalLink>> chunkIndexToLinkFuture;
    private final AtomicLong nextBatchStartIndex;
    private final AtomicBoolean isDownloadInProgress;
    private final AtomicBoolean isDownloadChainStarted;
    private volatile boolean isShutdown;
    private volatile CompletableFuture<Void> currentDownloadTask;
    private final Object resetLock = new Object();
    private final ConcurrentMap<Long, T> chunkIndexToChunksMap;

    public ChunkLinkDownloadService(IDatabricksSession session, StatementId statementId, long totalChunks, ConcurrentMap<Long, T> chunkIndexToChunksMap, long nextBatchStartIndex) {
        LOGGER.info("Initializing ChunkLinkDownloadService for statement {} with total chunks: {}, starting at index: {}", statementId, totalChunks, nextBatchStartIndex);
        this.session = session;
        this.statementId = statementId;
        this.totalChunks = totalChunks;
        this.nextBatchStartIndex = new AtomicLong(nextBatchStartIndex);
        this.isDownloadInProgress = new AtomicBoolean(false);
        this.isDownloadChainStarted = new AtomicBoolean(false);
        this.isShutdown = false;
        this.chunkIndexToLinkFuture = new ConcurrentHashMap<Long, CompletableFuture<ExternalLink>>();
        for (long i = 0L; i < totalChunks; ++i) {
            this.chunkIndexToLinkFuture.put(i, new CompletableFuture());
        }
        this.chunkIndexToChunksMap = chunkIndexToChunksMap;
        if (nextBatchStartIndex > 0L) {
            LOGGER.info("Completing futures for {} upfront-fetched links", nextBatchStartIndex);
            int completedCount = 0;
            for (long i = 0L; i < Math.min(nextBatchStartIndex, totalChunks); ++i) {
                ExternalLink link;
                AbstractArrowResultChunk chunk = (AbstractArrowResultChunk)chunkIndexToChunksMap.get(i);
                if (chunk == null || (link = chunk.getChunkLink()) == null) continue;
                LOGGER.debug("Completing link future for chunk {} in constructor", i);
                this.chunkIndexToLinkFuture.get(i).complete(link);
                ++completedCount;
            }
            LOGGER.info("Completed {} futures for upfront-fetched links", completedCount);
        }
        if (session.getConnectionContext().getClientType() == DatabricksClientType.SEA && this.isDownloadChainStarted.compareAndSet(false, true)) {
            LOGGER.info("Auto-triggering download chain for SEA client type");
            this.triggerNextBatchDownload();
        }
    }

    public CompletableFuture<ExternalLink> getLinkForChunk(long chunkIndex) throws ExecutionException, InterruptedException {
        if (this.isShutdown) {
            LOGGER.warn("Attempt to get link for chunk {} while chunk download service is shutdown", chunkIndex);
            return this.createExceptionalFuture(new DatabricksValidationException("Chunk Link Download Service is shutdown"));
        }
        if (chunkIndex >= this.totalChunks) {
            LOGGER.error("Requested chunk index {} exceeds total chunks {}", chunkIndex, this.totalChunks);
            return this.createExceptionalFuture(new DatabricksValidationException("Chunk index exceeds total chunks"));
        }
        LOGGER.debug("Getting link for chunk {}", chunkIndex);
        this.handleExpiredLinksAndReset(chunkIndex);
        if (this.isDownloadChainStarted.compareAndSet(false, true)) {
            LOGGER.info("Initiating first download chain for chunk {}", chunkIndex);
            this.triggerNextBatchDownload();
        }
        return this.chunkIndexToLinkFuture.get(chunkIndex);
    }

    public void shutdown() {
        LOGGER.info("Shutting down ChunkLinkDownloadService for statement {}", this.statementId);
        this.isShutdown = true;
        this.chunkIndexToLinkFuture.forEach((index, future) -> {
            if (!future.isDone()) {
                LOGGER.debug("Completing future for chunk {} exceptionally due to shutdown", index);
                future.completeExceptionally(new DatabricksValidationException("Service was shut down"));
            }
        });
    }

    private void triggerNextBatchDownload() {
        if (this.isShutdown || !this.isDownloadInProgress.compareAndSet(false, true)) {
            LOGGER.debug("Skipping batch download - Service shutdown: {}, Download in progress: {}", this.isShutdown, this.isDownloadInProgress.get());
            return;
        }
        long batchStartIndex = this.nextBatchStartIndex.get();
        if (batchStartIndex >= this.totalChunks) {
            LOGGER.info("No more chunks to download. Current index: {}, Total chunks: {}", batchStartIndex, this.totalChunks);
            this.isDownloadInProgress.set(false);
            return;
        }
        long batchStartRowOffset = this.getChunkStartRowOffset(batchStartIndex);
        LOGGER.info("Starting batch download from index {}", batchStartIndex);
        this.currentDownloadTask = CompletableFuture.runAsync(() -> {
            try {
                ChunkLinkFetchResult result = this.session.getDatabricksClient().getResultChunks(this.statementId, batchStartIndex, batchStartRowOffset);
                LOGGER.info("Retrieved {} links for batch starting at {} for statement id {}", result.getChunkLinks().size(), batchStartIndex, this.statementId);
                for (ExternalLink link : result.getChunkLinks()) {
                    CompletableFuture<ExternalLink> future = this.chunkIndexToLinkFuture.get(link.getChunkIndex());
                    if (future == null) continue;
                    LOGGER.debug("Completing future for chunk {} for statement id {}", link.getChunkIndex(), this.statementId);
                    future.complete(link);
                }
                if (!result.getChunkLinks().isEmpty()) {
                    long maxChunkIndex = result.getChunkLinks().stream().mapToLong(ExternalLink::getChunkIndex).max().getAsLong();
                    this.nextBatchStartIndex.set(maxChunkIndex + 1L);
                    LOGGER.debug("Updated next batch start index to {}", maxChunkIndex + 1L);
                    this.isDownloadInProgress.set(false);
                    if (maxChunkIndex + 1L < this.totalChunks) {
                        LOGGER.debug("Triggering next batch download");
                        this.triggerNextBatchDownload();
                    }
                }
            }
            catch (DatabricksSQLException e) {
                this.handleBatchDownloadError(batchStartIndex, e);
            }
        });
    }

    private void handleBatchDownloadError(long batchStartIndex, DatabricksSQLException e) {
        LOGGER.error(e, "Failed to download links for batch starting at {} : {}", batchStartIndex, e.getMessage());
        this.chunkIndexToLinkFuture.forEach((index, future) -> {
            if (!future.isDone()) {
                LOGGER.debug("Completing future for chunk {} exceptionally due to batch download error", index);
                future.completeExceptionally(e);
            }
        });
        this.isDownloadInProgress.set(false);
    }

    private CompletableFuture<ExternalLink> createExceptionalFuture(Exception e) {
        CompletableFuture<ExternalLink> future = new CompletableFuture<ExternalLink>();
        future.completeExceptionally(e);
        return future;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleExpiredLinksAndReset(long chunkIndex) throws ExecutionException, InterruptedException {
        Object object = this.resetLock;
        synchronized (object) {
            if (this.isChunkLinkExpiredForPendingDownload(chunkIndex)) {
                LOGGER.info("Detected expired link for chunk {}, re-triggering batch download from the smallest index with the expired link", chunkIndex);
                for (long i = 0L; i < this.totalChunks; ++i) {
                    if (!this.isChunkLinkExpiredForPendingDownload(i)) continue;
                    LOGGER.info("Found the smallest index {} with the expired link, initiating reset", i);
                    this.cancelCurrentDownloadTask();
                    this.resetFuturesFromIndex(i);
                    this.prepareNewBatchDownload(i);
                    break;
                }
            }
        }
    }

    private boolean isChunkLinkExpiredForPendingDownload(long chunkIndex) throws ExecutionException, InterruptedException {
        CompletableFuture<ExternalLink> chunkFuture = this.chunkIndexToLinkFuture.get(chunkIndex);
        AbstractArrowResultChunk chunk = (AbstractArrowResultChunk)this.chunkIndexToChunksMap.get(chunkIndex);
        return chunkFuture.isDone() && this.isChunkLinkExpired(chunkFuture.get()) && chunk.getStatus() != ChunkStatus.DOWNLOAD_SUCCEEDED;
    }

    private void cancelCurrentDownloadTask() {
        if (this.currentDownloadTask != null && !this.currentDownloadTask.isDone()) {
            LOGGER.debug("Cancelling current download task");
            this.currentDownloadTask.cancel(true);
            try {
                this.currentDownloadTask.get(100L, TimeUnit.MILLISECONDS);
            }
            catch (Exception e) {
                LOGGER.trace("Expected exception while cancelling download task: {}", e.getMessage());
            }
            this.currentDownloadTask = null;
        }
    }

    private void resetFuturesFromIndex(long startIndex) {
        LOGGER.info("Resetting futures from index {}", startIndex);
        for (long j = startIndex; j < this.totalChunks; ++j) {
            CompletableFuture<ExternalLink> future = this.chunkIndexToLinkFuture.get(j);
            if (future != null && !future.isDone()) {
                LOGGER.debug("Cancelling future for chunk {}", j);
                future.cancel(true);
            }
            this.chunkIndexToLinkFuture.put(j, new CompletableFuture());
        }
    }

    private void prepareNewBatchDownload(long startIndex) {
        LOGGER.info("Preparing new batch download from index {}", startIndex);
        this.nextBatchStartIndex.set(startIndex);
        this.isDownloadInProgress.set(false);
        this.isDownloadChainStarted.set(false);
    }

    private long getChunkStartRowOffset(long chunkIndex) {
        AbstractArrowResultChunk chunk = (AbstractArrowResultChunk)this.chunkIndexToChunksMap.get(chunkIndex);
        if (chunk == null) {
            throw new IllegalStateException("Chunk not found in map for index " + chunkIndex + ". Total chunks: " + this.totalChunks + ", StatementId: " + String.valueOf(this.statementId));
        }
        return chunk.getStartRowOffset();
    }

    private boolean isChunkLinkExpired(ExternalLink link) {
        if (link == null || link.getExpiration() == null) {
            LOGGER.warn("Link or expiration is null, assuming link is expired");
            return true;
        }
        if (DriverUtil.isRunningAgainstFake()) {
            return false;
        }
        Instant expirationWithBuffer = Instant.parse(link.getExpiration()).minusSeconds(ArrowResultChunk.SECONDS_BUFFER_FOR_EXPIRY.intValue());
        return expirationWithBuffer.isBefore(Instant.now());
    }

    @VisibleForTesting
    CompletableFuture<ExternalLink> getLinkFutureForTest(long chunkIndex) {
        return this.chunkIndexToLinkFuture.get(chunkIndex);
    }
}

