/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.sql.calcite.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlBinaryOperator;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.JoinAlgorithm;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.QueryUtils;
import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
import org.apache.druid.sql.calcite.rel.DruidQueryRel;
import org.apache.druid.sql.calcite.rel.DruidRel;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;

public class DruidJoinRule
extends RelOptRule {
    private final boolean enableLeftScanDirect;
    private final PlannerContext plannerContext;

    private DruidJoinRule(PlannerContext plannerContext) {
        super(DruidJoinRule.operand(Join.class, (RelOptRuleOperand)DruidJoinRule.operand(DruidRel.class, (RelOptRuleOperandChildren)DruidJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{DruidJoinRule.operand(DruidRel.class, (RelOptRuleOperandChildren)DruidJoinRule.any())}));
        this.enableLeftScanDirect = plannerContext.queryContext().getEnableJoinLeftScanDirect();
        this.plannerContext = plannerContext;
    }

    public static DruidJoinRule instance(PlannerContext plannerContext) {
        return new DruidJoinRule(plannerContext);
    }

    public boolean matches(RelOptRuleCall call) {
        Join join = (Join)call.rel(0);
        DruidRel left = (DruidRel)call.rel(1);
        DruidRel right = (DruidRel)call.rel(2);
        return this.canHandleCondition(join.getCondition(), join.getLeft().getRowType(), right, join.getJoinType(), join.getSystemFieldList(), join.getCluster().getRexBuilder()) && left.getPartialDruidQuery() != null && right.getPartialDruidQuery() != null;
    }

    public void onMatch(RelOptRuleCall call) {
        DruidRel newRight;
        DruidRel newLeft;
        Filter leftFilter;
        Join join = (Join)call.rel(0);
        DruidRel left = (DruidRel)call.rel(1);
        DruidRel right = (DruidRel)call.rel(2);
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ArrayList<Object> newProjectExprs = new ArrayList<Object>();
        ConditionAnalysis conditionAnalysis = DruidJoinRule.analyzeCondition(join.getCondition(), join.getLeft().getRowType(), rexBuilder);
        this.plannerContext.setPlanningError(conditionAnalysis.errorStr, new Object[0]);
        boolean isLeftDirectAccessPossible = this.enableLeftScanDirect && left instanceof DruidQueryRel;
        JoinAlgorithm joinAlgorithm = QueryUtils.getJoinAlgorithm(join, this.plannerContext);
        if (!joinAlgorithm.requiresSubquery() && left.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT && (isLeftDirectAccessPossible || left.getPartialDruidQuery().getWhereFilter() == null)) {
            RelNode leftScan = left.getPartialDruidQuery().getScan();
            Project leftProject = left.getPartialDruidQuery().getSelectProject();
            leftFilter = left.getPartialDruidQuery().getWhereFilter();
            newProjectExprs.addAll(leftProject.getProjects());
            newLeft = left.withPartialQuery(PartialDruidQuery.create(leftScan));
            conditionAnalysis = conditionAnalysis.pushThroughLeftProject(leftProject);
        } else {
            for (int i = 0; i < left.getRowType().getFieldCount(); ++i) {
                newProjectExprs.add(rexBuilder.makeInputRef(((RelDataTypeField)join.getRowType().getFieldList().get(i)).getType(), i));
            }
            newLeft = left;
            leftFilter = null;
        }
        if (!joinAlgorithm.requiresSubquery() && right.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT && right.getPartialDruidQuery().getWhereFilter() == null && !right.getPartialDruidQuery().getSelectProject().isMapping() && conditionAnalysis.onlyUsesMappingsFromRightProject(right.getPartialDruidQuery().getSelectProject())) {
            RelNode rightScan = right.getPartialDruidQuery().getScan();
            Project rightProject = right.getPartialDruidQuery().getSelectProject();
            for (RexNode rexNode : RexUtil.shift((Iterable)rightProject.getProjects(), (int)newLeft.getRowType().getFieldCount())) {
                if (join.getJoinType().generatesNullsOnRight()) {
                    newProjectExprs.add(DruidJoinRule.makeNullableIfLiteral(rexNode, rexBuilder));
                    continue;
                }
                newProjectExprs.add(rexNode);
            }
            newRight = right.withPartialQuery(PartialDruidQuery.create(rightScan));
            conditionAnalysis = conditionAnalysis.pushThroughRightProject(rightProject);
        } else {
            for (int i = 0; i < right.getRowType().getFieldCount(); ++i) {
                newProjectExprs.add(rexBuilder.makeInputRef(((RelDataTypeField)join.getRowType().getFieldList().get(left.getRowType().getFieldCount() + i)).getType(), newLeft.getRowType().getFieldCount() + i));
            }
            newRight = right;
        }
        DruidJoinQueryRel druidJoin = DruidJoinQueryRel.create(join.copy(join.getTraitSet(), conditionAnalysis.getConditionWithUnsupportedSubConditionsIgnored(rexBuilder), (RelNode)newLeft, (RelNode)newRight, join.getJoinType(), join.isSemiJoinDone()), leftFilter, left.getPlannerContext());
        RelBuilder relBuilder = call.builder().push((RelNode)druidJoin).project((Iterable)RexUtil.fixUp((RexBuilder)rexBuilder, newProjectExprs, (List)RelOptUtil.getFieldTypeList((RelDataType)druidJoin.getRowType())));
        RexNode postJoinFilter = RexUtil.composeConjunction((RexBuilder)rexBuilder, conditionAnalysis.getUnsupportedOnSubConditions(), (boolean)true);
        if (postJoinFilter != null) {
            relBuilder = relBuilder.filter(new RexNode[]{postJoinFilter});
        }
        relBuilder.convert(join.getRowType(), false);
        call.transformTo(relBuilder.build());
    }

    private static RexNode makeNullableIfLiteral(RexNode rexNode, RexBuilder rexBuilder) {
        if (rexNode.isA(SqlKind.LITERAL)) {
            return rexBuilder.makeLiteral((Object)RexLiteral.value((RexNode)rexNode), rexBuilder.getTypeFactory().createTypeWithNullability(rexNode.getType(), true), true);
        }
        return rexNode;
    }

    @VisibleForTesting
    public boolean canHandleCondition(RexNode condition, RelDataType leftRowType, DruidRel<?> right, JoinRelType joinType, List<RelDataTypeField> systemFieldList, RexBuilder rexBuilder) {
        long distinctRightColumns;
        DruidQueryRel druidQueryRel;
        ConditionAnalysis conditionAnalysis = DruidJoinRule.analyzeCondition(condition, leftRowType, rexBuilder);
        this.plannerContext.setPlanningError(conditionAnalysis.errorStr, new Object[0]);
        if (right != null && !DruidJoinQueryRel.computeRightRequiresSubquery(this.plannerContext, DruidJoinQueryRel.getSomeDruidChild(right)) && right instanceof DruidQueryRel && (druidQueryRel = (DruidQueryRel)right).getDruidTable().getDataSource() instanceof LookupDataSource && (distinctRightColumns = conditionAnalysis.rightColumns.stream().map(RexSlot::getIndex).distinct().count()) > 1L) {
            this.plannerContext.setPlanningError("SQL is resulting in a join involving lookup where value column is used in the condition.", new Object[0]);
            return false;
        }
        if (joinType != JoinRelType.INNER || !systemFieldList.isEmpty()) {
            return conditionAnalysis.getUnsupportedOnSubConditions().isEmpty();
        }
        return true;
    }

    public static ConditionAnalysis analyzeCondition(RexNode condition, RelDataType leftRowType, RexBuilder rexBuilder) {
        List<RexNode> subConditions = DruidJoinRule.decomposeAnd(condition);
        ArrayList<RexEquality> equalitySubConditions = new ArrayList<RexEquality>();
        ArrayList<RexLiteral> literalSubConditions = new ArrayList<RexLiteral>();
        ArrayList<RexNode> unSupportedSubConditions = new ArrayList<RexNode>();
        HashSet<RexInputRef> rightColumns = new HashSet<RexInputRef>();
        int numLeftFields = leftRowType.getFieldCount();
        ArrayList<String> errors = new ArrayList<String>();
        for (RexNode subCondition : subConditions) {
            SqlKind comparisonKind;
            RexNode secondOperand;
            RexLiteral firstOperand;
            if (RexUtil.isLiteral((RexNode)subCondition, (boolean)true)) {
                if (subCondition.isA(SqlKind.CAST)) {
                    RexCall call = (RexCall)subCondition;
                    if (call.getType().getSqlTypeName().equals((Object)((RexNode)call.getOperands().get(0)).getType().getSqlTypeName())) {
                        literalSubConditions.add((RexLiteral)call.getOperands().get(0));
                        continue;
                    }
                    unSupportedSubConditions.add(subCondition);
                    continue;
                }
                literalSubConditions.add((RexLiteral)subCondition);
                continue;
            }
            if (subCondition.isA(SqlKind.INPUT_REF)) {
                firstOperand = rexBuilder.makeLiteral(true);
                secondOperand = subCondition;
                comparisonKind = SqlKind.EQUALS;
                if (!SqlTypeName.BOOLEAN_TYPES.contains(secondOperand.getType().getSqlTypeName())) {
                    errors.add(StringUtils.format((String)"SQL requires a join with '%s' condition where the column is of the type %s, that is not supported", (Object[])new Object[]{subCondition.getKind(), secondOperand.getType().getSqlTypeName()}));
                    unSupportedSubConditions.add(subCondition);
                    continue;
                }
            } else if (subCondition.isA(SqlKind.EQUALS) || subCondition.isA(SqlKind.IS_NOT_DISTINCT_FROM)) {
                List operands = ((RexCall)subCondition).getOperands();
                Preconditions.checkState((operands.size() == 2 ? 1 : 0) != 0, (String)"Expected 2 operands, got[%s]", (int)operands.size());
                firstOperand = (RexNode)operands.get(0);
                secondOperand = (RexNode)operands.get(1);
                comparisonKind = subCondition.getKind();
            } else {
                errors.add(StringUtils.format((String)"SQL requires a join with '%s' condition that is not supported.", (Object[])new Object[]{subCondition.getKind()}));
                unSupportedSubConditions.add(subCondition);
                continue;
            }
            if (DruidJoinRule.isLeftExpression((RexNode)firstOperand, numLeftFields) && DruidJoinRule.isRightInputRef(secondOperand, numLeftFields)) {
                equalitySubConditions.add(new RexEquality((RexNode)firstOperand, (RexInputRef)secondOperand, comparisonKind));
                rightColumns.add((RexInputRef)secondOperand);
                continue;
            }
            if (DruidJoinRule.isRightInputRef((RexNode)firstOperand, numLeftFields) && DruidJoinRule.isLeftExpression(secondOperand, numLeftFields)) {
                equalitySubConditions.add(new RexEquality(secondOperand, (RexInputRef)firstOperand, subCondition.getKind()));
                rightColumns.add((RexInputRef)firstOperand);
                continue;
            }
            errors.add(StringUtils.format((String)"SQL is resulting in a join that has unsupported operand types.", (Object[])new Object[0]));
            unSupportedSubConditions.add(subCondition);
        }
        String errorStr = errors.size() > 0 ? Joiner.on((char)'\n').join(errors) : null;
        return new ConditionAnalysis(numLeftFields, equalitySubConditions, literalSubConditions, unSupportedSubConditions, rightColumns, errorStr);
    }

    @VisibleForTesting
    static List<RexNode> decomposeAnd(RexNode condition) {
        ArrayList<RexNode> retVal = new ArrayList<RexNode>();
        Stack<RexNode> stack = new Stack<RexNode>();
        stack.push(condition);
        while (!stack.empty()) {
            RexNode current = (RexNode)stack.pop();
            if (current.isA(SqlKind.AND)) {
                List operands = ((RexCall)current).getOperands();
                for (int i = operands.size() - 1; i >= 0; --i) {
                    stack.push((RexNode)operands.get(i));
                }
                continue;
            }
            retVal.add(current);
        }
        return retVal;
    }

    private static boolean isLeftExpression(RexNode rexNode, int numLeftFields) {
        return ImmutableBitSet.range((int)numLeftFields).contains(RelOptUtil.InputFinder.bits((RexNode)rexNode));
    }

    private static boolean isRightInputRef(RexNode rexNode, int numLeftFields) {
        return rexNode.isA(SqlKind.INPUT_REF) && ((RexInputRef)rexNode).getIndex() >= numLeftFields;
    }

    static class RexEquality {
        private final RexNode left;
        private final RexInputRef right;
        private final SqlKind kind;

        public RexEquality(RexNode left, RexInputRef right, SqlKind kind) {
            this.left = left;
            this.right = right;
            this.kind = kind;
        }

        public RexNode makeCall(RexBuilder builder) {
            SqlBinaryOperator operator;
            if (this.kind == SqlKind.EQUALS) {
                operator = SqlStdOperatorTable.EQUALS;
            } else if (this.kind == SqlKind.IS_NOT_DISTINCT_FROM) {
                operator = SqlStdOperatorTable.IS_NOT_DISTINCT_FROM;
            } else {
                throw DruidException.defensive((String)"Unexpected operator kind[%s]", (Object[])new Object[]{this.kind});
            }
            return builder.makeCall((SqlOperator)operator, new RexNode[]{this.left, this.right});
        }

        public String toString() {
            return "RexEquality{left=" + this.left + ", right=" + this.right + ", kind=" + this.kind + "}";
        }
    }

    public static class ConditionAnalysis {
        private final int numLeftFields;
        private final List<RexEquality> equalitySubConditions;
        private final List<RexLiteral> literalSubConditions;
        private final List<RexNode> unsupportedOnSubConditions;
        private final Set<RexInputRef> rightColumns;
        public final String errorStr;

        ConditionAnalysis(int numLeftFields, List<RexEquality> equalitySubConditions, List<RexLiteral> literalSubConditions, List<RexNode> unsupportedOnSubConditions, Set<RexInputRef> rightColumns, String errorStr) {
            this.numLeftFields = numLeftFields;
            this.equalitySubConditions = equalitySubConditions;
            this.literalSubConditions = literalSubConditions;
            this.unsupportedOnSubConditions = unsupportedOnSubConditions;
            this.rightColumns = rightColumns;
            this.errorStr = errorStr;
        }

        public ConditionAnalysis pushThroughLeftProject(Project leftProject) {
            int rhsShift = leftProject.getInput().getRowType().getFieldCount() - leftProject.getRowType().getFieldCount();
            return new ConditionAnalysis(leftProject.getInput().getRowType().getFieldCount(), this.equalitySubConditions.stream().map(equality -> new RexEquality(RelOptUtil.pushPastProject((RexNode)equality.left, (Project)leftProject), (RexInputRef)RexUtil.shift((RexNode)equality.right, (int)rhsShift), equality.kind)).collect(Collectors.toList()), this.literalSubConditions, this.unsupportedOnSubConditions, this.rightColumns, null);
        }

        public ConditionAnalysis pushThroughRightProject(Project rightProject) {
            Preconditions.checkArgument((boolean)this.onlyUsesMappingsFromRightProject(rightProject), (Object)"Cannot push through");
            return new ConditionAnalysis(this.numLeftFields, this.equalitySubConditions.stream().map(equality -> new RexEquality(equality.left, (RexInputRef)RexUtil.shift((RexNode)RelOptUtil.pushPastProject((RexNode)RexUtil.shift((RexNode)equality.right, (int)(-this.numLeftFields)), (Project)rightProject), (int)this.numLeftFields), equality.kind)).collect(Collectors.toList()), this.literalSubConditions, this.unsupportedOnSubConditions, this.rightColumns, null);
        }

        public boolean onlyUsesMappingsFromRightProject(Project rightProject) {
            for (RexEquality equality : this.equalitySubConditions) {
                int rightIndex = equality.right.getIndex() - this.numLeftFields;
                if (((RexNode)rightProject.getProjects().get(rightIndex)).isA(SqlKind.INPUT_REF)) continue;
                return false;
            }
            return true;
        }

        public RexNode getConditionWithUnsupportedSubConditionsIgnored(RexBuilder rexBuilder) {
            return RexUtil.composeConjunction((RexBuilder)rexBuilder, (Iterable)Iterables.concat(this.literalSubConditions, (Iterable)this.equalitySubConditions.stream().map(equality -> equality.makeCall(rexBuilder)).collect(Collectors.toList())), (boolean)false);
        }

        public List<RexNode> getUnsupportedOnSubConditions() {
            return this.unsupportedOnSubConditions;
        }

        public String toString() {
            return "ConditionAnalysis{numLeftFields=" + this.numLeftFields + ", equalitySubConditions=" + this.equalitySubConditions + ", literalSubConditions=" + this.literalSubConditions + ", unsupportedSubConditions=" + this.unsupportedOnSubConditions + ", rightColumns=" + this.rightColumns + "}";
        }
    }
}

