package com.databricks.client.spark.arrow;

import com.databricks.client.dsi.core.impl.DSIDriverSingleton;
import com.databricks.client.dsi.dataengine.utilities.DSITypeUtilities;
import com.databricks.client.dsi.dataengine.utilities.DataWrapper;
import com.databricks.client.hivecommon.HiveJDBCSettings;
import com.databricks.client.hivecommon.api.HiveServer2BaseBuffer;
import com.databricks.client.hivecommon.core.HiveJDBCCommonDriver;
import com.databricks.client.hivecommon.exceptions.HiveJDBCMessageKey;
import com.databricks.client.jdbc42.internal.apache.arrow.memory.BufferAllocator;
import com.databricks.client.jdbc42.internal.apache.arrow.memory.RootAllocator;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.FieldVector;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.VarCharVector;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.VectorLoader;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.VectorSchemaRoot;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.ipc.ArrowStreamReader;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.ipc.ReadChannel;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.ipc.message.MessageSerializer;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.DateUnit;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.FloatingPointPrecision;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.TimeUnit;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.pojo.ArrowType;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.pojo.Field;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.pojo.FieldType;
import com.databricks.client.jdbc42.internal.apache.arrow.vector.types.pojo.Schema;
import com.databricks.client.jdbc42.internal.apache.hive.service.rpc.thrift.TColumnDesc;
import com.databricks.client.jdbc42.internal.apache.hive.service.rpc.thrift.TSparkArrowBatch;
import com.databricks.client.jdbc42.internal.apache.hive.service.rpc.thrift.TTypeId;
import com.databricks.client.jdbc42.internal.jpountz.lz4.LZ4FrameInputStream;
import com.databricks.client.sqlengine.executor.etree.value.SqlDataIntegrityChecker;
import com.databricks.client.support.ILogger;
import com.databricks.client.support.LogUtilities;
import com.databricks.client.support.exceptions.ErrorException;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TimeZone;

/* loaded from: input_file:com/databricks/client/spark/arrow/ArrowBuffer.class */
public class ArrowBuffer extends HiveServer2BaseBuffer {
    public ArrowStreamReader m_arrowStreamReader;
    public List<TColumnDesc> m_columnMetadata;
    public byte[] m_resultFile;
    public RootAllocator m_rootAllocator;
    public long m_totalRowsNeeded;
    private Schema m_arrowSchema;
    private Iterator<TSparkArrowBatch> m_batches;
    private List<FieldVector> m_data;
    private List<List<FieldVector>> m_cachedData;
    private ArrowDataRetrievers m_dataRetriever;
    private int m_processedRows;
    private int m_rowsInCurrBatch;
    private boolean m_isCompressed;
    private BufferAllocator m_bufferAllocator;
    private VectorSchemaRoot m_schemaRoot;
    private static final int NOTHING_TO_READ_FROM_STREAM = -1;

    public ArrowBuffer(ILogger iLogger, List<TColumnDesc> list, boolean z, HiveJDBCSettings hiveJDBCSettings) {
        super(iLogger, hiveJDBCSettings);
        this.m_totalRowsNeeded = 0L;
        this.m_dataRetriever = new ArrowDataRetrievers();
        this.m_processedRows = 0;
        this.m_rowsInCurrBatch = 0;
        this.m_columnMetadata = list;
        this.m_isCompressed = z;
        this.m_rootAllocator = new RootAllocator();
    }

    @Override // com.databricks.client.hivecommon.api.IHiveServerBuffer
    public void close() {
        if (null != this.m_schemaRoot) {
            this.m_schemaRoot.close();
            this.m_schemaRoot = null;
        }
        if (null != this.m_bufferAllocator) {
            this.m_bufferAllocator.close();
            this.m_bufferAllocator = null;
        }
        closeArrowStreamReader();
        if (null != this.m_rootAllocator) {
            this.m_rootAllocator.close();
            this.m_rootAllocator = null;
        }
    }

