/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.QuantifiedComparisonExpression;
import io.trino.sql.tree.SubscriptExpression;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;

public class UnwrapSingleColumnRowInApply
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode();
    private final TypeAnalyzer typeAnalyzer;

    public UnwrapSingleColumnRowInApply(TypeAnalyzer typeAnalyzer) {
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode node, Captures captures, Rule.Context context) {
        Assignments.Builder inputAssignments = Assignments.builder().putIdentities(node.getInput().getOutputSymbols());
        Assignments.Builder nestedPlanAssignments = Assignments.builder().putIdentities(node.getSubquery().getOutputSymbols());
        boolean applied = false;
        Assignments.Builder applyAssignments = Assignments.builder();
        for (Map.Entry<Symbol, Expression> assignment : node.getSubqueryAssignments().entrySet()) {
            Symbol output = assignment.getKey();
            Expression expression = assignment.getValue();
            Optional<Object> unwrapped = Optional.empty();
            if (expression instanceof InPredicate) {
                InPredicate predicate = (InPredicate)expression;
                unwrapped = this.unwrapSingleColumnRow(context, predicate.getValue(), predicate.getValueList(), (value, list) -> new InPredicate((Expression)value.toSymbolReference(), (Expression)list.toSymbolReference()));
            } else if (expression instanceof QuantifiedComparisonExpression) {
                QuantifiedComparisonExpression comparison = (QuantifiedComparisonExpression)expression;
                unwrapped = this.unwrapSingleColumnRow(context, comparison.getValue(), comparison.getSubquery(), (value, list) -> new QuantifiedComparisonExpression(comparison.getOperator(), comparison.getQuantifier(), (Expression)value.toSymbolReference(), (Expression)list.toSymbolReference()));
            }
            if (unwrapped.isPresent()) {
                applied = true;
                Unwrapping unwrapping = (Unwrapping)unwrapped.get();
                inputAssignments.add(unwrapping.getInputAssignment());
                nestedPlanAssignments.add(unwrapping.getNestedPlanAssignment());
                applyAssignments.put(output, unwrapping.getExpression());
                continue;
            }
            applyAssignments.put(assignment);
        }
        if (!applied) {
            return Rule.Result.empty();
        }
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new ApplyNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getInput(), inputAssignments.build()), new ProjectNode(context.getIdAllocator().getNextId(), node.getSubquery(), nestedPlanAssignments.build()), applyAssignments.build(), node.getCorrelation(), node.getOriginSubquery()), Assignments.identity(node.getOutputSymbols())));
    }

    private Optional<Unwrapping> unwrapSingleColumnRow(Rule.Context context, Expression value, Expression list, BiFunction<Symbol, Symbol, Expression> function) {
        RowType rowType;
        Type type = this.typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), value);
        if (type instanceof RowType && (rowType = (RowType)type).getFields().size() == 1) {
            Type elementType = (Type)rowType.getTypeParameters().get(0);
            Symbol valueSymbol = context.getSymbolAllocator().newSymbol("input", elementType);
            Symbol listSymbol = context.getSymbolAllocator().newSymbol("subquery", elementType);
            Assignments.Assignment inputAssignment = new Assignments.Assignment(valueSymbol, (Expression)new SubscriptExpression(value, (Expression)new LongLiteral("1")));
            Assignments.Assignment nestedPlanAssignment = new Assignments.Assignment(listSymbol, (Expression)new SubscriptExpression(list, (Expression)new LongLiteral("1")));
            Expression comparison = function.apply(valueSymbol, listSymbol);
            return Optional.of(new Unwrapping(comparison, inputAssignment, nestedPlanAssignment));
        }
        return Optional.empty();
    }

    private static class Unwrapping {
        private final Expression expression;
        private final Assignments.Assignment inputAssignment;
        private final Assignments.Assignment nestedPlanAssignment;

        public Unwrapping(Expression expression, Assignments.Assignment inputAssignment, Assignments.Assignment nestedPlanAssignment) {
            this.expression = Objects.requireNonNull(expression, "expression is null");
            this.inputAssignment = Objects.requireNonNull(inputAssignment, "inputAssignment is null");
            this.nestedPlanAssignment = Objects.requireNonNull(nestedPlanAssignment, "nestedPlanAssignment is null");
        }

        public Expression getExpression() {
            return this.expression;
        }

        public Assignments.Assignment getInputAssignment() {
            return this.inputAssignment;
        }

        public Assignments.Assignment getNestedPlanAssignment() {
            return this.nestedPlanAssignment;
        }
    }
}

