package io.prestosql.orc;

import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import io.prestosql.memory.context.AggregatedMemoryContext;
import io.prestosql.orc.checkpoint.Checkpoints;
import io.prestosql.orc.checkpoint.InvalidCheckpointException;
import io.prestosql.orc.checkpoint.StreamCheckpoint;
import io.prestosql.orc.metadata.ColumnEncoding;
import io.prestosql.orc.metadata.ColumnMetadata;
import io.prestosql.orc.metadata.MetadataReader;
import io.prestosql.orc.metadata.OrcColumnId;
import io.prestosql.orc.metadata.OrcType;
import io.prestosql.orc.metadata.PostScript;
import io.prestosql.orc.metadata.RowGroupIndex;
import io.prestosql.orc.metadata.Stream;
import io.prestosql.orc.metadata.StripeFooter;
import io.prestosql.orc.metadata.StripeInformation;
import io.prestosql.orc.metadata.statistics.BloomFilter;
import io.prestosql.orc.metadata.statistics.ColumnStatistics;
import io.prestosql.orc.stream.CheckpointInputStreamSource;
import io.prestosql.orc.stream.InputStreamSources;
import io.prestosql.orc.stream.OrcChunkLoader;
import io.prestosql.orc.stream.OrcDataReader;
import io.prestosql.orc.stream.OrcInputStream;
import io.prestosql.orc.stream.ValueInputStream;
import io.prestosql.orc.stream.ValueInputStreamSource;
import io.prestosql.orc.stream.ValueStreams;
import java.io.IOException;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

/* loaded from: input_file:io/prestosql/orc/StripeReader.class */
public class StripeReader {
    private final OrcDataSource orcDataSource;
    private final ZoneId storageTimeZone;
    private final Optional<OrcDecompressor> decompressor;
    private final ColumnMetadata<OrcType> types;
    private final PostScript.HiveWriterVersion hiveWriterVersion;
    private final Set<OrcColumnId> includedOrcColumnIds;
    private final OptionalInt rowsInRowGroup;
    private final OrcPredicate predicate;
    private final MetadataReader metadataReader;
    private final Optional<OrcWriteValidation> writeValidation;

    public StripeReader(OrcDataSource orcDataSource, ZoneId zoneId, Optional<OrcDecompressor> optional, ColumnMetadata<OrcType> columnMetadata, Set<OrcColumn> set, OptionalInt optionalInt, OrcPredicate orcPredicate, PostScript.HiveWriterVersion hiveWriterVersion, MetadataReader metadataReader, Optional<OrcWriteValidation> optional2) {
        this.orcDataSource = (OrcDataSource) Objects.requireNonNull(orcDataSource, "orcDataSource is null");
        this.storageTimeZone = (ZoneId) Objects.requireNonNull(zoneId, "storageTimeZone is null");
        this.decompressor = (Optional) Objects.requireNonNull(optional, "decompressor is null");
        this.types = (ColumnMetadata) Objects.requireNonNull(columnMetadata, "types is null");
        this.includedOrcColumnIds = getIncludeColumns((Set) Objects.requireNonNull(set, "readColumns is null"));
        this.rowsInRowGroup = optionalInt;
        this.predicate = (OrcPredicate) Objects.requireNonNull(orcPredicate, "predicate is null");
        this.hiveWriterVersion = (PostScript.HiveWriterVersion) Objects.requireNonNull(hiveWriterVersion, "hiveWriterVersion is null");
        this.metadataReader = (MetadataReader) Objects.requireNonNull(metadataReader, "metadataReader is null");
        this.writeValidation = (Optional) Objects.requireNonNull(optional2, "writeValidation is null");
    }

