package org.datavec.spark.transform;

import java.util.List;
import org.apache.commons.math3.util.Pair;
import org.apache.spark.api.java.JavaRDD;
import org.datavec.api.transform.DataAction;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.join.Join;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.reduce.IReducer;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.ConvertToSequence;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.Writable;
import org.datavec.spark.SequenceEmptyRecordFunction;
import org.datavec.spark.functions.EmptyRecordFunction;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.filter.SparkFilterFunction;
import org.datavec.spark.transform.join.ExecuteJoinFlatMapFunction;
import org.datavec.spark.transform.join.MapToJoinValuesFunction;
import org.datavec.spark.transform.misc.ColumnAsKeyPairFunction;
import org.datavec.spark.transform.rank.UnzipForCalculateSortedRankFunction;
import org.datavec.spark.transform.reduce.MapToPairForReducerFunction;
import org.datavec.spark.transform.reduce.ReducerFunction;
import org.datavec.spark.transform.sequence.SparkGroupToSequenceFunction;
import org.datavec.spark.transform.sequence.SparkMapToPairByColumnFunction;
import org.datavec.spark.transform.sequence.SparkSequenceFilterFunction;
import org.datavec.spark.transform.sequence.SparkSequenceTransformFunction;
import org.datavec.spark.transform.transform.SequenceSplitFunction;
import org.datavec.spark.transform.transform.SparkTransformFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/datavec/spark/transform/SparkTransformExecutor.class */
public class SparkTransformExecutor {
    private static final Logger log = LoggerFactory.getLogger(SparkTransformExecutor.class);
    public static final String LOG_ERROR_PROPERTY = "org.datavec.spark.transform.logerrors";

    @Deprecated
    public SparkTransformExecutor() {
    }

    public static JavaRDD<List<Writable>> execute(JavaRDD<List<Writable>> javaRDD, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        return (JavaRDD) execute(javaRDD, null, transformProcess).getFirst();
    }

