/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.d2.balancer.simple;

import com.linkedin.common.callback.Callback;
import com.linkedin.common.util.None;
import com.linkedin.d2.balancer.KeyMapper;
import com.linkedin.d2.balancer.LoadBalancer;
import com.linkedin.d2.balancer.LoadBalancerState;
import com.linkedin.d2.balancer.LoadBalancerStateItem;
import com.linkedin.d2.balancer.ServiceUnavailableException;
import com.linkedin.d2.balancer.clients.RewriteClient;
import com.linkedin.d2.balancer.clients.TrackerClient;
import com.linkedin.d2.balancer.properties.ClusterProperties;
import com.linkedin.d2.balancer.properties.PartitionData;
import com.linkedin.d2.balancer.properties.ServiceProperties;
import com.linkedin.d2.balancer.properties.UriProperties;
import com.linkedin.d2.balancer.strategies.LoadBalancerStrategy;
import com.linkedin.d2.balancer.util.ClientFactoryProvider;
import com.linkedin.d2.balancer.util.HostToKeyMapper;
import com.linkedin.d2.balancer.util.KeysAndHosts;
import com.linkedin.d2.balancer.util.LoadBalancerUtil;
import com.linkedin.d2.balancer.util.MapKeyResult;
import com.linkedin.d2.balancer.util.hashing.HashRingProvider;
import com.linkedin.d2.balancer.util.hashing.Ring;
import com.linkedin.d2.balancer.util.partitions.PartitionAccessException;
import com.linkedin.d2.balancer.util.partitions.PartitionAccessor;
import com.linkedin.d2.balancer.util.partitions.PartitionInfoProvider;
import com.linkedin.d2.discovery.event.PropertyEventThread;
import com.linkedin.d2.discovery.util.LogUtil;
import com.linkedin.d2.discovery.util.Stats;
import com.linkedin.r2.message.Request;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.transport.common.TransportClientFactory;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SimpleLoadBalancer
implements LoadBalancer,
HashRingProvider,
ClientFactoryProvider,
PartitionInfoProvider {
    private static final Logger _log = LoggerFactory.getLogger(SimpleLoadBalancer.class);
    private static final String D2_SCHEME_NAME = "d2";
    private final LoadBalancerState _state;
    private final Stats _serviceUnavailableStats;
    private final Stats _serviceAvailableStats;
    private final long _timeout;
    private final TimeUnit _unit;
    private final Random _random = new Random();

    public SimpleLoadBalancer(LoadBalancerState state) {
        this(state, new Stats(1000L), new Stats(1000L), 0L, TimeUnit.SECONDS);
    }

    public SimpleLoadBalancer(LoadBalancerState state, long timeout) {
        this(state, new Stats(1000L), new Stats(1000L), timeout, TimeUnit.MILLISECONDS);
    }

    public SimpleLoadBalancer(LoadBalancerState state, long timeout, TimeUnit unit) {
        this(state, new Stats(1000L), new Stats(1000L), timeout, unit);
    }

    public SimpleLoadBalancer(LoadBalancerState state, Stats serviceAvailableStats, Stats serviceUnavailableStats) {
        this(state, serviceAvailableStats, serviceUnavailableStats, 0L, TimeUnit.SECONDS);
    }

    public SimpleLoadBalancer(LoadBalancerState state, Stats serviceAvailableStats, Stats serviceUnavailableStats, long timeout, TimeUnit unit) {
        this._state = state;
        this._serviceUnavailableStats = serviceUnavailableStats;
        this._serviceAvailableStats = serviceAvailableStats;
        this._timeout = timeout;
        this._unit = unit;
    }

    public Stats getServiceUnavailableStats() {
        return this._serviceUnavailableStats;
    }

    public Stats getServiceAvailableStats() {
        return this._serviceAvailableStats;
    }

    @Override
    public void start(Callback<None> callback) {
        this._state.start(callback);
    }

    @Override
    public void shutdown(PropertyEventThread.PropertyEventShutdownCallback shutdown) {
        this._state.shutdown(shutdown);
    }

    @Override
    public TransportClient getClient(Request request, RequestContext requestContext) throws ServiceUnavailableException {
        RewriteClient client;
        URI uri = request.getURI();
        LogUtil.debug(_log, "get client for uri: ", uri);
        ServiceProperties service = this.listenToServiceAndCluster(uri);
        String serviceName = service.getServiceName();
        String clusterName = service.getClusterName();
        ClusterProperties cluster = this.getClusterProperties(serviceName, clusterName);
        URI targetService = LoadBalancerUtil.TargetHints.getRequestContextTargetService(requestContext);
        if (targetService == null) {
            LoadBalancerStateItem<UriProperties> uriItem = this.getUriItem(serviceName, clusterName, cluster);
            UriProperties uris = uriItem.getProperty();
            List<LoadBalancerState.SchemeStrategyPair> orderedStrategies = this._state.getStrategiesForService(serviceName, service.getPrioritizedSchemes());
            TrackerClient trackerClient = this.chooseTrackerClient(request, requestContext, serviceName, clusterName, cluster, uriItem, uris, orderedStrategies, service);
            String clusterAndServiceUriString = trackerClient.getUri() + service.getPath();
            client = new RewriteClient(serviceName, URI.create(clusterAndServiceUriString), trackerClient);
            this._serviceAvailableStats.inc();
        } else {
            _log.debug("service hint found, using generic client for target: {}", (Object)targetService);
            TransportClient transportClient = this._state.getClient(serviceName, targetService.getScheme());
            client = new RewriteClient(serviceName, targetService, transportClient);
        }
        return client;
    }

    @Override
    public <K> MapKeyResult<Ring<URI>, K> getRings(URI serviceUri, Iterable<K> keys) throws ServiceUnavailableException {
        ServiceProperties service = this.listenToServiceAndCluster(serviceUri);
        String serviceName = service.getServiceName();
        String clusterName = service.getClusterName();
        ClusterProperties cluster = this.getClusterProperties(serviceName, clusterName);
        LoadBalancerStateItem<UriProperties> uriItem = this.getUriItem(serviceName, clusterName, cluster);
        UriProperties uris = uriItem.getProperty();
        List<LoadBalancerState.SchemeStrategyPair> orderedStrategies = this._state.getStrategiesForService(serviceName, service.getPrioritizedSchemes());
        if (!orderedStrategies.isEmpty()) {
            LoadBalancerState.SchemeStrategyPair pair = orderedStrategies.get(0);
            PartitionAccessor accessor = this.getPartitionAccessor(serviceName, clusterName);
            HashMap<Integer, HashSet<K>> partitionSet = new HashMap<Integer, HashSet<K>>();
            ArrayList unmappedKeys = new ArrayList();
            for (K key : keys) {
                int partitionId;
                try {
                    partitionId = accessor.getPartitionId(key.toString());
                }
                catch (PartitionAccessException e) {
                    unmappedKeys.add(new MapKeyResult.UnmappedKey<K>(key, MapKeyResult.ErrorType.FAIL_TO_FIND_PARTITION));
                    continue;
                }
                HashSet<K> set = (HashSet<K>)partitionSet.get(partitionId);
                if (set == null) {
                    set = new HashSet<K>();
                    partitionSet.put(partitionId, set);
                }
                set.add(key);
            }
            IdentityHashMap ringMap = new IdentityHashMap(partitionSet.size() * 2);
            for (Map.Entry entry : partitionSet.entrySet()) {
                int partitionId = (Integer)entry.getKey();
                List<TrackerClient> clients = this.getPotentialClients(serviceName, service, uris, pair.getScheme(), partitionId);
                Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, clients);
                Object oldValue = ringMap.put(ring, entry.getValue());
                assert (oldValue == null);
            }
            return new MapKeyResult(ringMap, unmappedKeys);
        }
        throw new ServiceUnavailableException(serviceName, "Unable to find a load balancer strategy");
    }

    @Override
    public TransportClientFactory getClientFactory(String scheme) {
        return ((ClientFactoryProvider)((Object)this._state)).getClientFactory(scheme);
    }

    @Override
    public Map<Integer, Ring<URI>> getRings(URI serviceUri) throws ServiceUnavailableException {
        ServiceProperties service = this.listenToServiceAndCluster(serviceUri);
        String serviceName = service.getServiceName();
        String clusterName = service.getClusterName();
        ClusterProperties cluster = this.getClusterProperties(serviceName, clusterName);
        LoadBalancerStateItem<UriProperties> uriItem = this.getUriItem(serviceName, clusterName, cluster);
        UriProperties uris = uriItem.getProperty();
        List<LoadBalancerState.SchemeStrategyPair> orderedStrategies = this._state.getStrategiesForService(serviceName, service.getPrioritizedSchemes());
        if (!orderedStrategies.isEmpty()) {
            LoadBalancerState.SchemeStrategyPair pair = orderedStrategies.get(0);
            PartitionAccessor accessor = this.getPartitionAccessor(serviceName, clusterName);
            int maxPartitionId = accessor.getMaxPartitionId();
            HashMap<Integer, Ring<URI>> ringMap = new HashMap<Integer, Ring<URI>>((maxPartitionId + 1) * 2);
            for (int partitionId = 0; partitionId <= maxPartitionId; ++partitionId) {
                Set<URI> possibleUris = uris.getUriBySchemeAndPartition(pair.getScheme(), partitionId);
                List<TrackerClient> trackerClients = this.getPotentialClients(serviceName, service, possibleUris);
                Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, trackerClients);
                ringMap.put(partitionId, ring);
            }
            return ringMap;
        }
        throw new ServiceUnavailableException(serviceName, "Unable to find a load balancer strategy");
    }

    private void listenToService(String serviceName) throws ServiceUnavailableException {
        if (this._timeout > 0L) {
            CountDownLatch latch = new CountDownLatch(1);
            SimpleLoadBalancerCountDownCallback callback = new SimpleLoadBalancerCountDownCallback(latch){

                @Override
                public void done(int type, String name) {
                    super.done(type, name);
                }
            };
            this._state.listenToService(serviceName, callback);
            try {
                if (!latch.await(this._timeout, this._unit)) {
                    LogUtil.warn(_log, "timed out during wait while trying to add service: ", serviceName);
                }
            }
            catch (InterruptedException e) {
                _log.error("got interrupt while waiting for a service to be registered", (Throwable)e);
                this.die(serviceName, "got interrupt while waiting for a service to be registered");
            }
        } else {
            this._state.listenToService(serviceName, new LoadBalancerState.NullStateListenerCallback());
            _log.info("No timeout for service {}", (Object)serviceName);
        }
    }

    private void listenToCluster(String serviceName, String clusterName) throws ServiceUnavailableException {
        if (this._timeout > 0L) {
            CountDownLatch latch = new CountDownLatch(1);
            this._state.listenToCluster(clusterName, new SimpleLoadBalancerCountDownCallback(latch));
            try {
                if (!latch.await(this._timeout, this._unit)) {
                    LogUtil.warn(_log, "timed out during wait while trying to add cluster: ", clusterName);
                }
            }
            catch (InterruptedException e) {
                this.die(serviceName, "got interrupt while waiting for a cluster to be registered: " + clusterName);
            }
        } else {
            this._state.listenToCluster(clusterName, new LoadBalancerState.NullStateListenerCallback());
        }
    }

    private ServiceProperties listenToServiceAndCluster(URI uri) throws ServiceUnavailableException {
        if (!D2_SCHEME_NAME.equalsIgnoreCase(uri.getScheme())) {
            throw new IllegalArgumentException("Unsupported scheme in URI " + uri);
        }
        String serviceName = LoadBalancerUtil.getServiceNameFromUri(uri);
        ServiceProperties service = this.getLoadBalancedServiceProperties(serviceName);
        String clusterName = service.getClusterName();
        this.listenToCluster(serviceName, clusterName);
        return service;
    }

    private LoadBalancerStateItem<UriProperties> getUriItem(String serviceName, String clusterName, ClusterProperties cluster) throws ServiceUnavailableException {
        LoadBalancerStateItem<UriProperties> uriItem = this._state.getUriProperties(clusterName);
        if (uriItem == null || uriItem.getProperty() == null) {
            LogUtil.warn(_log, "unable to find uris: ", clusterName);
            this.die(serviceName, "no uri properties in lb state");
        }
        LogUtil.debug(_log, "got uris: ", cluster);
        return uriItem;
    }

    private ClusterProperties getClusterProperties(String serviceName, String clusterName) throws ServiceUnavailableException {
        LoadBalancerStateItem<ClusterProperties> clusterItem = this._state.getClusterProperties(clusterName);
        if (clusterItem == null || clusterItem.getProperty() == null) {
            LogUtil.warn(_log, "unable to find cluster: ", clusterName);
            this.die(serviceName, "no cluster properties in lb state");
        }
        return clusterItem.getProperty();
    }

    @Override
    public <K> HostToKeyMapper<K> getPartitionInformation(URI serviceUri, Collection<K> keys, int limitHostPerPartition, int hash) throws ServiceUnavailableException {
        if (limitHostPerPartition <= 0) {
            throw new IllegalArgumentException("limitHostPartition cannot be 0 or less");
        }
        ServiceProperties service = this.listenToServiceAndCluster(serviceUri);
        String serviceName = service.getServiceName();
        String clusterName = service.getClusterName();
        ClusterProperties cluster = this.getClusterProperties(serviceName, clusterName);
        LoadBalancerStateItem<UriProperties> uriItem = this.getUriItem(serviceName, clusterName, cluster);
        UriProperties uris = uriItem.getProperty();
        List<LoadBalancerState.SchemeStrategyPair> orderedStrategies = this._state.getStrategiesForService(serviceName, service.getPrioritizedSchemes());
        HashMap<Integer, Integer> partitionWithoutEnoughHost = new HashMap<Integer, Integer>();
        if (!orderedStrategies.isEmpty()) {
            PartitionAccessor accessor = this.getPartitionAccessor(serviceName, clusterName);
            int maxPartitionId = accessor.getMaxPartitionId();
            ArrayList unmappedKeys = new ArrayList();
            Map<Integer, Set<K>> partitionSet = this.getPartitionSet(keys, accessor, unmappedKeys);
            LoadBalancerState.SchemeStrategyPair pair = orderedStrategies.get(0);
            HashMap<Integer, KeysAndHosts<Integer>> partitionDataMap = new HashMap<Integer, KeysAndHosts<Integer>>();
            for (Integer partitionId : partitionSet.keySet()) {
                Set<URI> possibleUris = uris.getUriBySchemeAndPartition(pair.getScheme(), partitionId);
                List<TrackerClient> trackerClients = this.getPotentialClients(serviceName, service, possibleUris);
                int size = trackerClients.size() <= limitHostPerPartition ? trackerClients.size() : limitHostPerPartition;
                ArrayList<URI> rankedUri = new ArrayList<URI>(size);
                Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, trackerClients);
                Iterator<URI> iterator = ring.getIterator(hash);
                while (iterator.hasNext() && rankedUri.size() < size) {
                    URI uri = iterator.next();
                    if (rankedUri.contains(uri)) continue;
                    rankedUri.add(uri);
                }
                if (rankedUri.size() < limitHostPerPartition) {
                    partitionWithoutEnoughHost.put(partitionId, limitHostPerPartition - rankedUri.size());
                }
                KeysAndHosts keysAndHosts = new KeysAndHosts((Collection)partitionSet.get(partitionId), rankedUri);
                partitionDataMap.put(partitionId, keysAndHosts);
            }
            return new HostToKeyMapper(unmappedKeys, partitionDataMap, limitHostPerPartition, maxPartitionId + 1, partitionWithoutEnoughHost);
        }
        throw new ServiceUnavailableException(serviceName, "Unable to find a load balancer strategy");
    }

    private <K> Map<Integer, Set<K>> getPartitionSet(Collection<K> keys, PartitionAccessor accessor, Collection<K> unmappedKeys) {
        TreeMap<Integer, Set<Integer>> partitionSet = new TreeMap<Integer, Set<Integer>>();
        if (keys == null) {
            for (int i = 0; i <= accessor.getMaxPartitionId(); ++i) {
                partitionSet.put(i, new HashSet());
            }
        } else {
            for (K key : keys) {
                int partitionId;
                try {
                    partitionId = accessor.getPartitionId(key.toString());
                }
                catch (PartitionAccessException e) {
                    unmappedKeys.add(key);
                    continue;
                }
                HashSet<K> set = (HashSet<K>)partitionSet.get(partitionId);
                if (set == null) {
                    set = new HashSet<K>();
                    partitionSet.put(partitionId, set);
                }
                set.add(key);
            }
        }
        return partitionSet;
    }

    @Override
    public PartitionAccessor getPartitionAccessor(URI serviceUri) throws ServiceUnavailableException {
        ServiceProperties service = this.listenToServiceAndCluster(serviceUri);
        String serviceName = service.getServiceName();
        String clusterName = service.getClusterName();
        return this.getPartitionAccessor(serviceName, clusterName);
    }

    private PartitionAccessor getPartitionAccessor(String serviceName, String clusterName) throws ServiceUnavailableException {
        LoadBalancerStateItem<PartitionAccessor> partitionAccessorItem = this._state.getPartitionAccessor(clusterName);
        if (partitionAccessorItem == null || partitionAccessorItem.getProperty() == null) {
            LogUtil.warn(_log, "unable to find partition accessor for cluster: ", clusterName);
            this.die(serviceName, "No partition accessor available for cluster: " + clusterName);
        }
        return partitionAccessorItem.getProperty();
    }

    @Override
    public ServiceProperties getLoadBalancedServiceProperties(String serviceName) throws ServiceUnavailableException {
        this.listenToService(serviceName);
        LoadBalancerStateItem<ServiceProperties> serviceItem = this._state.getServiceProperties(serviceName);
        if (serviceItem == null || serviceItem.getProperty() == null) {
            LogUtil.warn(_log, "unable to find service: ", serviceName);
            this.die(serviceName, "no service properties in lb state");
        }
        LogUtil.debug(_log, "got service: ", serviceItem);
        return serviceItem.getProperty();
    }

    private List<TrackerClient> getPotentialClients(String serviceName, ServiceProperties serviceProperties, UriProperties uris, String scheme, int partitionId) {
        Set<URI> possibleUris = uris.getUriBySchemeAndPartition(scheme, partitionId);
        List<TrackerClient> clientsToBalance = this.getPotentialClients(serviceName, serviceProperties, possibleUris);
        if (clientsToBalance.isEmpty()) {
            LogUtil.info(_log, "Can not find a host for service: ", serviceName, ", scheme: ", scheme, ", partition: ", partitionId);
        }
        return clientsToBalance;
    }

    private List<TrackerClient> getPotentialClients(String serviceName, ServiceProperties serviceProperties, Set<URI> possibleUris) {
        ArrayList<TrackerClient> clientsToLoadBalance = new ArrayList<TrackerClient>();
        if (possibleUris != null) {
            for (URI possibleUri : possibleUris) {
                if (!serviceProperties.isBanned(possibleUri)) {
                    TrackerClient possibleTrackerClient = this._state.getClient(serviceName, possibleUri);
                    if (possibleTrackerClient == null) continue;
                    clientsToLoadBalance.add(possibleTrackerClient);
                    continue;
                }
                LogUtil.warn(_log, "skipping banned uri: ", possibleUri);
            }
        }
        LogUtil.debug(_log, "got clients to load balancer for ", serviceName, ": ", clientsToLoadBalance);
        return clientsToLoadBalance;
    }

    private TrackerClient chooseTrackerClient(Request request, RequestContext requestContext, String serviceName, String clusterName, ClusterProperties cluster, LoadBalancerStateItem<UriProperties> uriItem, UriProperties uris, List<LoadBalancerState.SchemeStrategyPair> orderedStrategies, ServiceProperties serviceProperties) throws ServiceUnavailableException {
        TrackerClient trackerClient = null;
        URI targetHost = KeyMapper.TargetHostHints.getRequestContextTargetHost(requestContext);
        int partitionId = -1;
        URI requestUri = request.getURI();
        if (targetHost == null) {
            PartitionAccessor accessor = this.getPartitionAccessor(serviceName, clusterName);
            try {
                partitionId = accessor.getPartitionId(requestUri);
            }
            catch (PartitionAccessException e) {
                this.die(serviceName, "Error in finding the partition for URI: " + requestUri + ", " + e.getMessage());
            }
        } else {
            Map<Integer, PartitionData> partitionDataMap = uris.getPartitionDataMap(targetHost);
            if (partitionDataMap == null || partitionDataMap.isEmpty()) {
                this.die(serviceName, "There is no partition data for server host: " + targetHost + ". URI: " + requestUri);
            }
            Set<Integer> partitions = partitionDataMap.keySet();
            Iterator<Integer> iterator = partitions.iterator();
            int index = this._random.nextInt(partitions.size());
            for (int i = 0; i <= index; ++i) {
                partitionId = iterator.next();
            }
        }
        List<TrackerClient> clientsToLoadBalance = null;
        for (LoadBalancerState.SchemeStrategyPair pair : orderedStrategies) {
            LoadBalancerStrategy strategy = pair.getStrategy();
            String scheme = pair.getScheme();
            clientsToLoadBalance = this.getPotentialClients(serviceName, serviceProperties, uris, scheme, partitionId);
            trackerClient = strategy.getTrackerClient(request, requestContext, uriItem.getVersion(), partitionId, clientsToLoadBalance);
            LogUtil.debug(_log, "load balancer strategy for ", serviceName, " returned: ", trackerClient);
            if (trackerClient == null) continue;
            break;
        }
        if (trackerClient == null) {
            if (clientsToLoadBalance == null || clientsToLoadBalance.isEmpty()) {
                this.die(serviceName, "Service: " + serviceName + " unable to find a host to route the request in partition: " + partitionId + " cluster: " + clusterName + ". Check what cluster your servers are announcing to.");
            } else {
                this.die(serviceName, "Service: " + serviceName + " is in a bad state (high latency/high error). Dropping request. Cluster: " + clusterName + ", partitionId:" + partitionId + " (" + clientsToLoadBalance.size() + " hosts)");
            }
        }
        return trackerClient;
    }

    private void die(String serviceName, String message) throws ServiceUnavailableException {
        this._serviceUnavailableStats.inc();
        throw new ServiceUnavailableException(serviceName, message);
    }

    public static class SimpleLoadBalancerCountDownCallback
    implements LoadBalancerState.LoadBalancerStateListenerCallback {
        private CountDownLatch _latch;

        public SimpleLoadBalancerCountDownCallback(CountDownLatch latch) {
            this._latch = latch;
        }

        @Override
        public void done(int type, String name) {
            this._latch.countDown();
        }
    }
}

