/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.dlic.auth.http.saml;

import com.amazon.dlic.auth.http.saml.Saml2SettingsProvider;
import com.amazon.dlic.auth.http.saml.SamlConfigException;
import com.amazon.dlic.auth.http.saml.SamlNameIdFormat;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.nimbusds.jose.Algorithm;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import com.onelogin.saml2.authn.SamlResponse;
import com.onelogin.saml2.exception.ValidationError;
import com.onelogin.saml2.settings.Saml2Settings;
import com.onelogin.saml2.util.Util;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.Permission;
import java.security.PrivilegedActionException;
import java.util.Base64;
import java.util.Date;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.xml.xpath.XPathExpressionException;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.joda.time.DateTime;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.SpecialPermission;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.RestRequest;
import org.opensearch.security.DefaultObjectMapper;
import org.opensearch.security.authtoken.jwt.KeyPaddingUtil;
import org.opensearch.security.dlic.rest.api.AuthTokenProcessorAction;
import org.opensearch.security.filter.SecurityResponse;

class AuthTokenProcessorHandler {
    private static final Logger log = LogManager.getLogger(AuthTokenProcessorHandler.class);
    private static final Logger token_log = LogManager.getLogger((String)"com.amazon.dlic.auth.http.saml.Token");
    private static final Pattern EXPIRY_SETTINGS_PATTERN = Pattern.compile("\\s*(\\w+)\\s*(?:\\+\\s*(\\w+))?\\s*");
    private Saml2SettingsProvider saml2SettingsProvider;
    private String jwtSubjectKey;
    private String jwtRolesKey;
    private String samlSubjectKey;
    private String samlRolesKey;
    private String kibanaRootUrl;
    private long expiryOffset = 0L;
    private ExpiryBaseValue expiryBaseValue = ExpiryBaseValue.AUTO;
    private JWK signingKey;
    private JWSHeader jwsHeader;
    private Pattern samlRolesSeparatorPattern;

    AuthTokenProcessorHandler(Settings settings, Settings jwtSettings, Saml2SettingsProvider saml2SettingsProvider) throws Exception {
        this.saml2SettingsProvider = saml2SettingsProvider;
        this.jwtRolesKey = jwtSettings.get("roles_key", "roles");
        this.jwtSubjectKey = jwtSettings.get("subject_key", "sub");
        this.samlRolesKey = settings.get("roles_key");
        this.samlSubjectKey = settings.get("subject_key");
        String samlRolesSeparator = settings.get("roles_separator", settings.get("roles_seperator"));
        this.kibanaRootUrl = settings.get("kibana_url");
        if (samlRolesSeparator != null) {
            this.samlRolesSeparatorPattern = Pattern.compile(samlRolesSeparator);
        }
        if (this.samlRolesKey == null || this.samlRolesKey.isEmpty()) {
            log.warn("roles_key is not configured, will only extract subject from SAML");
            this.samlRolesKey = null;
        }
        if (this.samlSubjectKey == null || this.samlSubjectKey.isEmpty()) {
            this.samlSubjectKey = null;
        }
        this.initJwtExpirySettings(settings);
        this.signingKey = this.createJwkFromSettings(settings, jwtSettings);
        this.jwsHeader = this.createJwsHeaderFromSettings();
    }

    Optional<SecurityResponse> handle(RestRequest restRequest) throws Exception {
        try {
            SecurityManager sm = System.getSecurityManager();
            if (sm != null) {
                sm.checkPermission((Permission)new SpecialPermission());
            }
            return AccessController.doPrivileged(() -> this.handleLowLevel(restRequest));
        }
        catch (PrivilegedActionException e) {
            if (e.getCause() instanceof Exception) {
                throw (Exception)e.getCause();
            }
            throw new RuntimeException(e);
        }
    }

    private AuthTokenProcessorAction.Response handleImpl(String samlResponseBase64, String samlRequestId, String acsEndpoint, Saml2Settings saml2Settings, String requestPath) {
        if (token_log.isDebugEnabled()) {
            try {
                token_log.debug("SAMLResponse for {}\n{}", (Object)samlRequestId, (Object)new String(Util.base64decoder((String)samlResponseBase64), StandardCharsets.UTF_8));
            }
            catch (Exception e) {
                token_log.warn("SAMLResponse for {} cannot be decoded from base64\n{}", (Object)samlRequestId, (Object)samlResponseBase64, (Object)e);
            }
        }
        try {
            SamlResponse samlResponse = new SamlResponse(saml2Settings, acsEndpoint, samlResponseBase64);
            if (!samlResponse.isValid(samlRequestId)) {
                log.warn("Error while validating SAML response in {}", (Object)requestPath);
                return null;
            }
            AuthTokenProcessorAction.Response responseBody = new AuthTokenProcessorAction.Response();
            responseBody.setAuthorization("bearer " + this.createJwt(samlResponse));
            return responseBody;
        }
        catch (ValidationError e) {
            log.warn("Error while validating SAML response", (Throwable)e);
            return null;
        }
        catch (Exception e) {
            log.error("Error while converting SAML to JWT", (Throwable)e);
            return null;
        }
    }

