package org.jgroups.protocols;

import io.searchbox.params.Parameters;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.security.auth.Subject;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.login.LoginContext;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServerFactory;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.Message;
import org.jgroups.annotations.MBean;
import org.jgroups.annotations.Property;
import org.jgroups.auth.Krb5Token;
import org.jgroups.auth.sasl.SaslClientCallbackHandler;
import org.jgroups.auth.sasl.SaslClientContext;
import org.jgroups.auth.sasl.SaslContext;
import org.jgroups.auth.sasl.SaslServerContext;
import org.jgroups.auth.sasl.SaslUtils;
import org.jgroups.conf.ClassConfigurator;
import org.jgroups.conf.PropertyConverters;
import org.jgroups.protocols.pbcast.GMS;
import org.jgroups.protocols.pbcast.JoinRsp;
import org.jgroups.stack.Protocol;
import org.jgroups.util.MessageBatch;

@MBean(description = "Provides SASL authentication")
/* loaded from: input_file:WEB-INF/lib/jgroups-4.0.8.Final.jar:org/jgroups/protocols/SASL.class */
public class SASL extends Protocol {
    public static final short GMS_ID = ClassConfigurator.getProtocolId(GMS.class);
    public static final short SASL_ID = ClassConfigurator.getProtocolId(SASL.class);
    public static final String SASL_PROTOCOL_NAME = "jgroups";

    @Property(name = "login_module_name", description = "The name of the JAAS login module to use to obtain a subject for creating the SASL client and server (optional). Only required by some SASL mechs (e.g. GSSAPI)")
    protected String login_module_name;

    @Property(name = "client_name", description = "The name to use when a node is acting as a client (i.e. it is not the coordinator. Will also be used to obtain the subject if using a JAAS login module")
    protected String client_name;

    @Property(name = Krb5Token.CLIENT_PASSWORD, description = "The password to use when a node is acting as a client (i.e. it is not the coordinator. Will also be used to obtain the subject if using a JAAS login module", exposeAsManagedAttribute = false)
    protected String client_password;

    @Property(name = "mech", description = "The name of the mech to require for authentication. Can be any mech supported by your local SASL provider. The JDK comes standard with CRAM-MD5, DIGEST-MD5, GSSAPI, NTLM")
    protected String mech;

    @Property(name = "server_name", description = "The fully qualified server name")
    protected String server_name;

    @Property(name = "client_callback_handler", description = "The CallbackHandler to use when a node acts as a client (i.e. it is not the coordinator")
    protected CallbackHandler client_callback_handler;

    @Property(name = "server_callback_handler", description = "The CallbackHandler to use when a node acts as a server (i.e. it is the coordinator")
    protected CallbackHandler server_callback_handler;
    protected Subject client_subject;
    protected Subject server_subject;
    protected Address local_addr;
    private SaslServerFactory saslServerFactory;
    private SaslClientFactory saslClientFactory;

    @Property(name = "sasl_props", description = "Properties specific to the chosen mech", converter = PropertyConverters.StringProperties.class)
    protected Map<String, String> sasl_props = new HashMap();

    @Property(name = Parameters.TIMEOUT, description = "How long to wait (in ms) for a response to a challenge")
    protected long timeout = 5000;
    protected final Map<Address, SaslContext> sasl_context = new HashMap();

    @Property(name = "client_callback_handler_class")
    public void setClientCallbackHandlerClass(String str) throws Exception {
        this.client_callback_handler = (CallbackHandler) Class.forName(str).asSubclass(CallbackHandler.class).newInstance();
    }

    public String getClientCallbackHandlerClass() {
        if (this.client_callback_handler != null) {
            return this.client_callback_handler.getClass().getName();
        }
        return null;
    }

    public CallbackHandler getClientCallbackHandler() {
        return this.client_callback_handler;
    }

    public void setClientCallbackHandler(CallbackHandler callbackHandler) {
        this.client_callback_handler = callbackHandler;
    }

