package org.apache.hadoop.hive.ql.optimizer.calcite.stats;

import io.prestosql.hive.$internal.org.slf4j.Logger;
import io.prestosql.hive.$internal.org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelVisitor;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SemiJoin;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMdRowCount;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount.class */
public class HiveRelMdRowCount extends RelMdRowCount {
    protected static final Logger LOG = LoggerFactory.getLogger(HiveRelMdRowCount.class.getName());
    public static final RelMetadataProvider SOURCE = ReflectiveRelMetadataProvider.reflectiveSource(BuiltInMethod.ROW_COUNT.method, new HiveRelMdRowCount());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount$FKSideInfo.class */
    public static class FKSideInfo {
        public final double rowCount;
        public final double distinctCount;

        public FKSideInfo(double d, double d2) {
            this.rowCount = d;
            this.distinctCount = d2;
        }

        public String toString() {
            return String.format("FKInfo(rowCount=%.2f,ndv=%.2f)", Double.valueOf(this.rowCount), Double.valueOf(this.distinctCount));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount$IsSimpleTreeOnJoinKey.class */
    public static class IsSimpleTreeOnJoinKey extends RelVisitor {
        int joinKey;
        boolean simpleTree = true;
        RelMetadataQuery mq;

        static boolean check(RelNode relNode, int i, RelMetadataQuery relMetadataQuery) {
            IsSimpleTreeOnJoinKey isSimpleTreeOnJoinKey = new IsSimpleTreeOnJoinKey(i, relMetadataQuery);
            isSimpleTreeOnJoinKey.go(relNode);
            return isSimpleTreeOnJoinKey.simpleTree;
        }

        IsSimpleTreeOnJoinKey(int i, RelMetadataQuery relMetadataQuery) {
            this.joinKey = i;
            this.mq = relMetadataQuery;
        }

        public void visit(RelNode relNode, int i, RelNode relNode2) {
            if (relNode instanceof HepRelVertex) {
                relNode = ((HepRelVertex) relNode).getCurrentRel();
            }
            if (relNode instanceof TableScan) {
                this.simpleTree = true;
            } else if (relNode instanceof Project) {
                this.simpleTree = isSimple((Project) relNode);
            } else if (relNode instanceof Filter) {
                this.simpleTree = isSimple((Filter) relNode, this.mq);
            } else {
                this.simpleTree = false;
            }
            if (this.simpleTree) {
                super.visit(relNode, i, relNode2);
            }
        }

        private boolean isSimple(Project project) {
            RexInputRef rexInputRef = (RexNode) project.getProjects().get(this.joinKey);
            if (!(rexInputRef instanceof RexInputRef)) {
                return false;
            }
            this.joinKey = rexInputRef.getIndex();
            return true;
        }

        private boolean isSimple(Filter filter, RelMetadataQuery relMetadataQuery) {
            return HiveRelMdRowCount.isKey(RelOptUtil.InputFinder.bits(filter.getCondition()), filter, relMetadataQuery);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount$PKFKRelationInfo.class */
    public static class PKFKRelationInfo {
        public final int fkSide;
        public final double ndvScalingFactor;
        public final FKSideInfo fkInfo;
        public final PKSideInfo pkInfo;
        public final boolean isPKSideSimple;

        PKFKRelationInfo(int i, FKSideInfo fKSideInfo, PKSideInfo pKSideInfo, double d, boolean z) {
            this.fkSide = i;
            this.fkInfo = fKSideInfo;
            this.pkInfo = pKSideInfo;
            this.ndvScalingFactor = d;
            this.isPKSideSimple = z;
        }

        public String toString() {
            return String.format("Primary - Foreign Key join:\n\tfkSide = %d\n\tFKInfo:%s\n\tPKInfo:%s\n\tisPKSideSimple:%s\n\tNDV Scaling Factor:%.2f\n", Integer.valueOf(this.fkSide), this.fkInfo, this.pkInfo, Boolean.valueOf(this.isPKSideSimple), Double.valueOf(this.ndvScalingFactor));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/stats/HiveRelMdRowCount$PKSideInfo.class */
    public static class PKSideInfo extends FKSideInfo {
        public final double selectivity;

        public PKSideInfo(double d, double d2, double d3) {
            super(d, d2);
            this.selectivity = d3;
        }

        @Override // org.apache.hadoop.hive.ql.optimizer.calcite.stats.HiveRelMdRowCount.FKSideInfo
        public String toString() {
            return String.format("PKInfo(rowCount=%.2f,ndv=%.2f,selectivity=%.2f)", Double.valueOf(this.rowCount), Double.valueOf(this.distinctCount), Double.valueOf(this.selectivity));
        }
    }

    protected HiveRelMdRowCount() {
    }

    public Double getRowCount(Join join, RelMetadataQuery relMetadataQuery) {
        PKFKRelationInfo analyzeJoinForPKFK = analyzeJoinForPKFK(join, relMetadataQuery);
        if (analyzeJoinForPKFK == null) {
            return Double.valueOf(join.getRows());
        }
        double min = Math.min(1.0d, analyzeJoinForPKFK.pkInfo.selectivity * analyzeJoinForPKFK.ndvScalingFactor);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Identified Primary - Foreign Key relation: {} {}", RelOptUtil.toString(join), analyzeJoinForPKFK);
        }
        return Double.valueOf(analyzeJoinForPKFK.fkInfo.rowCount * min);
    }

    public Double getRowCount(SemiJoin semiJoin, RelMetadataQuery relMetadataQuery) {
        PKFKRelationInfo analyzeJoinForPKFK = analyzeJoinForPKFK(semiJoin, relMetadataQuery);
        if (analyzeJoinForPKFK == null) {
            return super.getRowCount(semiJoin, relMetadataQuery);
        }
        double min = Math.min(1.0d, analyzeJoinForPKFK.pkInfo.selectivity * analyzeJoinForPKFK.ndvScalingFactor);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Identified Primary - Foreign Key relation: {} {}", RelOptUtil.toString(semiJoin), analyzeJoinForPKFK);
        }
        return Double.valueOf(analyzeJoinForPKFK.fkInfo.rowCount * min);
    }

    public Double getRowCount(Sort sort, RelMetadataQuery relMetadataQuery) {
        Double rowCount = relMetadataQuery.getRowCount(sort.getInput());
        if (rowCount != null && sort.fetch != null) {
            Double d = new Double((sort.offset == null ? 0 : RexLiteral.intValue(sort.offset)) + RexLiteral.intValue(sort.fetch));
            if (d.doubleValue() < rowCount.doubleValue()) {
                return d;
            }
        }
        return rowCount;
    }

    public static PKFKRelationInfo analyzeJoinForPKFK(Join join, RelMetadataQuery relMetadataQuery) {
        boolean z;
        RelNode relNode = (RelNode) join.getInputs().get(0);
        RelNode relNode2 = (RelNode) join.getInputs().get(1);
        List conjunctions = RelOptUtil.conjunctions(join.getCondition());
        if (conjunctions.isEmpty()) {
            return null;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList(conjunctions);
        if (join instanceof SemiJoin) {
            return null;
        }
        RelOptUtil.classifyFilters(join, arrayList3, join.getJoinType(), false, !join.getJoinType().generatesNullsOnRight(), !join.getJoinType().generatesNullsOnLeft(), arrayList3, arrayList, arrayList2);
        Pair<Integer, Integer> canHandleJoin = canHandleJoin(join, arrayList, arrayList2, arrayList3);
        if (canHandleJoin == null) {
            return null;
        }
        int intValue = ((Integer) canHandleJoin.left).intValue();
        int intValue2 = ((Integer) canHandleJoin.right).intValue();
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        RexNode composeConjunction = RexUtil.composeConjunction(rexBuilder, arrayList, true);
        RexNode composeConjunction2 = RexUtil.composeConjunction(rexBuilder, arrayList2, true);
        ImmutableBitSet of = ImmutableBitSet.of(new int[]{intValue});
        ImmutableBitSet of2 = ImmutableBitSet.of(new int[]{intValue2});
        boolean z2 = (join.getJoinType() == JoinRelType.INNER || join.getJoinType() == JoinRelType.RIGHT) && !(join instanceof SemiJoin) && isKey(of, relNode, relMetadataQuery);
        boolean z3 = (join.getJoinType() == JoinRelType.INNER || join.getJoinType() == JoinRelType.LEFT) && isKey(of2, relNode2, relMetadataQuery);
        if (!z2 && !z3) {
            return null;
        }
        double doubleValue = relMetadataQuery.getRowCount(relNode).doubleValue();
        double doubleValue2 = relMetadataQuery.getRowCount(relNode2).doubleValue();
        if (z2 && z3 && doubleValue2 < doubleValue) {
            z2 = false;
        }
        char c = z2 ? (char) 0 : z3 ? (char) 1 : (char) 65535;
        if (c != 65535) {
            z = IsSimpleTreeOnJoinKey.check(c == 0 ? relNode : relNode2, c == 0 ? intValue : intValue2, relMetadataQuery);
        } else {
            z = false;
        }
        boolean z4 = z;
        double doubleValue3 = z4 ? relMetadataQuery.getDistinctRowCount(relNode, of, composeConjunction).doubleValue() : -1.0d;
        double doubleValue4 = z4 ? relMetadataQuery.getDistinctRowCount(relNode2, of2, composeConjunction2).doubleValue() : -1.0d;
        double d = 1.0d;
        if (z4) {
            d = c == 0 ? doubleValue3 / doubleValue4 : doubleValue4 / doubleValue3;
        }
        if (c == 0) {
            return new PKFKRelationInfo(1, new FKSideInfo(doubleValue2, doubleValue4), new PKSideInfo(doubleValue, doubleValue3, join.getJoinType().generatesNullsOnRight() ? 1.0d : pkSelectivity(join, relMetadataQuery, true, relNode, doubleValue)), d, z4);
        }
        if (c == 1) {
            return new PKFKRelationInfo(1, new FKSideInfo(doubleValue, doubleValue3), new PKSideInfo(doubleValue2, doubleValue4, join.getJoinType().generatesNullsOnLeft() ? 1.0d : pkSelectivity(join, relMetadataQuery, false, relNode2, doubleValue2)), d, z4);
        }
        return null;
    }

    private static double pkSelectivity(Join join, RelMetadataQuery relMetadataQuery, boolean z, RelNode relNode, double d) {
        HiveTableScan tableScan;
        if (z && join.getJoinType().generatesNullsOnRight()) {
            return 1.0d;
        }
        if ((z || !join.getJoinType().generatesNullsOnLeft()) && (tableScan = HiveRelMdUniqueKeys.getTableScan(relNode, true)) != null) {
            return d / relMetadataQuery.getRowCount(tableScan).doubleValue();
        }
        return 1.0d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isKey(ImmutableBitSet immutableBitSet, RelNode relNode, RelMetadataQuery relMetadataQuery) {
        boolean z = false;
        Set uniqueKeys = relMetadataQuery.getUniqueKeys(relNode);
        if (uniqueKeys != null) {
            Iterator it = uniqueKeys.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (((ImmutableBitSet) it.next()).equals(immutableBitSet)) {
                    z = true;
                    break;
                }
            }
        }
        return z;
    }

    private static Pair<Integer, Integer> canHandleJoin(Join join, List<RexNode> list, List<RexNode> list2, List<RexNode> list3) {
        if (list3.size() != 1) {
            return null;
        }
        RexCall rexCall = (RexNode) list3.get(0);
        if (!(rexCall instanceof RexCall) || rexCall.getOperator() != SqlStdOperatorTable.EQUALS) {
            return null;
        }
        ImmutableBitSet bits = RelOptUtil.InputFinder.bits((RexNode) rexCall.getOperands().get(0));
        ImmutableBitSet bits2 = RelOptUtil.InputFinder.bits((RexNode) rexCall.getOperands().get(1));
        if (bits.cardinality() != 1 || bits2.cardinality() != 1) {
            return null;
        }
        int size = join.getLeft().getRowType().getFieldList().size();
        int size2 = join.getRight().getRowType().getFieldList().size();
        int size3 = join.getSystemFieldList().size();
        if (ImmutableBitSet.range(size3 + size, size3 + size + size2).contains(bits)) {
            bits = bits2;
            bits2 = bits;
        }
        return new Pair<>(Integer.valueOf(bits.nextSetBit(0) - size3), Integer.valueOf(bits2.nextSetBit(0) - (size3 + size)));
    }
}
