/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.metadata;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptPredicateList;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.MetadataDef;
import org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMdExpressionLineage;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.checkerframework.checker.nullness.qual.Nullable;

public class RelMdAllPredicates
implements MetadataHandler<BuiltInMetadata.AllPredicates> {
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource(new RelMdAllPredicates(), BuiltInMetadata.AllPredicates.Handler.class);

    @Override
    public MetadataDef<BuiltInMetadata.AllPredicates> getDef() {
        return BuiltInMetadata.AllPredicates.DEF;
    }

    public @Nullable RelOptPredicateList getAllPredicates(RelNode rel, RelMetadataQuery mq) {
        return null;
    }

    public @Nullable RelOptPredicateList getAllPredicates(HepRelVertex rel, RelMetadataQuery mq) {
        return mq.getAllPredicates(rel.getCurrentRel());
    }

    public @Nullable RelOptPredicateList getAllPredicates(RelSubset rel, RelMetadataQuery mq) {
        RelNode bestOrOriginal = Util.first(rel.getBest(), rel.getOriginal());
        if (bestOrOriginal == null) {
            return null;
        }
        return mq.getAllPredicates(bestOrOriginal);
    }

    public @Nullable RelOptPredicateList getAllPredicates(TableScan scan, RelMetadataQuery mq) {
        BuiltInMetadata.AllPredicates.Handler handler = scan.getTable().unwrap(BuiltInMetadata.AllPredicates.Handler.class);
        if (handler != null) {
            return handler.getAllPredicates(scan, mq);
        }
        return RelOptPredicateList.EMPTY;
    }

    public @Nullable RelOptPredicateList getAllPredicates(Project project, RelMetadataQuery mq) {
        return mq.getAllPredicates(project.getInput());
    }

    public @Nullable RelOptPredicateList getAllPredicates(Filter filter, RelMetadataQuery mq) {
        return RelMdAllPredicates.getAllFilterPredicates(filter.getInput(), mq, filter.getCondition());
    }

    public @Nullable RelOptPredicateList getAllPredicates(Calc calc, RelMetadataQuery mq) {
        RexProgram rexProgram = calc.getProgram();
        if (rexProgram.getCondition() != null) {
            RexNode condition = rexProgram.expandLocalRef(rexProgram.getCondition());
            return RelMdAllPredicates.getAllFilterPredicates(calc.getInput(), mq, condition);
        }
        return mq.getAllPredicates(calc.getInput());
    }

    private static @Nullable RelOptPredicateList getAllFilterPredicates(RelNode rel, RelMetadataQuery mq, RexNode pred) {
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        RelOptPredicateList predsBelow = mq.getAllPredicates(rel);
        if (predsBelow == null) {
            return null;
        }
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        pred.accept(inputFinder);
        ImmutableBitSet inputFieldsUsed = inputFinder.build();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        for (int idx : inputFieldsUsed) {
            RexInputRef ref = RexInputRef.of(idx, rel.getRowType().getFieldList());
            Set<RexNode> originalExprs = mq.getExpressionLineage(rel, ref);
            if (originalExprs == null) {
                return null;
            }
            mapping.put(ref, originalExprs);
        }
        Set<RexNode> allExprs = RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, pred, mapping);
        if (allExprs == null) {
            return null;
        }
        return predsBelow.union(rexBuilder, RelOptPredicateList.of(rexBuilder, allExprs));
    }

    public @Nullable RelOptPredicateList getAllPredicates(Join join, RelMetadataQuery mq) {
        if (join.getJoinType().isOuterJoin()) {
            return null;
        }
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        RexNode pred = join.getCondition();
        HashMultimap<List<String>, RexTableInputRef.RelTableRef> qualifiedNamesToRefs = HashMultimap.create();
        RelOptPredicateList newPreds = RelOptPredicateList.EMPTY;
        for (RelNode input : join.getInputs()) {
            RelOptPredicateList inputPreds = mq.getAllPredicates(input);
            if (inputPreds == null) {
                return null;
            }
            Set<RexTableInputRef.RelTableRef> tableRefs = mq.getTableReferences(input);
            if (tableRefs == null) {
                return null;
            }
            if (input == join.getLeft()) {
                for (RexTableInputRef.RelTableRef relTableRef : tableRefs) {
                    qualifiedNamesToRefs.put(relTableRef.getQualifiedName(), relTableRef);
                }
                newPreds = newPreds.union(rexBuilder, inputPreds);
                continue;
            }
            HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef> currentTablesMapping = new HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef>();
            for (RexTableInputRef.RelTableRef rightRef : tableRefs) {
                int shift = 0;
                Collection lRefs = qualifiedNamesToRefs.get(rightRef.getQualifiedName());
                if (lRefs != null) {
                    shift = lRefs.size();
                }
                currentTablesMapping.put(rightRef, RexTableInputRef.RelTableRef.of(rightRef.getTable(), shift + rightRef.getEntityNumber()));
            }
            List<RexNode> list = Util.transform(inputPreds.pulledUpPredicates, e -> RexUtil.swapTableReferences(rexBuilder, e, currentTablesMapping));
            newPreds = newPreds.union(rexBuilder, RelOptPredicateList.of(rexBuilder, list));
        }
        LinkedHashSet<RelDataTypeField> inputExtraFields = new LinkedHashSet<RelDataTypeField>();
        RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(inputExtraFields);
        pred.accept(inputFinder);
        ImmutableBitSet inputFieldsUsed = inputFinder.build();
        LinkedHashMap<RexInputRef, Set<RexNode>> mapping = new LinkedHashMap<RexInputRef, Set<RexNode>>();
        RelDataType fullRowType = SqlValidatorUtil.createJoinType(rexBuilder.getTypeFactory(), join.getLeft().getRowType(), join.getRight().getRowType(), null, ImmutableList.of());
        for (int idx : inputFieldsUsed) {
            RexInputRef inputRef = RexInputRef.of(idx, fullRowType.getFieldList());
            Set<RexNode> originalExprs = mq.getExpressionLineage(join, inputRef);
            if (originalExprs == null) {
                return null;
            }
            RexInputRef ref = RexInputRef.of(idx, fullRowType.getFieldList());
            mapping.put(ref, originalExprs);
        }
        Set<RexNode> set = RelMdExpressionLineage.createAllPossibleExpressions(rexBuilder, pred, mapping);
        if (set == null) {
            return null;
        }
        return newPreds.union(rexBuilder, RelOptPredicateList.of(rexBuilder, set));
    }

    public @Nullable RelOptPredicateList getAllPredicates(Aggregate agg, RelMetadataQuery mq) {
        return mq.getAllPredicates(agg.getInput());
    }

    public @Nullable RelOptPredicateList getAllPredicates(TableModify tableModify, RelMetadataQuery mq) {
        return mq.getAllPredicates(tableModify.getInput());
    }

    public @Nullable RelOptPredicateList getAllPredicates(SetOp setOp, RelMetadataQuery mq) {
        RexBuilder rexBuilder = setOp.getCluster().getRexBuilder();
        HashMultimap<List<String>, RexTableInputRef.RelTableRef> qualifiedNamesToRefs = HashMultimap.create();
        RelOptPredicateList newPreds = RelOptPredicateList.EMPTY;
        for (int i = 0; i < setOp.getInputs().size(); ++i) {
            RelNode input = setOp.getInput(i);
            RelOptPredicateList inputPreds = mq.getAllPredicates(input);
            if (inputPreds == null) {
                return null;
            }
            Set<RexTableInputRef.RelTableRef> tableRefs = mq.getTableReferences(input);
            if (tableRefs == null) {
                return null;
            }
            if (i == 0) {
                for (RexTableInputRef.RelTableRef relTableRef : tableRefs) {
                    qualifiedNamesToRefs.put(relTableRef.getQualifiedName(), relTableRef);
                }
                newPreds = newPreds.union(rexBuilder, inputPreds);
                continue;
            }
            HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef> currentTablesMapping = new HashMap<RexTableInputRef.RelTableRef, RexTableInputRef.RelTableRef>();
            for (RexTableInputRef.RelTableRef rightRef : tableRefs) {
                int shift = 0;
                Collection lRefs = qualifiedNamesToRefs.get(rightRef.getQualifiedName());
                if (lRefs != null) {
                    shift = lRefs.size();
                }
                currentTablesMapping.put(rightRef, RexTableInputRef.RelTableRef.of(rightRef.getTable(), shift + rightRef.getEntityNumber()));
            }
            for (RexTableInputRef.RelTableRef newRef : currentTablesMapping.values()) {
                qualifiedNamesToRefs.put(newRef.getQualifiedName(), newRef);
            }
            List<RexNode> list = Util.transform(inputPreds.pulledUpPredicates, e -> RexUtil.swapTableReferences(rexBuilder, e, currentTablesMapping));
            newPreds = newPreds.union(rexBuilder, RelOptPredicateList.of(rexBuilder, list));
        }
        return newPreds;
    }

    public @Nullable RelOptPredicateList getAllPredicates(Sort sort, RelMetadataQuery mq) {
        return mq.getAllPredicates(sort.getInput());
    }

    public @Nullable RelOptPredicateList getAllPredicates(Exchange exchange, RelMetadataQuery mq) {
        return mq.getAllPredicates(exchange.getInput());
    }
}