    public Stripe readStripe(StripeInformation stripeInformation, AggregatedMemoryContext aggregatedMemoryContext) throws IOException {
        StripeFooter readStripeFooter = readStripeFooter(stripeInformation, aggregatedMemoryContext);
        ColumnMetadata<ColumnEncoding> columnEncodings = readStripeFooter.getColumnEncodings();
        if (this.writeValidation.isPresent()) {
            this.writeValidation.get().validateTimeZone(this.orcDataSource.getId(), readStripeFooter.getTimeZone().orElse(null));
        }
        ZoneId orElse = readStripeFooter.getTimeZone().orElse(this.storageTimeZone);
        HashMap hashMap = new HashMap();
        for (Stream stream : readStripeFooter.getStreams()) {
            if (this.includedOrcColumnIds.contains(stream.getColumnId()) && isSupportedStreamType(stream, this.types.get(stream.getColumnId()).getOrcTypeKind())) {
                hashMap.put(new StreamId(stream), stream);
            }
        }
        boolean z = false;
        if (this.rowsInRowGroup.isPresent() && stripeInformation.getNumberOfRows() > this.rowsInRowGroup.getAsInt()) {
            Map<StreamId, OrcChunkLoader> readDiskRanges = readDiskRanges(stripeInformation.getOffset(), Maps.filterKeys(getDiskRanges(readStripeFooter.getStreams()), Predicates.in(hashMap.keySet())), aggregatedMemoryContext);
            Map<StreamId, List<RowGroupIndex>> readColumnIndexes = readColumnIndexes(hashMap, readDiskRanges, readBloomFilterIndexes(hashMap, readDiskRanges));
            if (this.writeValidation.isPresent()) {
                this.writeValidation.get().validateRowGroupStatistics(this.orcDataSource.getId(), stripeInformation.getOffset(), readColumnIndexes);
            }
            Set<Integer> selectRowGroups = selectRowGroups(stripeInformation, readColumnIndexes);
            if (selectRowGroups.isEmpty()) {
                aggregatedMemoryContext.close();
                return null;
            }
            Map<StreamId, ValueInputStream<?>> createValueStreams = createValueStreams(hashMap, readDiskRanges, columnEncodings);
            try {
                return new Stripe(stripeInformation.getNumberOfRows(), orElse, this.storageTimeZone, columnEncodings, createRowGroups(stripeInformation.getNumberOfRows(), hashMap, createValueStreams, readColumnIndexes, selectRowGroups, columnEncodings), createDictionaryStreamSources(hashMap, createValueStreams, columnEncodings));
            } catch (InvalidCheckpointException e) {
                z = true;
            }
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, DiskRange> entry : getDiskRanges(readStripeFooter.getStreams()).entrySet()) {
            if (hashMap.containsKey(entry.getKey())) {
                builder.put(entry);
            }
        }
        Map<StreamId, OrcChunkLoader> readDiskRanges2 = readDiskRanges(stripeInformation.getOffset(), builder.build(), aggregatedMemoryContext);
        long j = 0;
        for (Map.Entry<StreamId, Stream> entry2 : hashMap.entrySet()) {
            if (entry2.getKey().getStreamKind() == Stream.StreamKind.ROW_INDEX) {
                List<RowGroupIndex> readRowIndexes = this.metadataReader.readRowIndexes(this.hiveWriterVersion, new OrcInputStream(readDiskRanges2.get(entry2.getKey())));
                Preconditions.checkState(readRowIndexes.size() == 1 || z, "expect a single row group or an invalid check point");
                long j2 = 0;
                long j3 = 0;
                Iterator<RowGroupIndex> it = readRowIndexes.iterator();
                while (it.hasNext()) {
                    ColumnStatistics columnStatistics = it.next().getColumnStatistics();
                    if (columnStatistics.hasMinAverageValueSizeInBytes()) {
                        j2 += columnStatistics.getMinAverageValueSizeInBytes() * columnStatistics.getNumberOfValues();
                        j3 += columnStatistics.getNumberOfValues();
                    }
                }
                if (j3 > 0) {
                    j += j2 / j3;
                }
            }
        }
        Map<StreamId, ValueInputStream<?>> createValueStreams2 = createValueStreams(hashMap, readDiskRanges2, columnEncodings);
        InputStreamSources createDictionaryStreamSources = createDictionaryStreamSources(hashMap, createValueStreams2, columnEncodings);
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Map.Entry<StreamId, ValueInputStream<?>> entry3 : createValueStreams2.entrySet()) {
            builder2.put(entry3.getKey(), new ValueInputStreamSource(entry3.getValue()));
        }
        return new Stripe(stripeInformation.getNumberOfRows(), orElse, this.storageTimeZone, columnEncodings, ImmutableList.of(new RowGroup(0, 0L, stripeInformation.getNumberOfRows(), j, new InputStreamSources(builder2.build()))), createDictionaryStreamSources);
    }

    private static boolean isSupportedStreamType(Stream stream, OrcType.OrcTypeKind orcTypeKind) {
        return stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER ? (orcTypeKind == OrcType.OrcTypeKind.STRING || orcTypeKind == OrcType.OrcTypeKind.VARCHAR || orcTypeKind == OrcType.OrcTypeKind.CHAR || orcTypeKind == OrcType.OrcTypeKind.TIMESTAMP) ? false : true : (stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER_UTF8 && orcTypeKind == OrcType.OrcTypeKind.CHAR) ? false : true;
    }

    private Map<StreamId, OrcChunkLoader> readDiskRanges(long j, Map<StreamId, DiskRange> map, AggregatedMemoryContext aggregatedMemoryContext) throws IOException {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, DiskRange> entry : map.entrySet()) {
            DiskRange value = entry.getValue();
            builder.put(entry.getKey(), new DiskRange(j + value.getOffset(), value.getLength()));
        }
        Map readFully = this.orcDataSource.readFully(builder.build());
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Map.Entry entry2 : readFully.entrySet()) {
            builder2.put((StreamId) entry2.getKey(), OrcChunkLoader.create((OrcDataReader) entry2.getValue(), this.decompressor, aggregatedMemoryContext));
        }
        return builder2.build();
    }

    private Map<StreamId, ValueInputStream<?>> createValueStreams(Map<StreamId, Stream> map, Map<StreamId, OrcChunkLoader> map2, ColumnMetadata<ColumnEncoding> columnMetadata) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, Stream> entry : map.entrySet()) {
            StreamId key = entry.getKey();
            Stream value = entry.getValue();
            ColumnEncoding.ColumnEncodingKind columnEncodingKind = columnMetadata.get(value.getColumnId()).getColumnEncodingKind();
            if (!isIndexStream(value) && value.getLength() != 0) {
                builder.put(key, ValueStreams.createValueStreams(key, map2.get(key), this.types.get(value.getColumnId()).getOrcTypeKind(), columnEncodingKind));
            }
        }
        return builder.build();
    }

    private InputStreamSources createDictionaryStreamSources(Map<StreamId, Stream> map, Map<StreamId, ValueInputStream<?>> map2, ColumnMetadata<ColumnEncoding> columnMetadata) {
        ValueInputStream<?> valueInputStream;
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, Stream> entry : map.entrySet()) {
            StreamId key = entry.getKey();
            Stream value = entry.getValue();
            ColumnEncoding.ColumnEncodingKind columnEncodingKind = columnMetadata.get(value.getColumnId()).getColumnEncodingKind();
            if (isDictionary(value, columnEncodingKind) && (valueInputStream = map2.get(key)) != null) {
                builder.put(key, CheckpointInputStreamSource.createCheckpointStreamSource(valueInputStream, Checkpoints.getDictionaryStreamCheckpoint(key, this.types.get(value.getColumnId()).getOrcTypeKind(), columnEncodingKind)));
            }
        }
        return new InputStreamSources(builder.build());
    }

    private List<RowGroup> createRowGroups(int i, Map<StreamId, Stream> map, Map<StreamId, ValueInputStream<?>> map2, Map<StreamId, List<RowGroupIndex>> map3, Set<Integer> set, ColumnMetadata<ColumnEncoding> columnMetadata) throws InvalidCheckpointException {
        int orElseThrow = this.rowsInRowGroup.orElseThrow(() -> {
            return new IllegalStateException("Cannot create row groups if row group info is missing");
        });
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            Map<StreamId, StreamCheckpoint> streamCheckpoints = Checkpoints.getStreamCheckpoints(this.includedOrcColumnIds, this.types, this.decompressor.isPresent(), intValue, columnMetadata, map, map3);
            int i2 = intValue * orElseThrow;
            builder.add(createRowGroup(intValue, i2, Math.min(i - i2, orElseThrow), map3.entrySet().stream().mapToLong(entry -> {
                return ((RowGroupIndex) ((List) entry.getValue()).get(intValue)).getColumnStatistics().getMinAverageValueSizeInBytes();
            }).sum(), map2, streamCheckpoints));
        }
        return builder.build();
    }

    private static RowGroup createRowGroup(int i, int i2, int i3, long j, Map<StreamId, ValueInputStream<?>> map, Map<StreamId, StreamCheckpoint> map2) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, StreamCheckpoint> entry : map2.entrySet()) {
            StreamId key = entry.getKey();
            StreamCheckpoint value = entry.getValue();
            ValueInputStream<?> valueInputStream = map.get(key);
            if (valueInputStream != null) {
                builder.put(key, CheckpointInputStreamSource.createCheckpointStreamSource(valueInputStream, value));
            }
        }
        return new RowGroup(i, i2, i3, j, new InputStreamSources(builder.build()));
    }

    private StripeFooter readStripeFooter(StripeInformation stripeInformation, AggregatedMemoryContext aggregatedMemoryContext) throws IOException {
        OrcInputStream orcInputStream = new OrcInputStream(OrcChunkLoader.create(this.orcDataSource.getId(), this.orcDataSource.readFully(stripeInformation.getOffset() + stripeInformation.getIndexLength() + stripeInformation.getDataLength(), Math.toIntExact(stripeInformation.getFooterLength())), this.decompressor, aggregatedMemoryContext));
        try {
            StripeFooter readStripeFooter = this.metadataReader.readStripeFooter(this.types, orcInputStream);
            orcInputStream.close();
            return readStripeFooter;
        } catch (Throwable th) {
            try {
                orcInputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    static boolean isIndexStream(Stream stream) {
        return stream.getStreamKind() == Stream.StreamKind.ROW_INDEX || stream.getStreamKind() == Stream.StreamKind.DICTIONARY_COUNT || stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER || stream.getStreamKind() == Stream.StreamKind.BLOOM_FILTER_UTF8;
    }

    private Map<OrcColumnId, List<BloomFilter>> readBloomFilterIndexes(Map<StreamId, Stream> map, Map<StreamId, OrcChunkLoader> map2) throws IOException {
        HashMap hashMap = new HashMap();
        for (Map.Entry<StreamId, Stream> entry : map.entrySet()) {
            Stream value = entry.getValue();
            if (value.getStreamKind() == Stream.StreamKind.BLOOM_FILTER_UTF8) {
                hashMap.put(value.getColumnId(), this.metadataReader.readBloomFilterIndexes(new OrcInputStream(map2.get(entry.getKey()))));
            }
        }
        for (Map.Entry<StreamId, Stream> entry2 : map.entrySet()) {
            Stream value2 = entry2.getValue();
            if (value2.getStreamKind() == Stream.StreamKind.BLOOM_FILTER && !hashMap.containsKey(value2.getColumnId())) {
                hashMap.put(entry2.getKey().getColumnId(), this.metadataReader.readBloomFilterIndexes(new OrcInputStream(map2.get(entry2.getKey()))));
            }
        }
        return ImmutableMap.copyOf(hashMap);
    }

    private Map<StreamId, List<RowGroupIndex>> readColumnIndexes(Map<StreamId, Stream> map, Map<StreamId, OrcChunkLoader> map2, Map<OrcColumnId, List<BloomFilter>> map3) throws IOException {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (Map.Entry<StreamId, Stream> entry : map.entrySet()) {
            if (entry.getValue().getStreamKind() == Stream.StreamKind.ROW_INDEX) {
                OrcInputStream orcInputStream = new OrcInputStream(map2.get(entry.getKey()));
                List<BloomFilter> list = map3.get(entry.getKey().getColumnId());
                List readRowIndexes = this.metadataReader.readRowIndexes(this.hiveWriterVersion, orcInputStream);
                if (list != null && !list.isEmpty()) {
                    ImmutableList.Builder builder2 = ImmutableList.builder();
                    for (int i = 0; i < readRowIndexes.size(); i++) {
                        RowGroupIndex rowGroupIndex = readRowIndexes.get(i);
                        builder2.add(new RowGroupIndex(rowGroupIndex.getPositions(), rowGroupIndex.getColumnStatistics().withBloomFilter(list.get(i))));
                    }
                    readRowIndexes = builder2.build();
                }
                builder.put(entry.getKey(), readRowIndexes);
            }
        }
        return builder.build();
    }

    private Set<Integer> selectRowGroups(StripeInformation stripeInformation, Map<StreamId, List<RowGroupIndex>> map) {
        int orElseThrow = this.rowsInRowGroup.orElseThrow(() -> {
            return new IllegalStateException("Cannot create row groups if row group info is missing");
        });
        int numberOfRows = stripeInformation.getNumberOfRows();
        int ceil = ceil(numberOfRows, orElseThrow);
        ImmutableSet.Builder builder = ImmutableSet.builder();
        int i = numberOfRows;
        for (int i2 = 0; i2 < ceil; i2++) {
            int min = Math.min(i, orElseThrow);
            if (this.predicate.matches(min, getRowGroupStatistics(this.types, map, i2))) {
                builder.add(Integer.valueOf(i2));
            }
            i -= min;
        }
        return builder.build();
    }

    private static ColumnMetadata<ColumnStatistics> getRowGroupStatistics(ColumnMetadata<OrcType> columnMetadata, Map<StreamId, List<RowGroupIndex>> map, int i) {
        Objects.requireNonNull(map, "columnIndexes is null");
        Preconditions.checkArgument(i >= 0, "rowGroup is negative");
        Map map2 = (Map) map.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return Integer.valueOf(((StreamId) entry.getKey()).getColumnId().getId());
        }, (v0) -> {
            return v0.getValue();
        }));
        ArrayList arrayList = new ArrayList(columnMetadata.size());
        for (int i2 = 0; i2 < columnMetadata.size(); i2++) {
            List list = (List) map2.get(Integer.valueOf(i2));
            if (list != null) {
                arrayList.add(((RowGroupIndex) list.get(i)).getColumnStatistics());
            } else {
                arrayList.add(null);
            }
        }
        return new ColumnMetadata<>(arrayList);
    }

    private static boolean isDictionary(Stream stream, ColumnEncoding.ColumnEncodingKind columnEncodingKind) {
        return stream.getStreamKind() == Stream.StreamKind.DICTIONARY_DATA || (stream.getStreamKind() == Stream.StreamKind.LENGTH && (columnEncodingKind == ColumnEncoding.ColumnEncodingKind.DICTIONARY || columnEncodingKind == ColumnEncoding.ColumnEncodingKind.DICTIONARY_V2));
    }

    private static Map<StreamId, DiskRange> getDiskRanges(List<Stream> list) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        long j = 0;
        for (Stream stream : list) {
            int length = stream.getLength();
            if (length > 0) {
                builder.put(new StreamId(stream), new DiskRange(j, length));
            }
            j += length;
        }
        return builder.build();
    }

    private static Set<OrcColumnId> getIncludeColumns(Set<OrcColumn> set) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        includeColumnsRecursive(linkedHashSet, set);
        return linkedHashSet;
    }

    private static void includeColumnsRecursive(Set<OrcColumnId> set, Collection<OrcColumn> collection) {
        for (OrcColumn orcColumn : collection) {
            set.add(orcColumn.getColumnId());
            includeColumnsRecursive(set, orcColumn.getNestedColumns());
        }
    }

    private static int ceil(int i, int i2) {
        return ((i + i2) - 1) / i2;
    }
}
