/*
 * Decompiled with CFR 0.152.
 */
package com.twosigma.beakerx.socket;

import com.twosigma.beakerx.handler.Handler;
import com.twosigma.beakerx.kernel.Config;
import com.twosigma.beakerx.kernel.KernelFunctionality;
import com.twosigma.beakerx.kernel.KernelSockets;
import com.twosigma.beakerx.kernel.SocketCloseAction;
import com.twosigma.beakerx.kernel.msg.JupyterMessages;
import com.twosigma.beakerx.message.Header;
import com.twosigma.beakerx.message.Message;
import com.twosigma.beakerx.message.MessageSerializer;
import com.twosigma.beakerx.security.HashedMessageAuthenticationCode;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zeromq.ZFrame;
import org.zeromq.ZMQ;
import org.zeromq.ZMsg;

public class KernelSocketsZMQ
extends KernelSockets {
    public static final Logger logger = LoggerFactory.getLogger(KernelSocketsZMQ.class);
    public static final String DELIM = "<IDS|MSG>";
    private KernelFunctionality kernel;
    private SocketCloseAction closeAction;
    private HashedMessageAuthenticationCode hmac;
    private ZMQ.Socket hearbeatSocket;
    private ZMQ.Socket controlSocket;
    private ZMQ.Socket shellSocket;
    private ZMQ.Socket iopubSocket;
    private ZMQ.Socket stdinSocket;
    private ZMQ.Poller sockets;
    private ZMQ.Context context;
    private boolean shutdownSystem = false;

    public KernelSocketsZMQ(KernelFunctionality kernel, Config configuration, SocketCloseAction closeAction) {
        this.closeAction = closeAction;
        this.kernel = kernel;
        this.hmac = new HashedMessageAuthenticationCode(configuration.getKey());
        this.context = ZMQ.context((int)1);
        this.configureSockets(configuration);
    }

    private void configureSockets(Config configuration) {
        String connection = configuration.getTransport() + "://" + configuration.getHost();
        this.hearbeatSocket = this.getNewSocket(4, configuration.getHeartbeat(), connection, this.context);
        this.iopubSocket = this.getNewSocket(1, configuration.getIopub(), connection, this.context);
        this.controlSocket = this.getNewSocket(6, configuration.getControl(), connection, this.context);
        this.stdinSocket = this.getNewSocket(6, configuration.getStdin(), connection, this.context);
        this.shellSocket = this.getNewSocket(6, configuration.getShell(), connection, this.context);
        this.sockets = new ZMQ.Poller(4);
        this.sockets.register(this.controlSocket, 1);
        this.sockets.register(this.hearbeatSocket, 1);
        this.sockets.register(this.shellSocket, 1);
        this.sockets.register(this.stdinSocket, 1);
    }

    @Override
    public void publish(Message message) {
        this.sendMsg(this.iopubSocket, message);
    }

    @Override
    public void send(Message message) {
        this.sendMsg(this.shellSocket, message);
    }

    private synchronized void sendMsg(ZMQ.Socket socket, Message message) {
        String header = MessageSerializer.toJson(message.getHeader());
        String parent = MessageSerializer.toJson(message.getParentHeader());
        String meta = MessageSerializer.toJson(message.getMetadata());
        String content = MessageSerializer.toJson(message.getContent());
        String digest = this.hmac.sign(Arrays.asList(header, parent, meta, content));
        ZMsg newZmsg = new ZMsg();
        message.getIdentities().forEach(arg_0 -> ((ZMsg)newZmsg).add(arg_0));
        newZmsg.add(DELIM);
        newZmsg.add(digest.getBytes(StandardCharsets.UTF_8));
        newZmsg.add(header.getBytes(StandardCharsets.UTF_8));
        newZmsg.add(parent.getBytes(StandardCharsets.UTF_8));
        newZmsg.add(meta.getBytes(StandardCharsets.UTF_8));
        newZmsg.add(content.getBytes(StandardCharsets.UTF_8));
        message.getBuffers().forEach(x -> newZmsg.add(x));
        newZmsg.send(socket);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Message readMessage(ZMQ.Socket socket) {
        ZMsg zmsg = null;
        Message message = new Message();
        try {
            zmsg = ZMsg.recvMsg((ZMQ.Socket)socket);
            Object[] parts = new ZFrame[zmsg.size()];
            zmsg.toArray(parts);
            byte[] uuid = parts[0].getData();
            byte[] header = parts[3].getData();
            byte[] parent = parts[4].getData();
            byte[] metadata = parts[5].getData();
            byte[] content = parts[6].getData();
            byte[] expectedSig = parts[2].getData();
            this.verifyDelim((ZFrame)parts[1]);
            this.verifySignatures(expectedSig, header, parent, metadata, content);
            if (uuid != null) {
                message.getIdentities().add(uuid);
            }
            message.setHeader(this.parse(header, Header.class));
            message.setParentHeader(this.parse(parent, Header.class));
            message.setMetadata(this.parse(metadata, LinkedHashMap.class));
            message.setContent(this.parse(content, LinkedHashMap.class));
        }
        finally {
            if (zmsg != null) {
                zmsg.destroy();
            }
        }
        return message;
    }

    @Override
    public void run() {
        try {
            while (!this.isShutdown()) {
                this.sockets.poll();
                if (this.isControlMsg()) {
                    this.handleControlMsg();
                    continue;
                }
                if (this.isHeartbeatMsg()) {
                    this.handleHeartbeat();
                    continue;
                }
                if (this.isShellMsg()) {
                    this.handleShell();
                    continue;
                }
                if (this.isStdinMsg()) {
                    this.handleStdIn();
                    continue;
                }
                if (!this.isShutdown()) continue;
                break;
            }
        }
        finally {
            this.close();
        }
    }

    private void handleStdIn() {
        byte[] buffer = this.stdinSocket.recv();
        logger.info("Stdin: {}", (Object)new String(buffer));
    }

    private void handleShell() {
        Message message = this.readMessage(this.shellSocket);
        Handler<Message> handler = this.kernel.getHandler(message.type());
        if (handler != null) {
            handler.handle(message);
        }
    }

    private void handleHeartbeat() {
        byte[] buffer = this.hearbeatSocket.recv(0);
        this.hearbeatSocket.send(buffer);
    }

    private void handleControlMsg() {
        Message message = this.readMessage(this.controlSocket);
        JupyterMessages type = message.getHeader().getTypeEnum();
        if (type.equals((Object)JupyterMessages.SHUTDOWN_REQUEST)) {
            Message reply = new Message();
            reply.setHeader(new Header(JupyterMessages.SHUTDOWN_REPLY, message.getHeader().getSession()));
            reply.setParentHeader(message.getHeader());
            reply.setContent(message.getContent());
            this.sendMsg(this.controlSocket, reply);
            this.shutdown();
        }
    }

    private ZMQ.Socket getNewSocket(int type, int port, String connection, ZMQ.Context context) {
        ZMQ.Socket socket = context.socket(type);
        socket.bind(connection + ":" + String.valueOf(port));
        return socket;
    }

    private void close() {
        this.closeAction.close();
        this.closeSockets();
    }

    private void closeSockets() {
        try {
            if (this.shellSocket != null) {
                this.shellSocket.close();
            }
            if (this.controlSocket != null) {
                this.controlSocket.close();
            }
            if (this.iopubSocket != null) {
                this.iopubSocket.close();
            }
            if (this.stdinSocket != null) {
                this.stdinSocket.close();
            }
            if (this.hearbeatSocket != null) {
                this.hearbeatSocket.close();
            }
            this.context.close();
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    private void verifySignatures(byte[] expectedSig, byte[] header, byte[] parent, byte[] metadata, byte[] content) {
        String expectedSigAsString = new String(expectedSig, StandardCharsets.UTF_8);
        String actualSig = this.hmac.signBytes(new ArrayList<byte[]>(Arrays.asList(header, parent, metadata, content)));
        if (!expectedSigAsString.equals(actualSig)) {
            throw new RuntimeException("Signatures do not match.");
        }
    }

    private String verifyDelim(ZFrame zframe) {
        String delim = new String(zframe.getData(), StandardCharsets.UTF_8);
        if (!DELIM.equals(delim)) {
            throw new RuntimeException("Delimiter <IDS|MSG> not found");
        }
        return delim;
    }

    private boolean isStdinMsg() {
        return this.sockets.pollin(3);
    }

    private boolean isShellMsg() {
        return this.sockets.pollin(2);
    }

    private boolean isHeartbeatMsg() {
        return this.sockets.pollin(1);
    }

    private boolean isControlMsg() {
        return this.sockets.pollin(0);
    }

    private void shutdown() {
        logger.debug("kernel shutdown");
        this.shutdownSystem = true;
    }

    private boolean isShutdown() {
        return this.shutdownSystem;
    }

    private <T> T parse(byte[] bytes, Class<T> theClass) {
        return bytes != null ? (T)MessageSerializer.parse(new String(bytes, StandardCharsets.UTF_8), theClass) : null;
    }
}

