package org.apache.hadoop.hive.ql.exec.tez;

import io.trino.hive.$internal.com.google.common.collect.ArrayListMultimap;
import io.trino.hive.$internal.com.google.common.collect.Lists;
import io.trino.hive.$internal.com.google.common.collect.Multimap;
import io.trino.hive.$internal.org.slf4j.Logger;
import io.trino.hive.$internal.org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.io.HiveFileFormatUtils;
import org.apache.hadoop.hive.ql.io.HiveInputFormat;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.InputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.split.SplitLocationProvider;
import org.apache.hadoop.mapred.split.TezGroupedSplit;
import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper;
import org.apache.tez.dag.api.TaskLocationHint;

/* loaded from: input_file:org/apache/hadoop/hive/ql/exec/tez/SplitGrouper.class */
public class SplitGrouper {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) SplitGrouper.class);
    private static final Map<Map<Path, PartitionDesc>, Map<Path, PartitionDesc>> cache = new ConcurrentHashMap();
    private final TezMapredSplitsGrouper tezGrouper = new TezMapredSplitsGrouper();

    public Multimap<Integer, InputSplit> group(Configuration configuration, Multimap<Integer, InputSplit> multimap, int i, float f, SplitLocationProvider splitLocationProvider) throws IOException {
        Map<Integer, Integer> estimateBucketSizes = estimateBucketSizes(i, f, multimap.asMap());
        ArrayListMultimap create = ArrayListMultimap.create();
        Iterator<Integer> it = multimap.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            InputSplit[] inputSplitArr = (InputSplit[]) multimap.get(Integer.valueOf(intValue)).toArray(new InputSplit[0]);
            InputSplit[] groupedSplits = this.tezGrouper.getGroupedSplits(configuration, inputSplitArr, estimateBucketSizes.get(Integer.valueOf(intValue)).intValue(), HiveInputFormat.class.getName(), new ColumnarSplitSizeEstimator(), splitLocationProvider);
            LOG.info("Original split count is " + inputSplitArr.length + " grouped split count is " + groupedSplits.length + ", for bucket: " + intValue);
            for (InputSplit inputSplit : groupedSplits) {
                create.put(Integer.valueOf(intValue), inputSplit);
            }
        }
        return create;
    }

    public List<TaskLocationHint> createTaskLocationHints(InputSplit[] inputSplitArr, boolean z) throws IOException {
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(inputSplitArr.length);
        for (InputSplit inputSplit : inputSplitArr) {
            String rack = inputSplit instanceof TezGroupedSplit ? ((TezGroupedSplit) inputSplit).getRack() : null;
            if (rack == null) {
                String[] locations = inputSplit.getLocations();
                if (locations == null || locations.length <= 0) {
                    newArrayListWithCapacity.add(TaskLocationHint.createTaskLocationHint((Set) null, (Set) null));
                } else if (z && locations.length > 1 && (inputSplit instanceof FileSplit)) {
                    Arrays.sort(locations);
                    FileSplit fileSplit = (FileSplit) inputSplit;
                    int hash = Objects.hash(fileSplit.getPath(), Long.valueOf(fileSplit.getStart())) % locations.length;
                    LinkedHashSet linkedHashSet = new LinkedHashSet(locations.length);
                    for (int i = 0; i < locations.length; i++) {
                        linkedHashSet.add(locations[(hash + i) % locations.length]);
                    }
                    newArrayListWithCapacity.add(TaskLocationHint.createTaskLocationHint(linkedHashSet, (Set) null));
                } else {
                    newArrayListWithCapacity.add(TaskLocationHint.createTaskLocationHint(new LinkedHashSet(Arrays.asList(inputSplit.getLocations())), (Set) null));
                }
            } else {
                newArrayListWithCapacity.add(TaskLocationHint.createTaskLocationHint((Set) null, Collections.singleton(rack)));
            }
        }
        return newArrayListWithCapacity;
    }

    public Multimap<Integer, InputSplit> generateGroupedSplits(JobConf jobConf, Configuration configuration, InputSplit[] inputSplitArr, float f, int i, SplitLocationProvider splitLocationProvider) throws Exception {
        return generateGroupedSplits(jobConf, configuration, inputSplitArr, f, i, null, true, splitLocationProvider);
    }

    public Multimap<Integer, InputSplit> generateGroupedSplits(JobConf jobConf, Configuration configuration, InputSplit[] inputSplitArr, float f, int i, String str, boolean z, SplitLocationProvider splitLocationProvider) throws Exception {
        MapWork populateMapWork = populateMapWork(jobConf, str);
        ArrayListMultimap create = ArrayListMultimap.create();
        int i2 = 0;
        InputSplit inputSplit = null;
        for (InputSplit inputSplit2 : inputSplitArr) {
            if (schemaEvolved(inputSplit2, inputSplit, z, populateMapWork)) {
                i2++;
                inputSplit = inputSplit2;
            }
            create.put(Integer.valueOf(i2), inputSplit2);
        }
        LOG.info("# Src groups for split generation: " + (i2 + 1));
        return group(jobConf, create, i, f, splitLocationProvider);
    }

    private Map<Integer, Integer> estimateBucketSizes(int i, float f, Map<Integer, Collection<InputSplit>> map) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        long j = 0;
        boolean z = false;
        Iterator<Integer> it = map.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            long j2 = 0;
            Iterator<InputSplit> it2 = map.get(Integer.valueOf(intValue)).iterator();
            while (it2.hasNext()) {
                FileSplit fileSplit = (InputSplit) it2.next();
                if (fileSplit instanceof FileSplit) {
                    FileSplit fileSplit2 = fileSplit;
                    j2 += fileSplit2.getLength();
                    j += fileSplit2.getLength();
                } else {
                    hashMap2.put(Integer.valueOf(intValue), Integer.valueOf((int) (i * f)));
                    z = true;
                }
            }
            hashMap.put(Integer.valueOf(intValue), Long.valueOf(j2));
        }
        if (z) {
            return hashMap2;
        }
        Iterator it3 = hashMap.keySet().iterator();
        while (it3.hasNext()) {
            int intValue2 = ((Integer) it3.next()).intValue();
            int i2 = 0;
            if (j != 0) {
                i2 = (int) (((i * f) * ((float) ((Long) hashMap.get(Integer.valueOf(intValue2))).longValue())) / ((float) j));
            }
            LOG.info("Estimated number of tasks: " + i2 + " for bucket " + intValue2);
            if (i2 == 0) {
                i2 = 1;
            }
            hashMap2.put(Integer.valueOf(intValue2), Integer.valueOf(i2));
        }
        return hashMap2;
    }

    private static MapWork populateMapWork(JobConf jobConf, String str) {
        MapWork mapWork = null;
        if (str != null) {
            mapWork = (MapWork) Utilities.getMergeWork(jobConf, str);
        }
        if (mapWork == null) {
            mapWork = Utilities.getMapWork(jobConf);
        }
        return mapWork;
    }

    private boolean schemaEvolved(InputSplit inputSplit, InputSplit inputSplit2, boolean z, MapWork mapWork) throws IOException {
        boolean z2 = false;
        Path path = ((FileSplit) inputSplit).getPath();
        PartitionDesc partitionDesc = (PartitionDesc) HiveFileFormatUtils.getFromPathRecursively(mapWork.getPathToPartitionInfo(), path, cache);
        String deserializerClassName = partitionDesc.getDeserializerClassName();
        Class<? extends InputFormat> inputFileFormatClass = partitionDesc.getInputFileFormatClass();
        Class<? extends InputFormat> cls = null;
        String str = null;
        if (inputSplit2 != null) {
            Path path2 = ((FileSplit) inputSplit2).getPath();
            if (!z) {
                return !path.equals(path2);
            }
            PartitionDesc partitionDesc2 = (PartitionDesc) HiveFileFormatUtils.getFromPathRecursively(mapWork.getPathToPartitionInfo(), path2, cache);
            str = partitionDesc2.getDeserializerClassName();
            cls = partitionDesc2.getInputFileFormatClass();
        }
        if (inputFileFormatClass != cls || !deserializerClassName.equals(str)) {
            z2 = true;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Adding split " + path + " to src new group? " + z2);
        }
        return z2;
    }
}
