/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.infra.rewrite.engine;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import lombok.Generated;
import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.type.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.unit.StorageUnit;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.rewrite.context.SQLRewriteContext;
import org.apache.shardingsphere.infra.rewrite.engine.result.RouteSQLRewriteResult;
import org.apache.shardingsphere.infra.rewrite.engine.result.SQLRewriteUnit;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.ParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.GroupedParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.parameter.builder.impl.StandardParameterBuilder;
import org.apache.shardingsphere.infra.rewrite.sql.SQLBuilderEngine;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.statement.core.util.SQLUtils;
import org.apache.shardingsphere.sqltranslator.context.SQLTranslatorContext;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;

public final class RouteSQLRewriteEngine {
    private final SQLTranslatorRule translatorRule;
    private final ShardingSphereDatabase database;
    private final RuleMetaData globalRuleMetaData;

    public RouteSQLRewriteResult rewrite(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, QueryContext queryContext) {
        int maxUnionSizePerDataSource = (Integer)queryContext.getMetaData().getProps().getValue((Enum)ConfigurationPropertyKey.MAX_UNION_SIZE_PER_DATASOURCE);
        return new RouteSQLRewriteResult(this.translate(queryContext, this.createSQLRewriteUnits(sqlRewriteContext, routeContext, maxUnionSizePerDataSource)));
    }

    private Map<RouteUnit, SQLRewriteUnit> createSQLRewriteUnits(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, int maxUnionSizePerDataSource) {
        LinkedHashMap<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<RouteUnit, SQLRewriteUnit>(routeContext.getRouteUnits().size(), 1.0f);
        for (Map.Entry<String, List<RouteUnit>> entry : this.aggregateRouteUnitGroups(routeContext.getRouteUnits()).entrySet()) {
            List<RouteUnit> routeUnits = entry.getValue();
            if (this.isNeedAggregateRewrite(sqlRewriteContext.getSqlStatementContext(), routeUnits, maxUnionSizePerDataSource)) {
                this.createAggregatedRewriteUnits(sqlRewriteContext, routeContext, routeUnits, maxUnionSizePerDataSource, result);
                continue;
            }
            for (RouteUnit each : routeUnits) {
                result.put(each, this.createSQLRewriteUnit(sqlRewriteContext, routeContext, each));
            }
        }
        return result;
    }

    private Map<String, List<RouteUnit>> aggregateRouteUnitGroups(Collection<RouteUnit> routeUnits) {
        LinkedHashMap<String, List<RouteUnit>> result = new LinkedHashMap<String, List<RouteUnit>>(routeUnits.size(), 1.0f);
        for (RouteUnit each : routeUnits) {
            result.computeIfAbsent(each.getDataSourceMapper().getActualName(), unused -> new ArrayList()).add(each);
        }
        return result;
    }

    private boolean isNeedAggregateRewrite(SQLStatementContext sqlStatementContext, Collection<RouteUnit> routeUnits, int maxUnionSizePerDataSource) {
        if (!(sqlStatementContext instanceof SelectStatementContext) || 1 == routeUnits.size() || 1 == maxUnionSizePerDataSource) {
            return false;
        }
        SelectStatementContext statementContext = (SelectStatementContext)sqlStatementContext;
        if (statementContext.getProjectionsContext().isDistinctRow()) {
            statementContext.setNeedAggregateRewrite(false);
            return false;
        }
        boolean containsSubqueryJoinQuery = statementContext.isContainsSubquery() || statementContext.isContainsJoinQuery();
        boolean containsOrderByLimitClause = !statementContext.getOrderByContext().getItems().isEmpty() || statementContext.getPaginationContext().isHasPagination();
        boolean containsLockClause = statementContext.getSqlStatement().getLock().isPresent();
        boolean result = !containsSubqueryJoinQuery && !containsOrderByLimitClause && !containsLockClause;
        statementContext.setNeedAggregateRewrite(result);
        return result;
    }

    private void createAggregatedRewriteUnits(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, List<RouteUnit> routeUnits, int maxUnionSizePerDataSource, Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
        if (routeUnits.size() <= maxUnionSizePerDataSource) {
            sqlRewriteUnits.put(routeUnits.get(ThreadLocalRandom.current().nextInt(routeUnits.size())), this.createSQLRewriteUnit(sqlRewriteContext, routeContext, routeUnits));
        } else {
            for (List<RouteUnit> batch : this.partitionRouteUnits(routeUnits, maxUnionSizePerDataSource)) {
                sqlRewriteUnits.put(batch.get(ThreadLocalRandom.current().nextInt(batch.size())), this.createSQLRewriteUnit(sqlRewriteContext, routeContext, batch));
            }
        }
    }

    private List<List<RouteUnit>> partitionRouteUnits(List<RouteUnit> routeUnits, int batchSize) {
        ArrayList<List<RouteUnit>> result = new ArrayList<List<RouteUnit>>();
        for (int i = 0; i < routeUnits.size(); i += batchSize) {
            result.add(routeUnits.subList(i, Math.min(i + batchSize, routeUnits.size())));
        }
        return result;
    }