    public void convertHS2MetadataToArrowSchema() throws ErrorException {
        try {
            LogUtilities.logFunctionEntrance(this.m_logger, new Object[0]);
            ArrayList arrayList = new ArrayList();
            Iterator<TColumnDesc> it = this.m_columnMetadata.iterator();
            while (it.hasNext()) {
                arrayList.add(getField(it.next()));
            }
            this.m_arrowSchema = new Schema(arrayList);
        } catch (ErrorException e) {
            throw e;
        } catch (Exception e2) {
            ErrorException createGeneralException = HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.ARROW_DESERIALIZATION_ERROR.name(), new String[]{e2.getMessage()});
            createGeneralException.initCause(e2);
            throw createGeneralException;
        }
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer, com.databricks.client.hivecommon.api.IHiveServerBuffer
    public boolean getData(int i, long j, long j2, DataWrapper dataWrapper, int i2, short s, String str, HiveJDBCSettings hiveJDBCSettings) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, Integer.valueOf(i), Long.valueOf(j), Long.valueOf(j2), Integer.valueOf(i2), Short.valueOf(s), hiveJDBCSettings);
        FieldVector fieldVector = this.m_data.get(i);
        int i3 = i2 - this.m_processedRows;
        try {
            switch (s) {
                case -6:
                    this.m_dataRetriever.setTinyInt(s, fieldVector, i3, dataWrapper);
                    return false;
                case -5:
                    this.m_dataRetriever.setBigInt(s, fieldVector, i3, dataWrapper);
                    return false;
                case -2:
                    return DSITypeUtilities.outputBinary(this.m_dataRetriever.getBinary(fieldVector, i3), dataWrapper, j, j2);
                case 1:
                    String string = this.m_dataRetriever.getString(fieldVector, i3);
                    if (string != null) {
                        return DSITypeUtilities.outputCharStringData(string, dataWrapper, j, j2);
                    }
                    dataWrapper.setNull(s);
                    return false;
                case 3:
                    String string2 = this.m_dataRetriever.getString(fieldVector, i3);
                    if (string2 == null) {
                        dataWrapper.setNull(s);
                        return false;
                    }
                    dataWrapper.setDecimal(new BigDecimal(string2));
                    return false;
                case 4:
                    this.m_dataRetriever.setInteger(s, fieldVector, i3, dataWrapper);
                    return false;
                case 5:
                    this.m_dataRetriever.setSmallInt(s, fieldVector, i3, dataWrapper);
                    return false;
                case 7:
                    this.m_dataRetriever.setReal(s, fieldVector, i3, dataWrapper);
                    return false;
                case 8:
                    this.m_dataRetriever.setDouble(s, fieldVector, i3, dataWrapper);
                    return false;
                case 12:
                    String string3 = this.m_dataRetriever.getString(fieldVector, i3);
                    if (string3 != null) {
                        return DSITypeUtilities.outputVarCharStringData(string3, dataWrapper, j, j2);
                    }
                    dataWrapper.setNull(s);
                    return false;
                case 16:
                    this.m_dataRetriever.setBoolean(s, fieldVector, i3, dataWrapper);
                    return false;
                case 91:
                    this.m_dataRetriever.setDate(s, fieldVector, i3, dataWrapper);
                    return false;
                case 93:
                    if (!hiveJDBCSettings.m_arrowTimestampAsString) {
                        this.m_dataRetriever.setTimestamp(s, fieldVector, i3, dataWrapper);
                        return false;
                    }
                    String string4 = this.m_dataRetriever.getString(fieldVector, i3);
                    if (null == string4) {
                        dataWrapper.setNull(s);
                        return false;
                    }
                    dataWrapper.setTimestamp(convertTimestamp(string4));
                    return false;
                default:
                    throw HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.HIVE_RESULTSET_DATA_RETRIEVING_ERR.name(), new String[]{"Data type is not supported. SqlType: " + Integer.toString(s)});
            }
        } catch (ErrorException e) {
            e.loadMessage(DSIDriverSingleton.getInstance().getMessageSource(), DSIDriverSingleton.getInstance().getLocale());
            ErrorException createGeneralException = HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.HIVE_RESULTSET_DATA_RETRIEVING_ERR.name(), new String[]{"Column" + String.valueOf(i) + ":" + e.getMessage()});
            createGeneralException.initCause(e);
            throw createGeneralException;
        }
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer
    public List<Integer> getIntColumn(int i) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, Integer.valueOf(i));
        throw HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.UNSUPPORTED_OPERATION_ERR.name(), new String[]{"ArrowBuffer's getIntColumn() should not be called"});
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer
    public int getNumColumns() {
        LogUtilities.logFunctionEntrance(this.m_logger, new Object[0]);
        if (null != this.m_data) {
            return this.m_data.size();
        }
        return 0;
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer
    public List<String> getStringColumn(int i) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, Integer.valueOf(i));
        ArrayList arrayList = new ArrayList();
        if (null != this.m_cachedData) {
            getStringColFromCachedBatches(i, arrayList);
        } else {
            if (!(this.m_data.get(i) instanceof VarCharVector)) {
                throw HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.ARROW_INCORRECT_VECTOR_TYPE.name(), new String[]{String.valueOf(i), "VarChar", this.m_data.get(i).getClass().getName()});
            }
            this.m_cachedData = new ArrayList();
            this.m_cachedData.add(this.m_data);
            getStringColumnFromBatch(i, this.m_data, arrayList);
            while (this.m_batches.hasNext()) {
                deserializeBatch(this.m_batches.next());
                this.m_cachedData.add(this.m_data);
                getStringColumnFromBatch(i, this.m_data, arrayList);
            }
        }
        return arrayList;
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer, com.databricks.client.hivecommon.api.IHiveServerBuffer
    public boolean isGetNextBuffer(int i) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, Integer.valueOf(i));
        if (this.m_resultFile != null && i >= this.m_totalRowsNeeded) {
            return true;
        }
        if (this.m_processedRows + this.m_rowsInCurrBatch > i) {
            return false;
        }
        if (this.m_batches != null && this.m_batches.hasNext()) {
            this.m_processedRows += this.m_rowsInCurrBatch;
            deserializeBatch(this.m_batches.next());
            return false;
        }
        if (this.m_resultFile == null || this.m_arrowStreamReader == null) {
            return true;
        }
        try {
            if (this.m_processedRows + this.m_rowsInCurrBatch >= this.m_totalRowsNeeded || !this.m_arrowStreamReader.loadNextBatch()) {
                return true;
            }
            this.m_processedRows += this.m_rowsInCurrBatch;
            this.m_schemaRoot = this.m_arrowStreamReader.getVectorSchemaRoot();
            this.m_data = this.m_schemaRoot.getFieldVectors();
            this.m_rowsInCurrBatch = this.m_arrowStreamReader.getVectorSchemaRoot().getRowCount();
            return false;
        } catch (IOException e) {
            close();
            ErrorException createGeneralException = HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.FILE_PARSE_ARROW_ERROR.name(), new String[]{e.getMessage()});
            createGeneralException.initCause(e);
            throw createGeneralException;
        }
    }

    protected ArrowType getArrowType(TTypeId tTypeId) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, tTypeId);
        switch (tTypeId) {
            case NULL_TYPE:
                return ArrowType.Null.INSTANCE;
            case BOOLEAN_TYPE:
                return ArrowType.Bool.INSTANCE;
            case TINYINT_TYPE:
                return new ArrowType.Int(8, true);
            case SMALLINT_TYPE:
                return new ArrowType.Int(16, true);
            case INT_TYPE:
                return new ArrowType.Int(32, true);
            case BIGINT_TYPE:
                return new ArrowType.Int(64, true);
            case FLOAT_TYPE:
                return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
            case DOUBLE_TYPE:
                return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
            case DATE_TYPE:
                return new ArrowType.Date(DateUnit.DAY);
            case BINARY_TYPE:
                return ArrowType.Binary.INSTANCE;
            case TIMESTAMP_TYPE:
                return this.m_settings.m_arrowTimestampAsString ? ArrowType.Utf8.INSTANCE : new ArrowType.Timestamp(TimeUnit.MICROSECOND, TimeZone.getDefault().getDisplayName());
            case DECIMAL_TYPE:
            case VARCHAR_TYPE:
            case CHAR_TYPE:
            case STRING_TYPE:
            case ARRAY_TYPE:
            case MAP_TYPE:
            case STRUCT_TYPE:
            case UNION_TYPE:
            case USER_DEFINED_TYPE:
                return ArrowType.Utf8.INSTANCE;
            default:
                throw HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.ARROW_DESERIALIZATION_ERROR.name(), new String[]{"Unsupported data type: " + tTypeId.toString()});
        }
    }

    protected void deserializeBatch(TSparkArrowBatch tSparkArrowBatch) throws ErrorException {
        int read;
        LogUtilities.logFunctionEntrance(this.m_logger, new Object[0]);
        if (null == this.m_arrowSchema) {
            convertHS2MetadataToArrowSchema();
        }
        try {
            if (null == this.m_bufferAllocator && null == this.m_schemaRoot) {
                this.m_bufferAllocator = this.m_rootAllocator.newChildAllocator("fromBatchList", 0L, Long.MAX_VALUE);
                this.m_schemaRoot = VectorSchemaRoot.create(this.m_arrowSchema, this.m_bufferAllocator);
            }
            this.m_rowsInCurrBatch = (int) tSparkArrowBatch.getRowCount();
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(tSparkArrowBatch.getBatch());
            if (this.m_isCompressed) {
                LZ4FrameInputStream lZ4FrameInputStream = new LZ4FrameInputStream(byteArrayInputStream);
                int length = tSparkArrowBatch.getBatch().length * 2;
                byte[] bArr = new byte[length];
                int i = 0;
                do {
                    read = lZ4FrameInputStream.read(bArr, i, bArr.length - i);
                    if (-1 == read) {
                        break;
                    }
                    i += read;
                    if (length <= i) {
                        length *= 2;
                        bArr = Arrays.copyOf(bArr, length);
                    }
                } while (-1 != read);
                lZ4FrameInputStream.close();
                byteArrayInputStream = new ByteArrayInputStream(bArr);
            }
            ArrowRecordBatch deserializeRecordBatch = MessageSerializer.deserializeRecordBatch(new ReadChannel(Channels.newChannel(byteArrayInputStream)), this.m_bufferAllocator);
            new VectorLoader(this.m_schemaRoot).load(deserializeRecordBatch);
            deserializeRecordBatch.close();
            this.m_data = this.m_schemaRoot.getFieldVectors();
        } catch (Exception e) {
            ErrorException createGeneralException = HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.ARROW_DESERIALIZATION_ERROR.name(), new String[]{e.getMessage()});
            createGeneralException.initCause(e);
            throw createGeneralException;
        }
    }

    protected void getStringColumnFromBatch(int i, List<FieldVector> list, List<String> list2) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, Integer.valueOf(i));
        VarCharVector varCharVector = (VarCharVector) list.get(i);
        for (int i2 = 0; i2 < varCharVector.getValueCount(); i2++) {
            list2.add(this.m_dataRetriever.getString(varCharVector, i2));
        }
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer
    protected void handleInitializeBuffer() throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, new Object[0]);
        if (this.m_hiveServer2Buffer != null && this.m_hiveServer2Buffer.getResults().isSetArrowBatches()) {
            this.m_batches = this.m_hiveServer2Buffer.getResults().getArrowBatchesIterator();
            if (this.m_batches.hasNext()) {
                if (this.m_arrowSchema == null) {
                    convertHS2MetadataToArrowSchema();
                }
                deserializeBatch(this.m_batches.next());
            }
            this.m_processedRows = 0;
            this.m_cachedData = null;
            return;
        }
        if (null != this.m_resultFile) {
            if (null != this.m_arrowStreamReader) {
                closeArrowStreamReader();
            }
            this.m_arrowStreamReader = new ArrowStreamReader(new ByteArrayInputStream(this.m_resultFile), this.m_rootAllocator);
            try {
                if (this.m_arrowStreamReader.loadNextBatch()) {
                    if (null != this.m_schemaRoot) {
                        this.m_schemaRoot.close();
                    }
                    this.m_schemaRoot = this.m_arrowStreamReader.getVectorSchemaRoot();
                    this.m_rowsInCurrBatch = this.m_schemaRoot.getRowCount();
                    if (this.m_totalRowsNeeded > SqlDataIntegrityChecker.SIGNED_INT_MAX) {
                        throw HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.FILE_CONTAINS_OVERFLOWN_ROWS.name(), new String[]{String.valueOf(this.m_totalRowsNeeded)});
                    }
                    if (this.m_totalRowsNeeded < this.m_arrowStreamReader.getVectorSchemaRoot().getRowCount()) {
                        this.m_rowsInCurrBatch = (int) this.m_totalRowsNeeded;
                        this.m_arrowStreamReader = null;
                    }
                    this.m_data = this.m_schemaRoot.getFieldVectors();
                }
                this.m_processedRows = 0;
            } catch (IOException e) {
                close();
                ErrorException createGeneralException = HiveJDBCCommonDriver.s_HiveMessages.createGeneralException(HiveJDBCMessageKey.FILE_PARSE_ARROW_ERROR.name(), new String[]{e.getMessage()});
                createGeneralException.initCause(e);
                throw createGeneralException;
            }
        }
    }

    @Override // com.databricks.client.hivecommon.api.HiveServer2BaseBuffer
    protected void setNumRows() throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, new Object[0]);
        int i = 0;
        if (this.m_hiveServer2Buffer == null || !this.m_hiveServer2Buffer.getResults().isSetArrowBatches()) {
            i = (int) this.m_totalRowsNeeded;
        } else {
            Iterator<TSparkArrowBatch> it = this.m_hiveServer2Buffer.getResults().getArrowBatches().iterator();
            while (it.hasNext()) {
                i = (int) (i + it.next().getRowCount());
            }
        }
        this.m_numRows = i;
    }

    private void closeArrowStreamReader() {
        if (null != this.m_arrowStreamReader) {
            try {
                this.m_arrowStreamReader.close();
            } catch (IOException e) {
                LogUtilities.logWarning("An exception happened when closing ArrowStreamReader: " + e.getMessage(), this.m_logger);
            }
            this.m_arrowStreamReader = null;
        }
    }

    private void getStringColFromCachedBatches(int i, List<String> list) throws ErrorException {
        LogUtilities.logFunctionEntrance(this.m_logger, Integer.valueOf(i));
        Iterator<List<FieldVector>> it = this.m_cachedData.iterator();
        while (it.hasNext()) {
            getStringColumnFromBatch(i, it.next(), list);
        }
    }

    private Field getField(TColumnDesc tColumnDesc) throws ErrorException {
        TTypeId type = tColumnDesc.getTypeDesc().getTypes().get(0).getPrimitiveEntry().getType();
        return new Field(type.name(), new FieldType(true, getArrowType(type), null), new ArrayList());
    }
}
