/*
 * Decompiled with CFR 0.152.
 */
package com.databricks.jdbc.auth;

import com.databricks.internal.apache.http.client.entity.UrlEncodedFormEntity;
import com.databricks.internal.apache.http.client.methods.CloseableHttpResponse;
import com.databricks.internal.apache.http.client.methods.HttpPost;
import com.databricks.internal.apache.http.client.utils.URIBuilder;
import com.databricks.internal.apache.http.message.BasicNameValuePair;
import com.databricks.internal.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import com.databricks.internal.bouncycastle.jce.provider.BouncyCastleProvider;
import com.databricks.internal.bouncycastle.openssl.PEMException;
import com.databricks.internal.bouncycastle.openssl.PEMKeyPair;
import com.databricks.internal.bouncycastle.openssl.PEMParser;
import com.databricks.internal.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import com.databricks.internal.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder;
import com.databricks.internal.bouncycastle.operator.InputDecryptorProvider;
import com.databricks.internal.bouncycastle.operator.OperatorCreationException;
import com.databricks.internal.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo;
import com.databricks.internal.bouncycastle.pkcs.PKCSException;
import com.databricks.internal.google.common.annotations.VisibleForTesting;
import com.databricks.internal.nimbusds.jose.JOSEException;
import com.databricks.internal.nimbusds.jose.JWSAlgorithm;
import com.databricks.internal.nimbusds.jose.JWSHeader;
import com.databricks.internal.nimbusds.jose.JWSSigner;
import com.databricks.internal.nimbusds.jose.crypto.ECDSASigner;
import com.databricks.internal.nimbusds.jose.crypto.RSASSASigner;
import com.databricks.internal.nimbusds.jose.crypto.impl.BaseJWSProvider;
import com.databricks.internal.nimbusds.jwt.JWTClaimsSet;
import com.databricks.internal.nimbusds.jwt.SignedJWT;
import com.databricks.internal.sdk.core.DatabricksException;
import com.databricks.internal.sdk.core.oauth.OAuthResponse;
import com.databricks.internal.sdk.core.oauth.Token;
import com.databricks.internal.sdk.core.oauth.TokenSource;
import com.databricks.jdbc.common.util.DriverUtil;
import com.databricks.jdbc.common.util.JsonUtil;
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
import com.databricks.jdbc.exception.DatabricksHttpException;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.exception.DatabricksSQLException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
import java.io.FileReader;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.PrivateKey;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.RSAPrivateKey;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;

