package org.datavec.api.transform.transform.string;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"inputSchema", "map", "columnIdx"})
/* loaded from: input_file:org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.class */
public class StringListToCategoricalSetTransform extends BaseTransform {
    private final String columnName;
    private final List<String> newColumnNames;
    private final List<String> categoryTokens;
    private final String delimiter;
    private final Map<String, Integer> map;
    private int columnIdx = -1;

    public StringListToCategoricalSetTransform(@JsonProperty("columnName") String str, @JsonProperty("newColumnNames") List<String> list, @JsonProperty("categoryTokens") List<String> list2, @JsonProperty("delimiter") String str2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("Names/tokens sizes cannot differ");
        }
        this.columnName = str;
        this.newColumnNames = list;
        this.categoryTokens = list2;
        this.delimiter = str2;
        this.map = new HashMap();
        for (int i = 0; i < list2.size(); i++) {
            this.map.put(list2.get(i), Integer.valueOf(i));
        }
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        int indexOfColumn = schema.getIndexOfColumn(this.columnName);
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        ArrayList arrayList = new ArrayList((columnMetaData.size() + this.newColumnNames.size()) - 1);
        List<String> columnNames = schema.getColumnNames();
        Iterator<String> it = columnNames.iterator();
        int i = 0;
        for (ColumnMetaData columnMetaData2 : columnMetaData) {
            it.next();
            int i2 = i;
            i++;
            if (i2 != indexOfColumn) {
                arrayList.add(columnMetaData2);
            } else {
                if (columnMetaData2.getColumnType() != ColumnType.String) {
                    throw new IllegalStateException("Cannot convert non-string type");
                }
                for (int i3 = 0; i3 < this.newColumnNames.size(); i3++) {
                    arrayList.add(new CategoricalMetaData(this.newColumnNames.get(i3), "true", "false"));
                }
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.transform.BaseTransform, org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        this.inputSchema = schema;
        this.columnIdx = schema.getIndexOfColumn(this.columnName);
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "StringListToCategoricalSetTransform(columnName=" + this.columnName + ",newColumnNames=" + this.newColumnNames + ",categoryTokens=" + this.categoryTokens + ",delimiter=\"" + this.delimiter + "\")";
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (list.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + list.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + toString());
        }
        ArrayList arrayList = new ArrayList(list.size());
        int i = 0;
        for (Writable writable : list) {
            int i2 = i;
            i++;
            if (i2 == this.columnIdx) {
                String obj = writable.toString();
                boolean[] zArr = new boolean[this.categoryTokens.size()];
                if (obj != null && !obj.isEmpty()) {
                    for (String str : obj.split(this.delimiter)) {
                        Integer num = this.map.get(str);
                        if (num == null) {
                            throw new IllegalStateException("Encountered unknown String: \"" + str + "\"");
                        }
                        zArr[num.intValue()] = true;
                    }
                }
                for (boolean z : zArr) {
                    arrayList.add(new Text(z ? "true" : "false"));
                }
            } else {
                arrayList.add(writable);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        return null;
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        return null;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        throw new UnsupportedOperationException("New column names is always more than 1 in length");
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return (String[]) this.newColumnNames.toArray(new String[this.newColumnNames.size()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{columnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return columnName();
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof StringListToCategoricalSetTransform)) {
            return false;
        }
        StringListToCategoricalSetTransform stringListToCategoricalSetTransform = (StringListToCategoricalSetTransform) obj;
        if (!stringListToCategoricalSetTransform.canEqual(this)) {
            return false;
        }
        String str = this.columnName;
        String str2 = stringListToCategoricalSetTransform.columnName;
        if (str == null) {
            if (str2 != null) {
                return false;
            }
        } else if (!str.equals(str2)) {
            return false;
        }
        List<String> list = this.newColumnNames;
        List<String> list2 = stringListToCategoricalSetTransform.newColumnNames;
        if (list == null) {
            if (list2 != null) {
                return false;
            }
        } else if (!list.equals(list2)) {
            return false;
        }
        List<String> list3 = this.categoryTokens;
        List<String> list4 = stringListToCategoricalSetTransform.categoryTokens;
        if (list3 == null) {
            if (list4 != null) {
                return false;
            }
        } else if (!list3.equals(list4)) {
            return false;
        }
        String str3 = this.delimiter;
        String str4 = stringListToCategoricalSetTransform.delimiter;
        if (str3 == null) {
            if (str4 != null) {
                return false;
            }
        } else if (!str3.equals(str4)) {
            return false;
        }
        Map<String, Integer> map = this.map;
        Map<String, Integer> map2 = stringListToCategoricalSetTransform.map;
        return map == null ? map2 == null : map.equals(map2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof StringListToCategoricalSetTransform;
    }

    public int hashCode() {
        String str = this.columnName;
        int hashCode = (1 * 59) + (str == null ? 43 : str.hashCode());
        List<String> list = this.newColumnNames;
        int hashCode2 = (hashCode * 59) + (list == null ? 43 : list.hashCode());
        List<String> list2 = this.categoryTokens;
        int hashCode3 = (hashCode2 * 59) + (list2 == null ? 43 : list2.hashCode());
        String str2 = this.delimiter;
        int hashCode4 = (hashCode3 * 59) + (str2 == null ? 43 : str2.hashCode());
        Map<String, Integer> map = this.map;
        return (hashCode4 * 59) + (map == null ? 43 : map.hashCode());
    }
}
