package org.dflib.jdbc.connector.saver;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Objects;
import java.util.function.Supplier;
import org.dflib.DataFrame;
import org.dflib.GroupBy;
import org.dflib.Hasher;
import org.dflib.Index;
import org.dflib.IntSeries;
import org.dflib.RowToValueMapper;
import org.dflib.Series;
import org.dflib.jdbc.SaveOp;
import org.dflib.jdbc.connector.JdbcConnector;
import org.dflib.jdbc.connector.StatementBuilder;
import org.dflib.jdbc.connector.TableLoader;
import org.dflib.jdbc.connector.metadata.TableFQName;
import org.dflib.join.JoinIndicator;
import org.dflib.row.RowProxy;
import org.dflib.series.SingleValueSeries;

/* loaded from: input_file:org/dflib/jdbc/connector/saver/SaveViaUpsert.class */
public class SaveViaUpsert extends TableSaveStrategy {
    private static final String INDICATOR_COLUMN = "dflib_ind_%$#86AcD3";
    private static final String DIFF_COLUMN = "dflib_dif_%4$#96Ac3";
    protected String[] keyColumns;

    public SaveViaUpsert(JdbcConnector jdbcConnector, TableFQName tableFQName, String[] strArr, int i) {
        super(jdbcConnector, tableFQName, i);
        this.keyColumns = strArr;
    }

    @Override // org.dflib.jdbc.connector.saver.TableSaveStrategy
    protected Supplier<Series<SaveOp>> doInsertOrUpdate(JdbcConnector jdbcConnector, DataFrame dataFrame) {
        DataFrame load = new TableLoader(jdbcConnector, this.tableName).cols(dataFrame.getColumnsIndex().toArray()).eq(keyValues(dataFrame)).load();
        if (load.height() == 0) {
            doInsert(jdbcConnector, dataFrame);
            return () -> {
                return new SingleValueSeries(SaveOp.insert, dataFrame.height());
            };
        }
        DataFrame select = dataFrame.leftJoin(load).on(keyHasher()).indicatorColumn(INDICATOR_COLUMN).select();
        Series<JoinIndicator> column = select.getColumn(INDICATOR_COLUMN);
        IntSeries index = column.index(joinIndicator -> {
            return joinIndicator == JoinIndicator.left_only;
        });
        IntSeries index2 = column.index(joinIndicator2 -> {
            return joinIndicator2 == JoinIndicator.both;
        });
        int height = select.height() - dataFrame.height();
        if (height > 0) {
            throw new IllegalStateException(String.format("Duplicate rows in the database table %s using key columns %s. Specify key columns that produce unique DB rows.", this.tableName, Arrays.toString(this.keyColumns)));
        }
        if (height < 0) {
            throw new IllegalStateException();
        }
        UpsertInfoTracker upsertInfoTracker = new UpsertInfoTracker(dataFrame.width(), dataFrame.height());
        upsertInfoTracker.insertAndUpdate(column);
        if (index.size() > 0) {
            doInsert(jdbcConnector, dataFrame.rows(index).select());
        }
        if (index2.size() > 0) {
            Index columnsIndex = dataFrame.getColumnsIndex();
            doUpdate(jdbcConnector, dataFrame.rows(index2).select(), select.cols(select.getColumnsIndex().selectRange(columnsIndex.size(), columnsIndex.size() * 2)).select().cols().as(columnsIndex.toArray()).rows(index2).select(), upsertInfoTracker);
        }
        Objects.requireNonNull(upsertInfoTracker);
        return upsertInfoTracker::getInfo;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataFrame keyValues(DataFrame dataFrame) {
        return dataFrame.cols(this.keyColumns).select();
    }

    protected void doUpdate(JdbcConnector jdbcConnector, DataFrame dataFrame, DataFrame dataFrame2, UpsertInfoTracker upsertInfoTracker) {
        int width = dataFrame.width();
        if (width == this.keyColumns.length) {
            log("All DataFrame columns are key columns. Skipping update.", new Object[0]);
            return;
        }
        DataFrame merge = dataFrame.colsAppend(new String[]{DIFF_COLUMN}).merge(new Series[]{dataFrame.eq(dataFrame2).colsAppend(new String[]{DIFF_COLUMN}).merge(new RowToValueMapper[]{this::booleansAsBitSet}).getColumn(DIFF_COLUMN)});
        upsertInfoTracker.updatesCardinality(merge.getColumn(DIFF_COLUMN));
        GroupBy group = merge.group(new String[]{DIFF_COLUMN});
        for (BitSet bitSet : group.getGroupKeys()) {
            int cardinality = bitSet.cardinality();
            if (cardinality != width) {
                DataFrame group2 = group.getGroup(bitSet);
                String[] strArr = new String[width - cardinality];
                int i = 0;
                for (int i2 = 0; i2 < width; i2++) {
                    if (!bitSet.get(i2)) {
                        int i3 = i;
                        i++;
                        strArr[i3] = group2.getColumnsIndex().get(i2);
                    }
                }
                Index selectExcept = Index.of(strArr).selectExcept(this.keyColumns);
                Index expand = selectExcept.expand(this.keyColumns);
                StatementBuilder bindBatch = jdbcConnector.createStatementBuilder(createUpdateStatement(this.keyColumns, selectExcept.toArray())).paramDescriptors(fixedParams(expand)).bindBatch(group2.cols(expand).select());
                try {
                    Connection connection = jdbcConnector.getConnection();
                    try {
                        bindBatch.update(connection);
                        if (connection != null) {
                            connection.close();
                        }
                    } finally {
                    }
                } catch (SQLException e) {
                    throw new RuntimeException("Error closing DB connection", e);
                }
            }
        }
    }

    protected BitSet booleansAsBitSet(RowProxy rowProxy) {
        int size = rowProxy.getIndex().size();
        BitSet bitSet = new BitSet(size);
        for (int i = 0; i < size; i++) {
            if (((Boolean) rowProxy.get(i)).booleanValue()) {
                bitSet.set(i);
            }
        }
        return bitSet;
    }

    protected Hasher keyHasher() {
        Hasher of = Hasher.of(this.keyColumns[0]);
        for (int i = 1; i < this.keyColumns.length; i++) {
            of = of.and(this.keyColumns[i]);
        }
        return of;
    }

    protected String createUpdateStatement(String[] strArr, String[] strArr2) {
        StringBuilder append = new StringBuilder("update ").append(this.connector.quoteTableName(this.tableName)).append(" set ").append(this.connector.quoteIdentifier(strArr2[0])).append(" = ?");
        for (int i = 1; i < strArr2.length; i++) {
            append.append(", ").append(this.connector.quoteIdentifier(strArr2[i])).append(" = ?");
        }
        append.append(" where ").append(this.connector.quoteIdentifier(strArr[0])).append(" = ?");
        for (int i2 = 1; i2 < strArr.length; i2++) {
            append.append(" and ").append(this.connector.quoteIdentifier(strArr[i2])).append(" = ?");
        }
        return append.toString();
    }
}
