/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.adapter.enumerable;

import com.google.common.collect.ImmutableList;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.calcite.adapter.enumerable.AggImpState;
import org.apache.calcite.adapter.enumerable.AggregateLambdaFactory;
import org.apache.calcite.adapter.enumerable.EnumerableAggregateBase;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.enumerable.impl.AggResultContextImpl;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Function0;
import org.apache.calcite.linq4j.function.Function2;
import org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.checkerframework.checker.nullness.qual.Nullable;

public class EnumerableSortedAggregate
extends EnumerableAggregateBase
implements EnumerableRel {
    public EnumerableSortedAggregate(RelOptCluster cluster, RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
        super(cluster, traitSet, ImmutableList.of(), input, groupSet, groupSets, aggCalls);
        assert (this.getConvention() instanceof EnumerableConvention);
    }

    @Override
    public EnumerableSortedAggregate copy(RelTraitSet traitSet, RelNode input, ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
        return new EnumerableSortedAggregate(this.getCluster(), traitSet, input, groupSet, groupSets, aggCalls);
    }

    @Override
    public @Nullable Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(RelTraitSet required) {
        if (!EnumerableSortedAggregate.isSimple(this)) {
            return null;
        }
        RelTraitSet inputTraits = this.getInput().getTraitSet();
        RelCollation collation = (RelCollation)Objects.requireNonNull(required.getCollation(), () -> "collation trait is null, required traits are " + required);
        ImmutableBitSet requiredKeys = ImmutableBitSet.of(RelCollations.ordinals(collation));
        ImmutableBitSet groupKeys = ImmutableBitSet.range(this.groupSet.cardinality());
        Mapping mapping = Mappings.source(this.groupSet.toList(), this.input.getRowType().getFieldCount());
        if (requiredKeys.equals(groupKeys)) {
            RelCollation inputCollation = RexUtil.apply((Mappings.TargetMapping)mapping, collation);
            return Pair.of(required, ImmutableList.of(inputTraits.replace(inputCollation)));
        }
        if (groupKeys.contains(requiredKeys)) {
            ArrayList<RelFieldCollation> list = new ArrayList<RelFieldCollation>(collation.getFieldCollations());
            groupKeys.except(requiredKeys).forEach(k -> list.add(new RelFieldCollation((int)k)));
            RelCollation aggCollation = RelCollations.of(list);
            RelCollation inputCollation = RexUtil.apply((Mappings.TargetMapping)mapping, aggCollation);
            return Pair.of(this.traitSet.replace(aggCollation), ImmutableList.of(inputTraits.replace(inputCollation)));
        }
        return null;
    }

    @Override
    public EnumerableRel.Result implement(EnumerableRelImplementor implementor, EnumerableRel.Prefer pref) {
        if (!Aggregate.isSimple(this)) {
            throw Util.needToImplement("EnumerableSortedAggregate");
        }
        JavaTypeFactory typeFactory = implementor.getTypeFactory();
        BlockBuilder builder = new BlockBuilder();
        EnumerableRel child = (EnumerableRel)this.getInput();
        EnumerableRel.Result result = implementor.visitChild(this, 0, child, pref);
        Expression childExp = builder.append("child", result.block);
        PhysType physType = PhysTypeImpl.of(typeFactory, this.getRowType(), pref.preferCustom());
        PhysType inputPhysType = result.physType;
        ParameterExpression parameter = Expressions.parameter(inputPhysType.getJavaRowType(), "a0");
        PhysType keyPhysType = inputPhysType.project(this.groupSet.asList(), this.getGroupType() != Aggregate.Group.SIMPLE, JavaRowFormat.LIST);
        int groupCount = this.getGroupCount();
        ArrayList<AggImpState> aggs = new ArrayList<AggImpState>(this.aggCalls.size());
        for (Ord call : Ord.zip(this.aggCalls)) {
            aggs.add(new AggImpState(call.i, (AggregateCall)call.e, false));
        }
        ArrayList<Expression> initExpressions = new ArrayList<Expression>();
        BlockBuilder initBlock = new BlockBuilder();
        List<Type> aggStateTypes = this.createAggStateTypes(initExpressions, initBlock, aggs, typeFactory);
        PhysType accPhysType = PhysTypeImpl.of(typeFactory, typeFactory.createSyntheticType(aggStateTypes));
        this.declareParentAccumulator(initExpressions, initBlock, accPhysType);
        Expression accumulatorInitializer = builder.append("accumulatorInitializer", Expressions.lambda(Function0.class, initBlock.toBlock(), new ParameterExpression[0]));
        ParameterExpression inParameter = Expressions.parameter(inputPhysType.getJavaRowType(), "in");
        ParameterExpression acc_ = Expressions.parameter(accPhysType.getJavaRowType(), "acc");
        this.createAccumulatorAdders(inParameter, aggs, accPhysType, acc_, inputPhysType, builder, implementor, typeFactory);
        ParameterExpression lambdaFactory = Expressions.parameter(AggregateLambdaFactory.class, builder.newName("lambdaFactory"));
        this.implementLambdaFactory(builder, inputPhysType, aggs, accumulatorInitializer, false, lambdaFactory);
        BlockBuilder resultBlock = new BlockBuilder();
        Expressions.FluentList<Expression> results = Expressions.list();
        Type keyType = keyPhysType.getJavaRowType();
        ParameterExpression key_ = Expressions.parameter(keyType, "key");
        for (int j = 0; j < groupCount; ++j) {
            Expression ref = keyPhysType.fieldReference(key_, j);
            results.add(ref);
        }
        for (AggImpState agg : aggs) {
            results.add(agg.implementor.implementResult(Objects.requireNonNull(agg.context, () -> "agg.context is null for " + agg), new AggResultContextImpl(resultBlock, agg.call, Objects.requireNonNull(agg.state, () -> "agg.state is null for " + agg), key_, keyPhysType)));
        }
        resultBlock.add(physType.record(results));
        Expression keySelector_ = builder.append("keySelector", inputPhysType.generateSelector(parameter, this.groupSet.asList(), keyPhysType.getFormat()));
        Expression comparator = keyPhysType.generateComparator((RelCollation)Objects.requireNonNull(this.getTraitSet().getCollation(), () -> "getTraitSet().getCollation() is null, current traits are " + this.getTraitSet()));
        Expression resultSelector_ = builder.append("resultSelector", Expressions.lambda(Function2.class, resultBlock.toBlock(), key_, acc_));
        builder.add(Expressions.return_(null, Expressions.call(childExp, BuiltInMethod.SORTED_GROUP_BY.method, Expressions.list(keySelector_, Expressions.call((Expression)lambdaFactory, BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_INITIALIZER.method, new Expression[0]), Expressions.call((Expression)lambdaFactory, BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_ADDER.method, new Expression[0]), Expressions.call((Expression)lambdaFactory, BuiltInMethod.AGG_LAMBDA_FACTORY_ACC_RESULT_SELECTOR.method, resultSelector_), comparator))));
        return implementor.result(physType, builder.toBlock());
    }
}

