package org.datavec.api.transform.transform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude;

@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonIgnoreProperties({"columnIdxs", "inputSchema"})
/* loaded from: input_file:org/datavec/api/transform/transform/BaseColumnsMathOpTransform.class */
public abstract class BaseColumnsMathOpTransform implements Transform {
    protected final String newColumnName;
    protected final MathOp mathOp;
    protected final String[] columns;
    private int[] columnIdxs;
    private Schema inputSchema;

    public BaseColumnsMathOpTransform(String str, MathOp mathOp, String... strArr) {
        if (strArr == null || strArr.length == 0) {
            throw new IllegalArgumentException("Invalid input: cannot have null/0 columns");
        }
        this.newColumnName = str;
        this.mathOp = mathOp;
        this.columns = strArr;
        switch (mathOp) {
            case Add:
                if (strArr.length < 2) {
                    throw new IllegalArgumentException("Need 2 or more columns for Add op. Got: " + Arrays.toString(strArr));
                }
                return;
            case Subtract:
                if (strArr.length != 2) {
                    throw new IllegalArgumentException("Need exactly 2 columns for Subtract op. Got: " + Arrays.toString(strArr));
                }
                return;
            case Multiply:
                if (strArr.length < 2) {
                    throw new IllegalArgumentException("Need 2 or more columns for Multiply op. Got: " + Arrays.toString(strArr));
                }
                return;
            case Divide:
                if (strArr.length != 2) {
                    throw new IllegalArgumentException("Need exactly 2 columns for Divide op. Got: " + Arrays.toString(strArr));
                }
                return;
            case Modulus:
                if (strArr.length != 2) {
                    throw new IllegalArgumentException("Need exactly 2 columns for Modulus op. Got: " + Arrays.toString(strArr));
                }
                return;
            case ReverseSubtract:
            case ReverseDivide:
            case ScalarMin:
            case ScalarMax:
                throw new IllegalArgumentException("Invalid MathOp: cannot use " + mathOp + " with ...ColumnsMathOpTransform");
            default:
                throw new RuntimeException("Unknown MathOp: " + mathOp);
        }
    }

    @Override // org.datavec.api.transform.Transform
    public Schema transform(Schema schema) {
        for (String str : this.columns) {
            if (!schema.hasColumn(str)) {
                throw new IllegalStateException("Input schema does not have column with name \"" + str + "\"");
            }
        }
        ArrayList arrayList = new ArrayList(schema.getColumnMetaData());
        arrayList.add(derivedColumnMetaData(this.newColumnName));
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.Transform
    public void setInputSchema(Schema schema) {
        this.columnIdxs = new int[this.columns.length];
        int i = 0;
        for (String str : this.columns) {
            if (!schema.hasColumn(str)) {
                throw new IllegalStateException("Input schema does not have column with name \"" + str + "\"");
            }
            int i2 = i;
            i++;
            this.columnIdxs[i2] = schema.getIndexOfColumn(str);
        }
        this.inputSchema = schema;
    }

    @Override // org.datavec.api.transform.Transform
    public Schema getInputSchema() {
        return this.inputSchema;
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (this.inputSchema == null) {
            throw new IllegalStateException("Input schema has not been set");
        }
        ArrayList arrayList = new ArrayList(list);
        Writable[] writableArr = new Writable[this.columns.length];
        for (int i = 0; i < this.columnIdxs.length; i++) {
            writableArr[i] = (Writable) arrayList.get(this.columnIdxs[i]);
        }
        arrayList.add(doOp(writableArr));
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Writable>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(map(it.next()));
        }
        return arrayList;
    }

    protected abstract ColumnMetaData derivedColumnMetaData(String str);

    protected abstract Writable doOp(Writable... writableArr);

    public abstract String toString();

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BaseColumnsMathOpTransform)) {
            return false;
        }
        BaseColumnsMathOpTransform baseColumnsMathOpTransform = (BaseColumnsMathOpTransform) obj;
        if (!baseColumnsMathOpTransform.canEqual(this)) {
            return false;
        }
        String str = this.newColumnName;
        String str2 = baseColumnsMathOpTransform.newColumnName;
        if (str == null) {
            if (str2 != null) {
                return false;
            }
        } else if (!str.equals(str2)) {
            return false;
        }
        MathOp mathOp = this.mathOp;
        MathOp mathOp2 = baseColumnsMathOpTransform.mathOp;
        if (mathOp == null) {
            if (mathOp2 != null) {
                return false;
            }
        } else if (!mathOp.equals(mathOp2)) {
            return false;
        }
        return Arrays.deepEquals(this.columns, baseColumnsMathOpTransform.columns);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BaseColumnsMathOpTransform;
    }

    public int hashCode() {
        String str = this.newColumnName;
        int hashCode = (1 * 59) + (str == null ? 43 : str.hashCode());
        MathOp mathOp = this.mathOp;
        return (((hashCode * 59) + (mathOp == null ? 43 : mathOp.hashCode())) * 59) + Arrays.deepHashCode(this.columns);
    }
}
