/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.rowpattern;

import com.google.common.collect.ImmutableMap;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.util.AstUtils;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;

public class ExpressionAndValuePointersEquivalence {
    private ExpressionAndValuePointersEquivalence() {
    }

    public static boolean equivalent(LogicalIndexExtractor.ExpressionAndValuePointers left, LogicalIndexExtractor.ExpressionAndValuePointers right) {
        return ExpressionAndValuePointersEquivalence.equivalent(left, right, Symbol::equals);
    }

    public static boolean equivalent(LogicalIndexExtractor.ExpressionAndValuePointers left, LogicalIndexExtractor.ExpressionAndValuePointers right, BiFunction<Symbol, Symbol, Boolean> symbolEquivalence) {
        if (left.getLayout().size() != right.getLayout().size()) {
            return false;
        }
        for (int i = 0; i < left.getLayout().size(); ++i) {
            ValuePointer leftPointer = left.getValuePointers().get(i);
            ValuePointer rightPointer = right.getValuePointers().get(i);
            if (leftPointer.getClass() != rightPointer.getClass()) {
                return false;
            }
            if (leftPointer instanceof ScalarValuePointer) {
                if (ExpressionAndValuePointersEquivalence.equivalent((ScalarValuePointer)leftPointer, (ScalarValuePointer)rightPointer, left.getClassifierSymbols(), left.getMatchNumberSymbols(), right.getClassifierSymbols(), right.getMatchNumberSymbols(), symbolEquivalence)) continue;
                return false;
            }
            if (leftPointer instanceof AggregationValuePointer) {
                if (ExpressionAndValuePointersEquivalence.equivalent((AggregationValuePointer)leftPointer, (AggregationValuePointer)rightPointer, symbolEquivalence)) continue;
                return false;
            }
            throw new UnsupportedOperationException("unexpected ValuePointer type: " + leftPointer.getClass().getSimpleName());
        }
        ImmutableMap.Builder mapping = ImmutableMap.builder();
        for (int i = 0; i < left.getLayout().size(); ++i) {
            mapping.put((Object)left.getLayout().get(i), (Object)right.getLayout().get(i));
        }
        return AstUtils.treeEqual((Node)left.getExpression(), (Node)right.getExpression(), ExpressionAndValuePointersEquivalence.mappingComparator((Map<Symbol, Symbol>)mapping.build()));
    }

    private static boolean equivalent(ScalarValuePointer left, ScalarValuePointer right, Set<Symbol> leftClassifierSymbols, Set<Symbol> leftMatchNumberSymbols, Set<Symbol> rightClassifierSymbols, Set<Symbol> rightMatchNumberSymbols, BiFunction<Symbol, Symbol, Boolean> symbolEquivalence) {
        if (!left.getLogicalIndexPointer().equals(right.getLogicalIndexPointer())) {
            return false;
        }
        Symbol leftInputSymbol = left.getInputSymbol();
        Symbol rightInputSymbol = right.getInputSymbol();
        boolean leftIsClassifier = leftClassifierSymbols.contains(leftInputSymbol);
        boolean leftIsMatchNumber = leftMatchNumberSymbols.contains(leftInputSymbol);
        boolean rightIsClassifier = rightClassifierSymbols.contains(rightInputSymbol);
        boolean rightIsMatchNumber = rightMatchNumberSymbols.contains(rightInputSymbol);
        if (leftIsClassifier != rightIsClassifier || leftIsMatchNumber != rightIsMatchNumber) {
            return false;
        }
        if (!leftIsClassifier && !leftIsMatchNumber) {
            return symbolEquivalence.apply(leftInputSymbol, rightInputSymbol);
        }
        return true;
    }

    private static boolean equivalent(AggregationValuePointer left, AggregationValuePointer right, BiFunction<Symbol, Symbol, Boolean> symbolEquivalence) {
        if (!left.getFunction().equals(right.getFunction()) || !left.getSetDescriptor().equals(right.getSetDescriptor()) || left.getArguments().size() != right.getArguments().size()) {
            return false;
        }
        BiFunction<Node, Node, Boolean> comparator = ExpressionAndValuePointersEquivalence.subsetComparator(left.getClassifierSymbol(), left.getMatchNumberSymbol(), right.getClassifierSymbol(), right.getMatchNumberSymbol(), symbolEquivalence);
        for (int i = 0; i < left.getArguments().size(); ++i) {
            if (AstUtils.treeEqual((Node)((Node)left.getArguments().get(i)), (Node)((Node)right.getArguments().get(i)), comparator)) continue;
            return false;
        }
        return true;
    }

    private static BiFunction<Node, Node, Boolean> subsetComparator(Symbol leftClassifierSymbol, Symbol leftMatchNumberSymbol, Symbol rightClassifierSymbol, Symbol rightMatchNumberSymbol, BiFunction<Symbol, Symbol, Boolean> symbolEquivalence) {
        return (left, right) -> {
            if (left instanceof SymbolReference && right instanceof SymbolReference) {
                Symbol leftSymbol = Symbol.from((Expression)((SymbolReference)left));
                Symbol rightSymbol = Symbol.from((Expression)((SymbolReference)right));
                boolean leftIsClassifier = leftSymbol.equals(leftClassifierSymbol);
                boolean leftIsMatchNumber = leftSymbol.equals(leftMatchNumberSymbol);
                boolean rightIsClassifier = rightSymbol.equals(rightClassifierSymbol);
                boolean rightIsMatchNumber = rightSymbol.equals(rightMatchNumberSymbol);
                if (leftIsClassifier != rightIsClassifier || leftIsMatchNumber != rightIsMatchNumber) {
                    return false;
                }
                if (!leftIsClassifier && !leftIsMatchNumber) {
                    return (Boolean)symbolEquivalence.apply(leftSymbol, rightSymbol);
                }
                return true;
            }
            if (!left.shallowEquals(right)) {
                return false;
            }
            return null;
        };
    }

    private static BiFunction<Node, Node, Boolean> mappingComparator(Map<Symbol, Symbol> mapping) {
        return (left, right) -> {
            if (left instanceof SymbolReference && right instanceof SymbolReference) {
                Symbol leftSymbol = Symbol.from((Expression)((SymbolReference)left));
                Symbol rightSymbol = Symbol.from((Expression)((SymbolReference)right));
                return rightSymbol.equals(mapping.get(leftSymbol));
            }
            if (!left.shallowEquals(right)) {
                return false;
            }
            return null;
        };
    }
}