    private SQLRewriteUnit createSQLRewriteUnit(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, Collection<RouteUnit> routeUnits) {
        LinkedList<String> sql = new LinkedList<String>();
        LinkedList<Object> params = new LinkedList<Object>();
        boolean containsDollarMarker = sqlRewriteContext.getSqlStatementContext() instanceof SelectStatementContext && ((SelectStatementContext)sqlRewriteContext.getSqlStatementContext()).isContainsDollarParameterMarker();
        for (RouteUnit each : routeUnits) {
            sql.add(SQLUtils.trimSemicolon((String)new SQLBuilderEngine(sqlRewriteContext, each).buildSQL()));
            if (containsDollarMarker && !params.isEmpty()) continue;
            params.addAll(this.getParameters(sqlRewriteContext, routeContext, each));
        }
        return new SQLRewriteUnit(String.join((CharSequence)" UNION ALL ", sql), params);
    }

    private SQLRewriteUnit createSQLRewriteUnit(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, RouteUnit routeUnit) {
        return new SQLRewriteUnit(this.getActualSQL(sqlRewriteContext, routeUnit), this.getParameters(sqlRewriteContext, routeContext, routeUnit));
    }

    private String getActualSQL(SQLRewriteContext sqlRewriteContext, RouteUnit routeUnit) {
        return new SQLBuilderEngine(sqlRewriteContext, routeUnit).buildSQL();
    }

    private List<Object> getParameters(SQLRewriteContext sqlRewriteContext, RouteContext routeContext, RouteUnit routeUnit) {
        if (sqlRewriteContext.getParameters().isEmpty()) {
            return Collections.emptyList();
        }
        ParameterBuilder parameterBuilder = sqlRewriteContext.getParameterBuilder();
        if (parameterBuilder instanceof StandardParameterBuilder) {
            return parameterBuilder.getParameters();
        }
        return routeContext.getOriginalDataNodes().isEmpty() ? ((GroupedParameterBuilder)parameterBuilder).getParameters() : this.buildRouteParameters((GroupedParameterBuilder)parameterBuilder, routeContext, routeUnit);
    }

    private List<Object> buildRouteParameters(GroupedParameterBuilder paramBuilder, RouteContext routeContext, RouteUnit routeUnit) {
        LinkedList<Object> result = new LinkedList<Object>(paramBuilder.getBeforeGenericParameterBuilder().getParameters());
        int count = 0;
        for (Collection each : routeContext.getOriginalDataNodes()) {
            if (this.isInSameDataNode(each, routeUnit)) {
                result.addAll(paramBuilder.getParameters(count));
            }
            ++count;
        }
        result.addAll(paramBuilder.getAfterGenericParameterBuilder().getParameters());
        return result;
    }

    private boolean isInSameDataNode(Collection<DataNode> dataNodes, RouteUnit routeUnit) {
        if (dataNodes.isEmpty()) {
            return true;
        }
        for (DataNode each : dataNodes) {
            if (!routeUnit.findTableMapper(each.getDataSourceName(), each.getTableName()).isPresent()) continue;
            return true;
        }
        return false;
    }

    private Map<RouteUnit, SQLRewriteUnit> translate(QueryContext queryContext, Map<RouteUnit, SQLRewriteUnit> sqlRewriteUnits) {
        LinkedHashMap<RouteUnit, SQLRewriteUnit> result = new LinkedHashMap<RouteUnit, SQLRewriteUnit>(sqlRewriteUnits.size(), 1.0f);
        Map storageUnits = this.database.getResourceMetaData().getStorageUnits();
        for (Map.Entry<RouteUnit, SQLRewriteUnit> entry : sqlRewriteUnits.entrySet()) {
            List parameters;
            DatabaseType storageType = ((StorageUnit)storageUnits.get(entry.getKey().getDataSourceMapper().getActualName())).getStorageType();
            String sql = entry.getValue().getSql();
            Optional sqlTranslatorContext = this.translatorRule.translate(sql, parameters = entry.getValue().getParameters(), queryContext, storageType, this.database, this.globalRuleMetaData);
            String translatedSQL = sqlTranslatorContext.isPresent() ? ((SQLTranslatorContext)sqlTranslatorContext.get()).getSql() : sql;
            List translatedParameters = sqlTranslatorContext.isPresent() ? ((SQLTranslatorContext)sqlTranslatorContext.get()).getParameters() : parameters;
            SQLRewriteUnit sqlRewriteUnit = new SQLRewriteUnit(translatedSQL, translatedParameters);
            result.put(entry.getKey(), sqlRewriteUnit);
        }
        return result;
    }

    @Generated
    public RouteSQLRewriteEngine(SQLTranslatorRule translatorRule, ShardingSphereDatabase database, RuleMetaData globalRuleMetaData) {
        this.translatorRule = translatorRule;
        this.database = database;
        this.globalRuleMetaData = globalRuleMetaData;
    }
}