    private Optional<SecurityResponse> handleLowLevel(RestRequest restRequest) throws SamlConfigException, IOException {
        try {
            AuthTokenProcessorAction.Response responseBody;
            if (restRequest.getMediaType() != XContentType.JSON) {
                throw new OpenSearchSecurityException(restRequest.path() + " expects content with type application/json", RestStatus.UNSUPPORTED_MEDIA_TYPE, new Object[0]);
            }
            if (restRequest.method() != RestRequest.Method.POST) {
                throw new OpenSearchSecurityException(restRequest.path() + " expects POST requests", RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
            }
            Saml2Settings saml2Settings = this.saml2SettingsProvider.getCached();
            BytesReference bytesReference = restRequest.requiredContent();
            JsonNode jsonRoot = DefaultObjectMapper.objectMapper.readTree(BytesReference.toBytes((BytesReference)bytesReference));
            if (!(jsonRoot instanceof ObjectNode)) {
                throw new JsonParseException(null, "Unexpected json format: " + String.valueOf(jsonRoot));
            }
            if (((ObjectNode)jsonRoot).get("SAMLResponse") == null) {
                log.warn("SAMLResponse is missing from request ");
                throw new OpenSearchSecurityException("SAMLResponse is missing from request", RestStatus.BAD_REQUEST, new Object[0]);
            }
            String samlResponseBase64 = ((ObjectNode)jsonRoot).get("SAMLResponse").asText();
            String samlRequestId = ((ObjectNode)jsonRoot).get("RequestId") != null ? ((ObjectNode)jsonRoot).get("RequestId").textValue() : null;
            String acsEndpoint = saml2Settings.getSpAssertionConsumerServiceUrl().toString();
            if (((ObjectNode)jsonRoot).get("acsEndpoint") != null && ((ObjectNode)jsonRoot).get("acsEndpoint").textValue() != null) {
                acsEndpoint = this.getAbsoluteAcsEndpoint(((ObjectNode)jsonRoot).get("acsEndpoint").textValue());
            }
            if ((responseBody = this.handleImpl(samlResponseBase64, samlRequestId, acsEndpoint, saml2Settings, restRequest.path())) == null) {
                return Optional.empty();
            }
            String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString((Object)responseBody);
            return Optional.of(new SecurityResponse(200, null, responseBodyString, XContentType.JSON.mediaType()));
        }
        catch (JsonProcessingException e) {
            log.warn("Error while parsing JSON for {}", (Object)restRequest.path(), (Object)e);
            return Optional.of(new SecurityResponse(400, "JSON could not be parsed"));
        }
    }

    private JWSHeader createJwsHeaderFromSettings() {
        JWSHeader.Builder jwsHeaderBuilder = new JWSHeader.Builder(JWSAlgorithm.HS512);
        return jwsHeaderBuilder.build();
    }

    JWK createJwkFromSettings(Settings settings, Settings jwtSettings) throws Exception {
        String exchangeKey = settings.get("exchange_key");
        if (!Strings.isNullOrEmpty((String)exchangeKey)) {
            exchangeKey = KeyPaddingUtil.padSecret(new String(Base64.getUrlDecoder().decode(exchangeKey), StandardCharsets.UTF_8), JWSAlgorithm.HS512);
            return new OctetSequenceKey.Builder(exchangeKey.getBytes(StandardCharsets.UTF_8)).algorithm((Algorithm)JWSAlgorithm.HS512).keyUse(KeyUse.SIGNATURE).build();
        }
        Settings jwkSettings = jwtSettings.getAsSettings("key");
        if (!jwkSettings.hasValue("k") && !Strings.isNullOrEmpty((String)jwkSettings.get("k"))) {
            throw new Exception("Settings for key exchange missing. Please specify at least the option exchange_key with a shared secret.");
        }
        String k = KeyPaddingUtil.padSecret(new String(Base64.getUrlDecoder().decode(jwkSettings.get("k")), StandardCharsets.UTF_8), JWSAlgorithm.HS512);
        return new OctetSequenceKey.Builder(k.getBytes(StandardCharsets.UTF_8)).algorithm((Algorithm)JWSAlgorithm.HS512).keyUse(KeyUse.SIGNATURE).build();
    }

    private String createJwt(SamlResponse samlResponse) throws Exception {
        String sessionIndex;
        JWTClaimsSet.Builder jwtClaimsBuilder = new JWTClaimsSet.Builder().notBeforeTime(new Date()).expirationTime(new Date(this.getJwtExpiration(samlResponse))).claim(this.jwtSubjectKey, (Object)this.extractSubject(samlResponse));
        if (this.samlSubjectKey != null) {
            jwtClaimsBuilder.claim("saml_ni", (Object)samlResponse.getNameId());
        }
        if (samlResponse.getNameIdFormat() != null) {
            jwtClaimsBuilder.claim("saml_nif", (Object)SamlNameIdFormat.getByUri(samlResponse.getNameIdFormat()).getShortName());
        }
        if ((sessionIndex = samlResponse.getSessionIndex()) != null) {
            jwtClaimsBuilder.claim("saml_si", (Object)sessionIndex);
        }
        if (this.samlRolesKey != null && this.jwtRolesKey != null) {
            String[] roles = this.extractRoles(samlResponse);
            jwtClaimsBuilder.claim(this.jwtRolesKey, (Object)roles);
        }
        JWTClaimsSet jwtClaims = jwtClaimsBuilder.build();
        SignedJWT jwt = new SignedJWT(this.jwsHeader, jwtClaims);
        jwt.sign(new DefaultJWSSignerFactory().createJWSSigner(this.signingKey));
        String encodedJwt = jwt.serialize();
        if (token_log.isDebugEnabled()) {
            token_log.debug("Created JWT: " + encodedJwt + "\n" + jwt.getHeader().toString() + "\n" + jwt.getJWTClaimsSet().toString());
        }
        return encodedJwt;
    }

    private long getJwtExpiration(SamlResponse samlResponse) throws Exception {
        DateTime sessionNotOnOrAfter = samlResponse.getSessionNotOnOrAfter();
        if (this.expiryBaseValue == ExpiryBaseValue.NOW) {
            return System.currentTimeMillis() + this.expiryOffset * 1000L;
        }
        if (this.expiryBaseValue == ExpiryBaseValue.SESSION) {
            if (sessionNotOnOrAfter != null) {
                return sessionNotOnOrAfter.getMillis() + this.expiryOffset * 1000L;
            }
            throw new Exception("Error while determining JWT expiration time: SamlResponse did not contain sessionNotOnOrAfter value");
        }
        if (sessionNotOnOrAfter != null) {
            return sessionNotOnOrAfter.getMillis();
        }
        return System.currentTimeMillis() + (this.expiryOffset > 0L ? this.expiryOffset * 1000L : 3600000L);
    }

    private void initJwtExpirySettings(Settings settings) {
        String expiry = settings.get("jwt.expiry");
        if (Strings.isNullOrEmpty((String)expiry)) {
            return;
        }
        Matcher matcher = EXPIRY_SETTINGS_PATTERN.matcher(expiry);
        if (!matcher.matches()) {
            log.error("Invalid value for jwt.expiry: {}; using defaults.", (Object)expiry);
            return;
        }
        String baseValue = matcher.group(1);
        String offset = matcher.group(2);
        if (offset != null && !StringUtils.isNumeric((CharSequence)offset)) {
            log.error("Invalid offset value for jwt.expiry: {}; using defaults.", (Object)expiry);
            return;
        }
        if (!Strings.isNullOrEmpty((String)baseValue)) {
            try {
                this.expiryBaseValue = ExpiryBaseValue.valueOf(baseValue.toUpperCase());
            }
            catch (IllegalArgumentException e) {
                log.error("Invalid base value for jwt.expiry: {}; using defaults", (Object)expiry);
                return;
            }
        }
        if (offset != null) {
            this.expiryOffset = Integer.parseInt(offset) * 60;
        }
    }

    private String extractSubject(SamlResponse samlResponse) throws Exception {
        if (this.samlSubjectKey == null) {
            return samlResponse.getNameId();
        }
        List values = (List)samlResponse.getAttributes().get(this.samlSubjectKey);
        if (values == null || values.size() == 0) {
            return null;
        }
        return (String)values.get(0);
    }

    private String[] extractRoles(SamlResponse samlResponse) throws XPathExpressionException, ValidationError {
        if (this.samlRolesKey == null) {
            return new String[0];
        }
        List<String> values = (List<String>)samlResponse.getAttributes().get(this.samlRolesKey);
        if (values == null || values.size() == 0) {
            return null;
        }
        if (this.samlRolesSeparatorPattern != null) {
            values = this.splitRoles(values);
        }
        return values.toArray(new String[values.size()]);
    }

    private List<String> splitRoles(List<String> values) {
        return values.stream().flatMap(v -> this.samlRolesSeparatorPattern.splitAsStream((CharSequence)v)).filter(r -> !Strings.isNullOrEmpty((String)r)).collect(Collectors.toList());
    }

    private String getAbsoluteAcsEndpoint(String acsEndpoint) {
        try {
            URI acsEndpointUri = new URI(acsEndpoint);
            if (acsEndpointUri.isAbsolute()) {
                return acsEndpoint;
            }
            return new URI(this.kibanaRootUrl).resolve(acsEndpointUri).toString();
        }
        catch (URISyntaxException e) {
            log.error("Could not parse URI for acsEndpoint: {}", (Object)acsEndpoint);
            return acsEndpoint;
        }
    }

    public JWK getSigningKey() {
        return this.signingKey;
    }

    private static enum ExpiryBaseValue {
        AUTO,
        NOW,
        SESSION;

    }
}

