/*
 * Decompiled with CFR 0.152.
 */
package apoc.export.arrow;

import apoc.Extended;
import apoc.Pools;
import apoc.export.util.BatchTransaction;
import apoc.export.util.ProgressReporter;
import apoc.export.util.Reporter;
import apoc.result.ImportProgressInfo;
import apoc.result.ProgressInfo;
import apoc.util.ExtendedUtil;
import apoc.util.FileUtils;
import apoc.util.Util;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.channels.SeekableByteChannel;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Label;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.RelationshipType;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.graphdb.security.URLAccessValidationError;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@Extended
public class ImportArrow {
    public static String FIELD_ID = "<id>";
    public static String FIELD_LABELS = "labels";
    public static String FIELD_SOURCE_ID = "<source.id>";
    public static String FIELD_TARGET_ID = "<target.id>";
    public static String FIELD_TYPE = "<type>";
    @Context
    public Pools pools;
    @Context
    public GraphDatabaseService db;
    @Context
    public URLAccessChecker urlAccessChecker;

    @Procedure(name="apoc.import.arrow", mode=Mode.WRITE)
    @Description(value="Imports entities from the provided Arrow file or byte array")
    public Stream<ImportProgressInfo> importFile(@Name(value="input") Object input, @Name(value="config", defaultValue="{}") Map<String, Object> config) throws Exception {
        ImportProgressInfo result = (ImportProgressInfo)Util.inThread((Pools)this.pools, () -> {
            String file = null;
            String sourceInfo = "binary";
            if (input instanceof String) {
                file = (String)input;
                sourceInfo = "file";
            }
            ArrowConfig conf = new ArrowConfig(config);
            HashMap<Long, Long> idMapping = new HashMap<Long, Long>();
            AtomicInteger counter = new AtomicInteger();
            try (ArrowReader reader = this.getReader(input);){
                ImportProgressInfo importProgressInfo;
                block22: {
                    VectorSchemaRoot schemaRoot = reader.getVectorSchemaRoot();
                    try {
                        ProgressReporter reporter = new ProgressReporter(null, null, (ProgressInfo)new ImportProgressInfo(file, sourceInfo, "arrow"));
                        try (BatchTransaction btx = new BatchTransaction(this.db, conf.getBatchSize(), (Reporter)reporter);){
                            while (ImportArrow.hasElements(counter, reader, schemaRoot)) {
                                Map row = schemaRoot.getFieldVectors().stream().collect(HashMap::new, (map, fieldVector) -> {
                                    Object read = ImportArrow.read(fieldVector, counter.get(), conf);
                                    if (read == null) {
                                        return;
                                    }
                                    map.put(fieldVector.getName(), read);
                                }, HashMap::putAll);
                                String relType = (String)row.remove(FIELD_TYPE);
                                if (relType == null) {
                                    String[] stringLabels = (String[])row.remove(FIELD_LABELS);
                                    Label[] labels = Optional.ofNullable(stringLabels).map(l -> (Label[])Arrays.stream(l).map(Label::label).toArray(Label[]::new)).orElse(new Label[0]);
                                    Node node = btx.getTransaction().createNode(labels);
                                    long id = (Long)row.remove(FIELD_ID);
                                    idMapping.put(id, node.getId());
                                    this.addProps(row, (Entity)node);
                                    reporter.update(1L, 0L, (long)row.size());
                                } else {
                                    long sourceId = (Long)row.remove(FIELD_SOURCE_ID);
                                    Long idSource = (Long)idMapping.get(sourceId);
                                    Node source = btx.getTransaction().getNodeById(idSource.longValue());
                                    long targetId = (Long)row.remove(FIELD_TARGET_ID);
                                    Long idTarget = (Long)idMapping.get(targetId);
                                    Node target = btx.getTransaction().getNodeById(idTarget.longValue());
                                    Relationship rel = source.createRelationshipTo(target, RelationshipType.withName((String)relType));
                                    this.addProps(row, (Entity)rel);
                                    reporter.update(0L, 1L, (long)row.size());
                                }
                                counter.incrementAndGet();
                                btx.increment();
                            }
                            btx.doCommit();
                        }
                        importProgressInfo = (ImportProgressInfo)reporter.getTotal();
                        if (schemaRoot == null) break block22;
                    }
                    catch (Throwable throwable) {
                        if (schemaRoot != null) {
                            try {
                                schemaRoot.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    schemaRoot.close();
                }
                return importProgressInfo;
            }
        });
        return Stream.of(result);
    }

    private ArrowReader getReader(Object input) throws IOException, URISyntaxException, URLAccessValidationError {
        RootAllocator allocator = new RootAllocator();
        if (input instanceof String) {
            SeekableByteChannel channel = FileUtils.inputStreamFor((Object)input, null, null, null, (URLAccessChecker)this.urlAccessChecker).asChannel();
            return new ArrowFileReader(channel, (BufferAllocator)allocator);
        }
        ByteArrayInputStream inputStream = new ByteArrayInputStream((byte[])input);
        return new ArrowStreamReader((InputStream)inputStream, (BufferAllocator)allocator);
    }

    private static boolean hasElements(AtomicInteger counter, ArrowReader reader, VectorSchemaRoot schemaRoot) throws IOException {
        if (counter.get() >= schemaRoot.getRowCount()) {
            if (reader.loadNextBatch()) {
                counter.set(0);
            } else {
                return false;
            }
        }
        return true;
    }

    private static Object read(FieldVector fieldVector, int index, ArrowConfig conf) {
        Collection coll;
        if (fieldVector.isNull(index)) {
            return null;
        }
        if (fieldVector instanceof BitVector) {
            BitVector fe = (BitVector)fieldVector;
            return fe.get(index) == 1;
        }
        Object object = fieldVector.getObject(index);
        if (object instanceof Collection && (coll = (Collection)object).isEmpty()) {
            return null;
        }
        return ExtendedUtil.toValidValue(object, fieldVector.getName(), conf.getMapping());
    }

    private void addProps(Map<String, Object> row, Entity rel) {
        row.forEach((arg_0, arg_1) -> ((Entity)rel).setProperty(arg_0, arg_1));
    }

    public static class ArrowConfig {
        private final int batchSize;
        private final Map<String, Object> mapping;

        public ArrowConfig(Map<String, Object> config) {
            if (config == null) {
                config = Collections.emptyMap();
            }
            this.mapping = config.getOrDefault("mapping", Map.of());
            this.batchSize = Util.toInteger((Object)config.getOrDefault("batchSize", 2000));
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public Map<String, Object> getMapping() {
            return this.mapping;
        }
    }
}