    @Property(name = "server_callback_handler_class")
    public void setServerCallbackHandlerClass(String str) throws Exception {
        this.server_callback_handler = (CallbackHandler) Class.forName(str).asSubclass(CallbackHandler.class).newInstance();
    }

    public String getServerCallbackHandlerClass() {
        if (this.server_callback_handler != null) {
            return this.server_callback_handler.getClass().getName();
        }
        return null;
    }

    public CallbackHandler getServerCallbackHandler() {
        return this.server_callback_handler;
    }

    public void setServerCallbackHandler(CallbackHandler callbackHandler) {
        this.server_callback_handler = callbackHandler;
    }

    public void setLoginModuleName(String str) {
        this.login_module_name = str;
    }

    public String getLoginModulename() {
        return this.login_module_name;
    }

    public void setMech(String str) {
        this.mech = str;
    }

    public String getMech() {
        return this.mech;
    }

    public void setSaslProps(Map<String, String> map) {
        this.sasl_props = map;
    }

    public Map<String, String> getSaslProps() {
        return this.sasl_props;
    }

    public void setClientSubject(Subject subject) {
        this.client_subject = subject;
    }

    public Subject getClientSubject() {
        return this.client_subject;
    }

    public void setServerSubject(Subject subject) {
        this.server_subject = subject;
    }

    public Subject getServerSubject() {
        return this.server_subject;
    }

    public void setServerName(String str) {
        this.server_name = str;
    }

    public String getServerName(String str) {
        return this.server_name;
    }

    public void setTimeout(long j) {
        this.timeout = j;
    }

    public long getTimeout() {
        return this.timeout;
    }

    public Address getAddress() {
        return this.local_addr;
    }

    @Override // org.jgroups.stack.Protocol
    public void init() throws Exception {
        super.init();
        this.saslServerFactory = SaslUtils.getSaslServerFactory(this.mech, this.sasl_props);
        this.saslClientFactory = SaslUtils.getSaslClientFactory(this.mech, this.sasl_props);
        char[] charArray = this.client_password == null ? new char[0] : this.client_password.toCharArray();
        if (this.client_callback_handler == null && this.client_password != null) {
            this.client_callback_handler = new SaslClientCallbackHandler(this.client_name, charArray);
        }
        if (this.server_subject == null && this.login_module_name != null) {
            LoginContext loginContext = new LoginContext(this.login_module_name);
            loginContext.login();
            this.server_subject = loginContext.getSubject();
        }
        if (this.client_subject != null || this.login_module_name == null) {
            return;
        }
        LoginContext loginContext2 = new LoginContext(this.login_module_name, new SaslClientCallbackHandler(this.client_name, charArray));
        loginContext2.login();
        this.client_subject = loginContext2.getSubject();
    }

    @Override // org.jgroups.stack.Protocol
    public void stop() {
        super.stop();
        cleanup();
    }

    @Override // org.jgroups.stack.Protocol
    public void destroy() {
        super.destroy();
        cleanup();
    }

    private void cleanup() {
        this.sasl_context.values().forEach((v0) -> {
            v0.dispose();
        });
        this.sasl_context.clear();
    }