public class JwtPrivateKeyClientCredentials
implements TokenSource {
    private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(JwtPrivateKeyClientCredentials.class);
    private static final BouncyCastleProvider bouncyCastleProvider = new BouncyCastleProvider();
    private IDatabricksHttpClient hc;
    private String clientId;
    private String tokenUrl;
    private final List<String> scopes;
    private final String jwtKeyFile;
    private final String jwtKid;
    private final String jwtKeyPassphrase;
    private final JWSAlgorithm jwtAlgorithm;

    private JwtPrivateKeyClientCredentials(IDatabricksHttpClient hc, String clientId, String jwtKeyFile, String jwtKid, String jwtKeyPassphrase, String jwtAlgorithm, String tokenUrl, List<String> scopes) {
        this.hc = hc;
        this.clientId = clientId;
        this.jwtKeyFile = jwtKeyFile;
        this.jwtKid = jwtKid;
        this.jwtKeyPassphrase = jwtKeyPassphrase;
        this.jwtAlgorithm = this.determineSignatureAlgorithm(jwtAlgorithm);
        this.tokenUrl = tokenUrl;
        this.scopes = scopes;
    }

    @Override
    public Token getToken() {
        HashMap<String, String> params = new HashMap<String, String>();
        params.put("grant_type", "client_credentials");
        if (this.scopes != null) {
            params.put("scope", String.join((CharSequence)" ", this.scopes));
        }
        params.put("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
        params.put("client_assertion", this.getSerialisedSignedJWT());
        if (DriverUtil.isRunningAgainstFake()) {
            params.put("client_assertion", "my-private-key");
        }
        return JwtPrivateKeyClientCredentials.retrieveToken(this.hc, this.tokenUrl, params, new HashMap<String, String>());
    }

    @VisibleForTesting
    protected static Token retrieveToken(IDatabricksHttpClient hc, String tokenUrl, Map<String, String> params, Map<String, String> headers) {
        try {
            URIBuilder uriBuilder = new URIBuilder(tokenUrl);
            HttpPost postRequest = new HttpPost(uriBuilder.build());
            postRequest.setEntity(new UrlEncodedFormEntity(params.entrySet().stream().map(e -> new BasicNameValuePair((String)e.getKey(), (String)e.getValue())).collect(Collectors.toList()), StandardCharsets.UTF_8));
            headers.forEach(postRequest::setHeader);
            CloseableHttpResponse response = hc.execute(postRequest);
            OAuthResponse resp = JsonUtil.getMapper().readValue(response.getEntity().getContent(), OAuthResponse.class);
            Instant expiry = Instant.now().plus((long)resp.getExpiresIn(), ChronoUnit.SECONDS);
            return new Token(resp.getAccessToken(), resp.getTokenType(), resp.getRefreshToken(), expiry);
        }
        catch (DatabricksHttpException | IOException | URISyntaxException e2) {
            String errorMessage = "Failed to retrieve custom M2M token: " + e2.getMessage();
            LOGGER.error(errorMessage);
            throw new DatabricksException(errorMessage, e2);
        }
    }

    private String getSerialisedSignedJWT() {
        PrivateKey privateKey = this.getPrivateKey();
        SignedJWT signedJWT = this.fetchSignedJWT(privateKey);
        return signedJWT.serialize();
    }

    @VisibleForTesting
    String getTokenEndpoint() {
        return this.tokenUrl;
    }

    @VisibleForTesting
    JWSAlgorithm determineSignatureAlgorithm(String jwtAlgorithm) {
        if (jwtAlgorithm == null) {
            jwtAlgorithm = "RS256";
        }
        switch (jwtAlgorithm) {
            case "RS384": {
                return JWSAlgorithm.RS384;
            }
            case "RS512": {
                return JWSAlgorithm.RS512;
            }
            case "PS256": {
                return JWSAlgorithm.PS256;
            }
            case "PS384": {
                return JWSAlgorithm.PS384;
            }
            case "PS512": {
                return JWSAlgorithm.PS512;
            }
            case "RS256": {
                return JWSAlgorithm.RS256;
            }
            case "ES384": {
                return JWSAlgorithm.ES384;
            }
            case "ES512": {
                return JWSAlgorithm.ES512;
            }
            case "ES256": {
                return JWSAlgorithm.ES256;
            }
        }
        LOGGER.debug("Defaulting to RS256. Provided JWT algorithm not supported " + jwtAlgorithm);
        return JWSAlgorithm.RS256;
    }

    /*
     * Enabled aggressive exception aggregation
     */
    private PrivateKey getPrivateKey() {
        try (FileReader reader = new FileReader(this.jwtKeyFile);){
            PrivateKey privateKey;
            try (PEMParser pemParser = new PEMParser(reader);){
                Object object = pemParser.readObject();
                privateKey = this.convertPrivateKey(object);
            }
            return privateKey;
        }
        catch (DatabricksSQLException | IOException e) {
            String errorMessage = "Failed to parse private key: " + e.getMessage();
            LOGGER.error(errorMessage);
            throw new DatabricksException(errorMessage, e);
        }
    }

    PrivateKey convertPrivateKey(Object pemObject) throws DatabricksParsingException {
        try {
            PrivateKeyInfo privateKeyInfo;
            if (this.jwtKeyPassphrase != null) {
                PKCS8EncryptedPrivateKeyInfo encryptedKeyInfo = (PKCS8EncryptedPrivateKeyInfo)pemObject;
                JceOpenSSLPKCS8DecryptorProviderBuilder decryptorProviderBuilder = new JceOpenSSLPKCS8DecryptorProviderBuilder();
                decryptorProviderBuilder.setProvider(bouncyCastleProvider);
                InputDecryptorProvider decryptorProvider = decryptorProviderBuilder.build(this.jwtKeyPassphrase.toCharArray());
                privateKeyInfo = encryptedKeyInfo.decryptPrivateKeyInfo(decryptorProvider);
            } else {
                try {
                    privateKeyInfo = ((PEMKeyPair)pemObject).getPrivateKeyInfo();
                }
                catch (ClassCastException e) {
                    privateKeyInfo = (PrivateKeyInfo)pemObject;
                }
            }
            JcaPEMKeyConverter keyConverter = new JcaPEMKeyConverter().setProvider(bouncyCastleProvider);
            return keyConverter.getPrivateKey(privateKeyInfo);
        }
        catch (PEMException | OperatorCreationException | PKCSException e) {
            String errorMessage = "Cannot decrypt private JWT key " + e.getMessage();
            LOGGER.error(errorMessage);
            throw new DatabricksParsingException(errorMessage, DatabricksDriverErrorCode.VOLUME_OPERATION_PARSING_ERROR);
        }
    }

    @VisibleForTesting
    SignedJWT fetchSignedJWT(PrivateKey privateKey) {
        try {
            BaseJWSProvider signer;
            if (privateKey instanceof RSAPrivateKey) {
                signer = new RSASSASigner(privateKey);
            } else if (privateKey instanceof ECPrivateKey) {
                signer = new ECDSASigner((ECPrivateKey)privateKey);
            } else {
                String errorMessage = "Unsupported private key type: " + privateKey.getClass().getName();
                LOGGER.error(errorMessage);
                throw new DatabricksException(errorMessage);
            }
            Timestamp timestamp = Timestamp.valueOf(LocalDateTime.now());
            JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().subject(this.clientId).issuer(this.clientId).issueTime(timestamp).expirationTime(timestamp).jwtID(UUID.randomUUID().toString()).audience(this.tokenUrl).build();
            JWSHeader header = new JWSHeader.Builder(this.jwtAlgorithm).keyID(this.jwtKid).build();
            SignedJWT signedJWT = new SignedJWT(header, claimsSet);
            signedJWT.sign((JWSSigner)((Object)signer));
            return signedJWT;
        }
        catch (JOSEException e) {
            String errorMessage = "Error signing the JWT: " + e.getMessage();
            LOGGER.error(errorMessage);
            throw new DatabricksException(errorMessage, e);
        }
    }

    public static class Builder {
        private String clientId;
        private String tokenUrl;
        private String jwtKeyFile;
        private String jwtKid;
        private String jwtKeyPassphrase;
        private String jwtAlgorithm;
        private IDatabricksHttpClient hc;
        private List<String> scopes = Collections.emptyList();

        public Builder withClientId(String clientId) {
            this.clientId = clientId;
            return this;
        }

        public Builder withTokenUrl(String tokenUrl) {
            this.tokenUrl = tokenUrl;
            return this;
        }

        public Builder withScopes(List<String> scopes) {
            this.scopes = scopes;
            return this;
        }

        public Builder withHttpClient(IDatabricksHttpClient hc) {
            this.hc = hc;
            return this;
        }

        public Builder withJwtAlgorithm(String jwtAlgorithm) {
            this.jwtAlgorithm = jwtAlgorithm;
            return this;
        }

        public Builder withJwtKeyPassphrase(String jwtKeyPassphrase) {
            this.jwtKeyPassphrase = jwtKeyPassphrase;
            return this;
        }

        public Builder withJwtKid(String jwtKid) {
            this.jwtKid = jwtKid;
            return this;
        }

        public Builder withJwtKeyFile(String jwtKeyFile) {
            this.jwtKeyFile = jwtKeyFile;
            return this;
        }

        public JwtPrivateKeyClientCredentials build() {
            Objects.requireNonNull(this.clientId, "clientId must be specified");
            Objects.requireNonNull(this.jwtKeyFile, "JWT key file must be specified");
            Objects.requireNonNull(this.jwtKid, "JWT KID must be specified");
            return new JwtPrivateKeyClientCredentials(this.hc, this.clientId, this.jwtKeyFile, this.jwtKid, this.jwtKeyPassphrase, this.jwtAlgorithm, this.tokenUrl, this.scopes);
        }
    }
}

