/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hive.druid.org.apache.calcite.adapter.enumerable;

import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.hive.druid.com.google.common.collect.ImmutableList;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.EnumUtils;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.EnumerableTraitsUtils;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.hive.druid.org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.hive.druid.org.apache.calcite.linq4j.tree.BlockBuilder;
import org.apache.hive.druid.org.apache.calcite.linq4j.tree.Expression;
import org.apache.hive.druid.org.apache.calcite.linq4j.tree.Expressions;
import org.apache.hive.druid.org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.hive.druid.org.apache.calcite.linq4j.tree.Primitive;
import org.apache.hive.druid.org.apache.calcite.linq4j.tree.Statement;
import org.apache.hive.druid.org.apache.calcite.plan.DeriveMode;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptCluster;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptCost;
import org.apache.hive.druid.org.apache.calcite.plan.RelOptPlanner;
import org.apache.hive.druid.org.apache.calcite.plan.RelTrait;
import org.apache.hive.druid.org.apache.calcite.plan.RelTraitSet;
import org.apache.hive.druid.org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.hive.druid.org.apache.calcite.rel.RelNode;
import org.apache.hive.druid.org.apache.calcite.rel.RelWriter;
import org.apache.hive.druid.org.apache.calcite.rel.core.CorrelationId;
import org.apache.hive.druid.org.apache.calcite.rel.core.Join;
import org.apache.hive.druid.org.apache.calcite.rel.core.JoinRelType;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMdCollation;
import org.apache.hive.druid.org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.hive.druid.org.apache.calcite.rex.RexNode;
import org.apache.hive.druid.org.apache.calcite.util.BuiltInMethod;
import org.apache.hive.druid.org.apache.calcite.util.ImmutableBitSet;
import org.apache.hive.druid.org.apache.calcite.util.Pair;
import org.checkerframework.checker.nullness.qual.Nullable;