    public static JavaRDD<List<List<Writable>>> executeToSequence(JavaRDD<List<Writable>> javaRDD, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            return (JavaRDD) execute(javaRDD, null, transformProcess).getSecond();
        }
        throw new IllegalStateException("Cannot return non-sequence data with this method");
    }

    public static JavaRDD<List<Writable>> executeSequenceToSeparate(JavaRDD<List<List<Writable>>> javaRDD, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        return (JavaRDD) execute(null, javaRDD, transformProcess).getFirst();
    }

    public static JavaRDD<List<List<Writable>>> executeSequenceToSequence(JavaRDD<List<List<Writable>>> javaRDD, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            return (JavaRDD) execute(null, javaRDD, transformProcess).getSecond();
        }
        throw new IllegalStateException("Cannot return non-sequence data with this method");
    }

    public static boolean isTryCatch() {
        return Boolean.getBoolean(LOG_ERROR_PROPERTY);
    }

    private static Pair<JavaRDD<List<Writable>>, JavaRDD<List<List<Writable>>>> execute(JavaRDD<List<Writable>> javaRDD, JavaRDD<List<List<Writable>>> javaRDD2, TransformProcess transformProcess) {
        JavaRDD<List<Writable>> javaRDD3 = javaRDD;
        JavaRDD<List<List<Writable>>> javaRDD4 = javaRDD2;
        List<DataAction> actionList = transformProcess.getActionList();
        if (javaRDD != null) {
            List list = (List) javaRDD.first();
            if (list.size() != transformProcess.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input data number of columns (" + list.size() + ") does not match the number of columns for the transform process (" + transformProcess.getInitialSchema().numColumns() + ")");
            }
        } else {
            List list2 = (List) javaRDD2.first();
            if (list2.size() > 0 && ((List) list2.get(0)).size() != transformProcess.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input sequence data number of columns (" + ((List) list2.get(0)).size() + ") does not match the number of columns for the transform process (" + transformProcess.getInitialSchema().numColumns() + ")");
            }
        }
        int i = 1;
        for (DataAction dataAction : actionList) {
            if (dataAction.getTransform() != null) {
                Transform transform = dataAction.getTransform();
                if (javaRDD3 != null) {
                    SparkTransformFunction sparkTransformFunction = new SparkTransformFunction(transform);
                    javaRDD3 = isTryCatch() ? javaRDD3.map(sparkTransformFunction).filter(new EmptyRecordFunction()) : javaRDD3.map(sparkTransformFunction);
                } else {
                    SparkSequenceTransformFunction sparkSequenceTransformFunction = new SparkSequenceTransformFunction(transform);
                    javaRDD4 = isTryCatch() ? javaRDD4.map(sparkSequenceTransformFunction).filter(new SequenceEmptyRecordFunction()) : javaRDD4.map(sparkSequenceTransformFunction);
                }
            } else if (dataAction.getFilter() != null) {
                Filter filter = dataAction.getFilter();
                if (javaRDD3 != null) {
                    javaRDD3 = javaRDD3.filter(new SparkFilterFunction(filter));
                } else {
                    javaRDD4 = javaRDD4.filter(new SparkSequenceFilterFunction(filter));
                }
            } else if (dataAction.getConvertToSequence() != null) {
                ConvertToSequence convertToSequence = dataAction.getConvertToSequence();
                javaRDD4 = javaRDD3.mapToPair(new SparkMapToPairByColumnFunction(convertToSequence.getInputSchema().getIndexOfColumn(convertToSequence.getKeyColumn()))).groupByKey().map(new SparkGroupToSequenceFunction(convertToSequence.getComparator()));
                javaRDD3 = null;
            } else if (dataAction.getConvertFromSequence() != null) {
                if (javaRDD4 == null) {
                    throw new IllegalStateException("Cannot execute ConvertFromSequence operation: current sequence is null");
                }
                javaRDD3 = javaRDD4.flatMap(new SequenceFlatMapFunction());
                javaRDD4 = null;
            } else if (dataAction.getSequenceSplit() != null) {
                SequenceSplit sequenceSplit = dataAction.getSequenceSplit();
                if (javaRDD4 == null) {
                    throw new IllegalStateException("Error during execution of SequenceSplit: currentSequence is null");
                }
                javaRDD4 = javaRDD4.flatMap(new SequenceSplitFunction(sequenceSplit));
            } else if (dataAction.getReducer() != null) {
                IReducer reducer = dataAction.getReducer();
                if (javaRDD3 == null) {
                    throw new IllegalStateException("Error during execution of reduction: current writables are null. Trying to execute a reduce operation on a sequence?");
                }
                javaRDD3 = javaRDD3.mapToPair(new MapToPairForReducerFunction(reducer)).groupByKey().map(new ReducerFunction(reducer));
            } else {
                if (dataAction.getCalculateSortedRank() == null) {
                    throw new RuntimeException("Unknown/not implemented action: " + dataAction);
                }
                CalculateSortedRank calculateSortedRank = dataAction.getCalculateSortedRank();
                if (javaRDD3 == null) {
                    throw new IllegalStateException("Error during execution of CalculateSortedRank: current writables are null. Trying to execute a CalculateSortedRank operation on a sequenc? (not currently supported)");
                }
                javaRDD3 = javaRDD3.mapToPair(new ColumnAsKeyPairFunction(calculateSortedRank.getInputSchema().getIndexOfColumn(calculateSortedRank.getSortOnColumn()))).sortByKey(calculateSortedRank.getComparator(), calculateSortedRank.isAscending()).zipWithIndex().map(new UnzipForCalculateSortedRankFunction());
            }
            i++;
        }
        return new Pair<>(javaRDD3, javaRDD4);
    }

    public static JavaRDD<List<Writable>> executeJoin(Join join, JavaRDD<List<Writable>> javaRDD, JavaRDD<List<Writable>> javaRDD2) {
        return javaRDD.mapToPair(new MapToJoinValuesFunction(true, join)).union(javaRDD2.mapToPair(new MapToJoinValuesFunction(false, join))).groupByKey().flatMap(new ExecuteJoinFlatMapFunction(join));
    }
}