    @Override // org.jgroups.stack.Protocol, org.jgroups.UpHandler
    public Object up(Message message) {
        SaslHeader saslHeader = (SaslHeader) message.getHeader(SASL_ID);
        GMS.GmsHeader gmsHeader = (GMS.GmsHeader) message.getHeader(GMS_ID);
        Address src = message.getSrc();
        if (needsAuthentication(gmsHeader, src)) {
            if (saslHeader == null) {
                throw new IllegalStateException("Found GMS join or merge request but no SASL header");
            }
            if (!serverChallenge(gmsHeader, saslHeader, message)) {
                return null;
            }
        } else if (saslHeader != null) {
            SaslContext saslContext = this.sasl_context.get(src);
            if (saslContext == null) {
                throw new IllegalStateException(String.format("Cannot find server context to challenge SASL request from %s", src.toString()));
            }
            switch (saslHeader.getType()) {
                case CHALLENGE:
                    try {
                        if (this.log.isTraceEnabled()) {
                            this.log.trace("%s: received CHALLENGE from %s", getAddress(), src);
                        }
                        Message nextMessage = saslContext.nextMessage(src, saslHeader);
                        if (nextMessage != null) {
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: sending RESPONSE to %s", getAddress(), src);
                            }
                            this.down_prot.down(nextMessage);
                        } else {
                            if (!saslContext.isSuccessful()) {
                                throw new SaslException("computed response is null but challenge-response cycle not complete!");
                            }
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: authentication complete from %s", getAddress(), src);
                            }
                        }
                        return null;
                    } catch (SaslException e) {
                        disposeContext(src);
                        if (!this.log.isWarnEnabled()) {
                            return null;
                        }
                        this.log.warn(getAddress() + ": failed to validate CHALLENGE from " + src + ", token", e);
                        return null;
                    }
                case RESPONSE:
                    try {
                        if (this.log.isTraceEnabled()) {
                            this.log.trace("%s: received RESPONSE from %s", getAddress(), src);
                        }
                        Message nextMessage2 = saslContext.nextMessage(src, saslHeader);
                        if (nextMessage2 != null) {
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: sending CHALLENGE to %s", getAddress(), src);
                            }
                            this.down_prot.down(nextMessage2);
                        } else {
                            if (!saslContext.isSuccessful()) {
                                throw new SaslException("computed challenge is null but challenge-response cycle not complete!");
                            }
                            if (this.log.isTraceEnabled()) {
                                this.log.trace("%s: authentication complete from %s", getAddress(), src);
                            }
                        }
                        return null;
                    } catch (SaslException e2) {
                        disposeContext(src);
                        if (!this.log.isWarnEnabled()) {
                            return null;
                        }
                        this.log.warn("failed to validate RESPONSE from " + src + ", token", e2);
                        return null;
                    }
                default:
                    return null;
            }
        }
        return this.up_prot.up(message);
    }

    private void disposeContext(Address address) {
        SaslContext remove = this.sasl_context.remove(address);
        if (remove != null) {
            remove.dispose();
        }
    }

    @Override // org.jgroups.stack.Protocol, org.jgroups.UpHandler
    public void up(MessageBatch messageBatch) {
        Iterator<Message> it = messageBatch.iterator();
        while (it.hasNext()) {
            Message next = it.next();
            GMS.GmsHeader gmsHeader = (GMS.GmsHeader) next.getHeader(GMS_ID);
            if (needsAuthentication(gmsHeader, next.getSrc())) {
                SaslHeader saslHeader = (SaslHeader) next.getHeader(this.id);
                if (saslHeader == null) {
                    this.log.warn("Found GMS join or merge request but no SASL header");
                    sendRejectionMessage(gmsHeader.getType(), messageBatch.sender(), "join or merge without an SASL header");
                    messageBatch.remove(next);
                } else if (!serverChallenge(gmsHeader, saslHeader, next)) {
                    messageBatch.remove(next);
                }
            }
        }
        if (messageBatch.isEmpty()) {
            return;
        }
        this.up_prot.up(messageBatch);
    }

    @Override // org.jgroups.stack.Protocol
    public Object down(Event event) {
        switch (event.getType()) {
            case 8:
                this.local_addr = (Address) event.getArg();
                break;
        }
        return this.down_prot.down(event);
    }

    @Override // org.jgroups.stack.Protocol
    public Object down(Message message) {
        GMS.GmsHeader gmsHeader = (GMS.GmsHeader) message.getHeader(GMS_ID);
        Address dest = message.getDest();
        if (needsAuthentication(gmsHeader, dest)) {
            SaslClientContext saslClientContext = null;
            try {
                saslClientContext = new SaslClientContext(this.saslClientFactory, this.mech, this.server_name != null ? this.server_name : dest.toString(), this.client_callback_handler, this.sasl_props, this.client_subject);
                this.sasl_context.put(dest, saslClientContext);
                saslClientContext.addHeader(message, null);
            } catch (Exception e) {
                if (saslClientContext != null) {
                    disposeContext(dest);
                }
                throw new SecurityException(e);
            }
        }
        return this.down_prot.down(message);
    }

    private boolean isSelf(Address address) {
        return address.equals(this.local_addr);
    }

    private boolean needsAuthentication(GMS.GmsHeader gmsHeader, Address address) {
        if (gmsHeader == null) {
            return false;
        }
        switch (gmsHeader.getType()) {
            case 1:
            case 11:
                return true;
            case 2:
            case 7:
                return false;
            case 3:
            case 4:
            case 5:
            case 8:
            case 9:
            case 10:
            default:
                return false;
            case 6:
                return !isSelf(address);
        }
    }

    protected boolean serverChallenge(GMS.GmsHeader gmsHeader, SaslHeader saslHeader, Message message) {
        switch (gmsHeader.getType()) {
            case 1:
            case 6:
            case 11:
                Address src = message.getSrc();
                SaslContext saslContext = null;
                try {
                    try {
                        SaslServerContext saslServerContext = new SaslServerContext(this.saslServerFactory, this.mech, this.server_name != null ? this.server_name : this.local_addr.toString(), this.server_callback_handler, this.sasl_props, this.server_subject);
                        this.sasl_context.put(src, saslServerContext);
                        getDownProtocol().down(saslServerContext.nextMessage(src, saslHeader));
                        saslServerContext.awaitCompletion(this.timeout);
                        if (saslServerContext.isSuccessful()) {
                            if (this.log.isDebugEnabled()) {
                                this.log.debug("Authentication successful for %s", saslServerContext.getAuthorizationID());
                            }
                            if (saslServerContext != null && !saslServerContext.needsWrapping()) {
                                disposeContext(src);
                            }
                            return true;
                        }
                        this.log.warn("failed to validate SaslHeader from %s, header: %s", message.getSrc(), saslHeader);
                        sendRejectionMessage(gmsHeader.getType(), message.getSrc(), "authentication failed");
                        if (saslServerContext != null && !saslServerContext.needsWrapping()) {
                            disposeContext(src);
                        }
                        return false;
                    } catch (InterruptedException e) {
                        if (0 != 0 && !saslContext.needsWrapping()) {
                            disposeContext(src);
                        }
                        return false;
                    } catch (SaslException e2) {
                        this.log.warn("failed to validate SaslHeader from %s, header: %s", message.getSrc(), saslHeader);
                        sendRejectionMessage(gmsHeader.getType(), message.getSrc(), "authentication failed");
                        if (0 == 0 || saslContext.needsWrapping()) {
                            return true;
                        }
                        disposeContext(src);
                        return true;
                    }
                } catch (Throwable th) {
                    if (0 != 0 && !saslContext.needsWrapping()) {
                        disposeContext(src);
                    }
                    throw th;
                }
            default:
                return true;
        }
    }

    protected void sendRejectionMessage(byte b, Address address, String str) {
        switch (b) {
            case 1:
            case 11:
                sendJoinRejectionMessage(address, str);
                return;
            case 6:
                sendMergeRejectionMessage(address);
                return;
            default:
                this.log.error("type " + ((int) b) + " unknown");
                return;
        }
    }

    protected void sendJoinRejectionMessage(Address address, String str) {
        if (address == null) {
            return;
        }
        this.down_prot.down(new Message(address).putHeader(GMS_ID, new GMS.GmsHeader((byte) 2)).setBuffer(GMS.marshal(new JoinRsp(str))));
    }

    protected void sendMergeRejectionMessage(Address address) {
        Message flag = new Message(address).setFlag(Message.Flag.OOB);
        GMS.GmsHeader gmsHeader = new GMS.GmsHeader((byte) 7);
        gmsHeader.setMergeRejected(true);
        flag.putHeader(GMS_ID, gmsHeader);
        if (this.log.isDebugEnabled()) {
            this.log.debug("merge response=" + gmsHeader);
        }
        this.down_prot.down(flag);
    }
}