public class EnumerableBatchNestedLoopJoin
extends Join
implements EnumerableRel {
    private final ImmutableBitSet requiredColumns;

    protected EnumerableBatchNestedLoopJoin(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelNode right, RexNode condition, Set<CorrelationId> variablesSet, ImmutableBitSet requiredColumns, JoinRelType joinType) {
        super(cluster, traits, ImmutableList.of(), left, right, condition, variablesSet, joinType);
        this.requiredColumns = requiredColumns;
    }

    public static EnumerableBatchNestedLoopJoin create(RelNode left, RelNode right, RexNode condition, ImmutableBitSet requiredColumns, Set<CorrelationId> variablesSet, JoinRelType joinType) {
        RelOptCluster cluster = left.getCluster();
        RelMetadataQuery mq = cluster.getMetadataQuery();
        RelTraitSet traitSet = cluster.traitSetOf((RelTrait)EnumerableConvention.INSTANCE).replaceIfs(RelCollationTraitDef.INSTANCE, () -> RelMdCollation.enumerableBatchNestedLoopJoin(mq, left, right, joinType));
        return new EnumerableBatchNestedLoopJoin(cluster, traitSet, left, right, condition, variablesSet, requiredColumns, joinType);
    }

    @Override
    public @Nullable Pair<RelTraitSet, List<RelTraitSet>> passThroughTraits(RelTraitSet required) {
        return EnumerableTraitsUtils.passThroughTraitsForJoin(required, this.joinType, this.getLeft().getRowType().getFieldCount(), this.traitSet);
    }

    @Override
    public @Nullable Pair<RelTraitSet, List<RelTraitSet>> deriveTraits(RelTraitSet childTraits, int childId) {
        return EnumerableTraitsUtils.deriveTraitsForJoin(childTraits, childId, this.joinType, this.traitSet, this.right.getTraitSet());
    }

    @Override
    public DeriveMode getDeriveMode() {
        if (this.joinType == JoinRelType.FULL || this.joinType == JoinRelType.RIGHT) {
            return DeriveMode.PROHIBITED;
        }
        return DeriveMode.LEFT_FIRST;
    }

    @Override
    public EnumerableBatchNestedLoopJoin copy(RelTraitSet traitSet, RexNode condition, RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone) {
        return new EnumerableBatchNestedLoopJoin(this.getCluster(), traitSet, left, right, condition, this.variablesSet, this.requiredColumns, joinType);
    }

    @Override
    public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        double rowCount = mq.getRowCount(this);
        double rightRowCount = this.right.estimateRowCount(mq);
        double leftRowCount = this.left.estimateRowCount(mq);
        if (Double.isInfinite(leftRowCount) || Double.isInfinite(rightRowCount)) {
            return planner.getCostFactory().makeInfiniteCost();
        }
        Double restartCount = mq.getRowCount(this.getLeft()) / (double)this.variablesSet.size();
        RelOptCost rightCost = planner.getCost(this.getRight(), mq);
        if (rightCost == null) {
            return null;
        }
        RelOptCost rescanCost = rightCost.multiplyBy(Math.max(1.0, restartCount - 1.0));
        return planner.getCostFactory().makeCost(rowCount + leftRowCount, 0.0, 0.0).plus(rescanCost);
    }

    @Override
    public RelWriter explainTerms(RelWriter pw) {
        super.explainTerms(pw);
        return pw.item("batchSize", this.variablesSet.size());
    }

    @Override
    public EnumerableRel.Result implement(EnumerableRelImplementor implementor, EnumerableRel.Prefer pref) {
        Object decl;
        ParameterExpression corrArg;
        int c;
        BlockBuilder builder = new BlockBuilder();
        EnumerableRel.Result leftResult = implementor.visitChild(this, 0, (EnumerableRel)this.left, pref);
        Expression leftExpression = builder.append("left", leftResult.block);
        ArrayList<String> corrVar = new ArrayList<String>();
        for (CorrelationId c2 : this.variablesSet) {
            corrVar.add(c2.getName());
        }
        BlockBuilder corrBlock = new BlockBuilder();
        Type corrVarType = leftResult.physType.getJavaRowType();
        ParameterExpression corrArgList = Expressions.parameter(16, List.class, "corrList" + Integer.toUnsignedString(this.getId()));
        if (!Primitive.is(corrVarType)) {
            for (c = 0; c < corrVar.size(); ++c) {
                corrArg = Expressions.parameter(16, corrVarType, (String)corrVar.get(c));
                decl = Expressions.declare(16, corrArg, (Expression)Expressions.convert_(Expressions.call((Expression)corrArgList, BuiltInMethod.LIST_GET.method, Expressions.constant(c)), corrVarType));
                corrBlock.add((Statement)decl);
                implementor.registerCorrelVariable((String)corrVar.get(c), corrArg, corrBlock, leftResult.physType);
            }
        } else {
            for (c = 0; c < corrVar.size(); ++c) {
                corrArg = Expressions.parameter(16, Primitive.box(corrVarType), "$box" + (String)corrVar.get(c));
                decl = Expressions.declare(16, corrArg, (Expression)Expressions.call((Expression)corrArgList, BuiltInMethod.LIST_GET.method, Expressions.constant(c)));
                corrBlock.add((Statement)decl);
                ParameterExpression corrRef = (ParameterExpression)corrBlock.append((String)corrVar.get(c), Expressions.unbox(corrArg));
                implementor.registerCorrelVariable((String)corrVar.get(c), corrRef, corrBlock, leftResult.physType);
            }
        }
        EnumerableRel.Result rightResult = implementor.visitChild(this, 1, (EnumerableRel)this.right, pref);
        corrBlock.add(rightResult.block);
        for (String c3 : corrVar) {
            implementor.clearCorrelVariable(c3);
        }
        PhysType physType = PhysTypeImpl.of(implementor.getTypeFactory(), this.getRowType(), pref.prefer(JavaRowFormat.CUSTOM));
        Expression selector = EnumUtils.joinSelector(this.joinType, physType, ImmutableList.of(leftResult.physType, rightResult.physType));
        Expression predicate = EnumUtils.generatePredicate(implementor, this.getCluster().getRexBuilder(), this.left, this.right, leftResult.physType, rightResult.physType, this.condition);
        builder.append(Expressions.call(BuiltInMethod.CORRELATE_BATCH_JOIN.method, Expressions.constant((Object)EnumUtils.toLinq4jJoinType(this.joinType)), leftExpression, Expressions.lambda(corrBlock.toBlock(), corrArgList), selector, predicate, Expressions.constant(this.variablesSet.size())));
        return implementor.result(physType, builder.toBlock());
    }
}

