/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.sql.fun;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.fun.SqlCase;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorImpl;
import org.apache.calcite.sql.validate.SqlValidatorScope;
import org.apache.calcite.sql.validate.implicit.TypeCoercion;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Static;
import org.checkerframework.checker.nullness.qual.Nullable;

public class SqlCaseOperator
extends SqlOperator {
    public static final SqlCaseOperator INSTANCE = new SqlCaseOperator();

    private SqlCaseOperator() {
        super("CASE", SqlKind.CASE, 200, true, null, InferTypes.RETURN_TYPE, null);
    }

    @Override
    public void validateCall(SqlCall call, SqlValidator validator, SqlValidatorScope scope, SqlValidatorScope operandScope) {
        SqlCase sqlCase = (SqlCase)call;
        SqlNodeList whenOperands = sqlCase.getWhenOperands();
        SqlNodeList thenOperands = sqlCase.getThenOperands();
        SqlNode elseOperand = sqlCase.getElseOperand();
        for (SqlNode operand : whenOperands) {
            operand.validateExpr(validator, operandScope);
        }
        for (SqlNode operand : thenOperands) {
            operand.validateExpr(validator, operandScope);
        }
        if (elseOperand != null) {
            elseOperand.validateExpr(validator, operandScope);
        }
    }

    @Override
    public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
        return this.validateOperands(validator, scope, call);
    }

    @Override
    public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
        SqlCase caseCall = (SqlCase)callBinding.getCall();
        SqlNodeList whenList = caseCall.getWhenOperands();
        SqlNodeList thenList = caseCall.getThenOperands();
        assert (whenList.size() == thenList.size());
        for (SqlNode node : whenList) {
            RelDataType type = SqlTypeUtil.deriveType(callBinding, node);
            if (SqlTypeUtil.inBooleanFamily(type)) continue;
            if (throwOnFailure) {
                throw callBinding.newError(Static.RESOURCE.expectedBoolean());
            }
            return false;
        }
        boolean foundNotNull = false;
        for (SqlNode node : thenList) {
            if (SqlUtil.isNullLiteral(node, false)) continue;
            foundNotNull = true;
        }
        if (!SqlUtil.isNullLiteral(caseCall.getElseOperand(), false)) {
            foundNotNull = true;
        }
        if (!foundNotNull) {
            if (throwOnFailure && !callBinding.isTypeCoercionEnabled()) {
                throw callBinding.newError(Static.RESOURCE.mustNotNullInElse());
            }
            return false;
        }
        return true;
    }

    @Override
    public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
        if (!(opBinding instanceof SqlCallBinding)) {
            return SqlCaseOperator.inferTypeFromOperands(opBinding);
        }
        return SqlCaseOperator.inferTypeFromValidator((SqlCallBinding)opBinding);
    }

    private static RelDataType inferTypeFromValidator(SqlCallBinding callBinding) {
        RelDataType ret;
        SqlCase caseCall = (SqlCase)callBinding.getCall();
        SqlNodeList thenList = caseCall.getThenOperands();
        ArrayList<SqlNode> nullList = new ArrayList<SqlNode>();
        ArrayList<RelDataType> argTypes = new ArrayList<RelDataType>();
        SqlNodeList whenOperands = caseCall.getWhenOperands();
        RelDataTypeFactory typeFactory = callBinding.getTypeFactory();
        for (int i = 0; i < thenList.size(); ++i) {
            SqlBasicCall call;
            SqlNode node = thenList.get(i);
            RelDataType type = SqlTypeUtil.deriveType(callBinding, node);
            SqlNode operand = whenOperands.get(i);
            if (operand.getKind() == SqlKind.IS_NOT_NULL && type.isNullable() && (call = (SqlBasicCall)operand).getOperandList().get(0).equalsDeep(node, Litmus.IGNORE)) {
                type = typeFactory.createTypeWithNullability(type, false);
            }
            argTypes.add(type);
            if (!SqlUtil.isNullLiteral(node, false)) continue;
            nullList.add(node);
        }
        SqlNode elseOp = Objects.requireNonNull(caseCall.getElseOperand(), () -> "elseOperand for " + caseCall);
        argTypes.add(SqlTypeUtil.deriveType(callBinding, elseOp));
        if (SqlUtil.isNullLiteral(elseOp, false)) {
            nullList.add(elseOp);
        }
        if (null == (ret = typeFactory.leastRestrictive(argTypes))) {
            TypeCoercion typeCoercion;
            RelDataType commonType;
            boolean coerced = false;
            if (callBinding.isTypeCoercionEnabled() && null != (commonType = (typeCoercion = callBinding.getValidator().getTypeCoercion()).getWiderTypeFor(argTypes, true)) && (coerced = typeCoercion.caseWhenCoercion(callBinding))) {
                ret = SqlTypeUtil.deriveType(callBinding);
            }
            if (!coerced) {
                throw callBinding.newValidationError(Static.RESOURCE.illegalMixingOfTypes());
            }
        }
        SqlValidatorImpl validator = (SqlValidatorImpl)callBinding.getValidator();
        Objects.requireNonNull(ret, () -> "return type for " + callBinding);
        for (SqlNode node : nullList) {
            validator.setValidatedNodeType(node, ret);
        }
        return ret;
    }

    private static RelDataType inferTypeFromOperands(SqlOperatorBinding opBinding) {
        RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
        List<RelDataType> argTypes = opBinding.collectOperandTypes();
        assert (argTypes.size() % 2 == 1) : "odd number of arguments expected: " + argTypes.size();
        assert (argTypes.size() > 1) : "CASE must have more than 1 argument. Given " + argTypes.size() + ", " + argTypes;
        ArrayList<RelDataType> thenTypes = new ArrayList<RelDataType>();
        for (int j = 1; j < argTypes.size() - 1; j += 2) {
            RelDataType argType = argTypes.get(j);
            if (opBinding instanceof RexCallBinding) {
                RexCall isNotNullCall;
                RexCallBinding rexCallBinding = (RexCallBinding)opBinding;
                RexNode whenNode = rexCallBinding.operands().get(j - 1);
                RexNode thenNode = rexCallBinding.operands().get(j);
                if (whenNode.getKind() == SqlKind.IS_NOT_NULL && argType.isNullable() && (isNotNullCall = (RexCall)whenNode).getOperands().get(0).equals(thenNode)) {
                    argType = typeFactory.createTypeWithNullability(argType, false);
                }
            }
            thenTypes.add(argType);
        }
        thenTypes.add(Iterables.getLast(argTypes));
        return Objects.requireNonNull(typeFactory.leastRestrictive(thenTypes), () -> "Can't find leastRestrictive type for " + thenTypes);
    }

    @Override
    public SqlOperandCountRange getOperandCountRange() {
        return SqlOperandCountRanges.any();
    }

    @Override
    public SqlSyntax getSyntax() {
        return SqlSyntax.SPECIAL;
    }

    @Override
    public SqlCall createCall(@Nullable SqlLiteral functionQualifier, SqlParserPos pos, SqlNode ... operands) {
        assert (functionQualifier == null);
        assert (operands.length == 4);
        return new SqlCase(pos, operands[0], (SqlNodeList)operands[1], (SqlNodeList)operands[2], operands[3]);
    }

    @Override
    public void unparse(SqlWriter writer, SqlCall call_, int leftPrec, int rightPrec) {
        SqlCase kase = (SqlCase)call_;
        SqlWriter.Frame frame = writer.startList(SqlWriter.FrameTypeEnum.CASE, "CASE", "END");
        assert (kase.whenList.size() == kase.thenList.size());
        if (kase.value != null) {
            kase.value.unparse(writer, 0, 0);
        }
        for (Pair<SqlNode, SqlNode> pair : Pair.zip(kase.whenList, kase.thenList)) {
            writer.sep("WHEN");
            ((SqlNode)pair.left).unparse(writer, 0, 0);
            writer.sep("THEN");
            ((SqlNode)pair.right).unparse(writer, 0, 0);
        }
        SqlNode elseExpr = kase.elseExpr;
        if (elseExpr != null) {
            writer.sep("ELSE");
            elseExpr.unparse(writer, 0, 0);
        }
        writer.endList(frame);
    }
}

