package io.seata.rm.datasource.exec;

import io.seata.common.exception.NotSupportYetException;
import io.seata.common.exception.ShouldNeverHappenException;
import io.seata.common.util.StringUtils;
import io.seata.rm.datasource.PreparedStatementProxy;
import io.seata.rm.datasource.StatementProxy;
import io.seata.rm.datasource.sql.SQLInsertRecognizer;
import io.seata.rm.datasource.sql.SQLRecognizer;
import io.seata.rm.datasource.sql.struct.ColumnMeta;
import io.seata.rm.datasource.sql.struct.Null;
import io.seata.rm.datasource.sql.struct.SqlMethodExpr;
import io.seata.rm.datasource.sql.struct.SqlSequenceExpr;
import io.seata.rm.datasource.sql.struct.TableRecords;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/seata/rm/datasource/exec/InsertExecutor.class */
public class InsertExecutor<T, S extends Statement> extends AbstractDMLBaseExecutor<T, S> {
    private static final Logger LOGGER = LoggerFactory.getLogger(InsertExecutor.class);
    protected static final String ERR_SQL_STATE = "S1009";
    private static final String PLACEHOLDER = "?";

    public InsertExecutor(StatementProxy statementProxy, StatementCallback statementCallback, SQLRecognizer sQLRecognizer) {
        super(statementProxy, statementCallback, sQLRecognizer);
    }

    @Override // io.seata.rm.datasource.exec.AbstractDMLBaseExecutor
    protected TableRecords beforeImage() throws SQLException {
        return TableRecords.empty(getTableMeta());
    }

    @Override // io.seata.rm.datasource.exec.AbstractDMLBaseExecutor
    protected TableRecords afterImage(TableRecords tableRecords) throws SQLException {
        TableRecords buildTableRecords = buildTableRecords(containsPK() ? getPkValuesByColumn() : containsColumns() ? getPkValuesByAuto() : getPkValuesByColumn());
        if (buildTableRecords == null) {
            throw new SQLException("Failed to build after-image for insert");
        }
        return buildTableRecords;
    }

    protected boolean containsPK() {
        return getTableMeta().containsPK(((SQLInsertRecognizer) this.sqlRecognizer).getInsertColumns());
    }

    protected boolean containsColumns() {
        List<String> insertColumns = ((SQLInsertRecognizer) this.sqlRecognizer).getInsertColumns();
        return (insertColumns == null || insertColumns.isEmpty()) ? false : true;
    }

    protected List<Object> getPkValuesByColumn() throws SQLException {
        SQLInsertRecognizer sQLInsertRecognizer = (SQLInsertRecognizer) this.sqlRecognizer;
        int pkIndex = getPkIndex();
        List<Object> list = null;
        if (this.statementProxy instanceof PreparedStatementProxy) {
            PreparedStatementProxy preparedStatementProxy = (PreparedStatementProxy) this.statementProxy;
            List<List<Object>> insertRows = sQLInsertRecognizer.getInsertRows();
            if (insertRows != null && !insertRows.isEmpty()) {
                ArrayList<Object>[] parameters = preparedStatementProxy.getParameters();
                int size = insertRows.size();
                if (size == 1) {
                    list = PLACEHOLDER.equals(insertRows.get(0).get(pkIndex)) ? parameters[pkIndex] : (List) insertRows.stream().map(list2 -> {
                        return list2.get(pkIndex);
                    }).collect(Collectors.toList());
                } else {
                    int i = -1;
                    list = new ArrayList(size);
                    for (int i2 = 0; i2 < size; i2++) {
                        List<Object> list3 = insertRows.get(i2);
                        Object obj = list3.get(pkIndex);
                        int i3 = -1;
                        Iterator<Object> it = list3.iterator();
                        while (it.hasNext()) {
                            if (PLACEHOLDER.equals(it.next())) {
                                i++;
                                i3++;
                            }
                        }
                        if (PLACEHOLDER.equals(obj)) {
                            int i4 = pkIndex;
                            if (i2 != 0) {
                                i4 = (i - i3) + pkIndex;
                            }
                            Iterator<Object> it2 = parameters[i4].iterator();
                            while (it2.hasNext()) {
                                list.add(it2.next());
                            }
                        } else {
                            list.add(obj);
                        }
                    }
                }
            }
        } else {
            List<List<Object>> insertRows2 = sQLInsertRecognizer.getInsertRows();
            list = new ArrayList(insertRows2.size());
            Iterator<List<Object>> it3 = insertRows2.iterator();
            while (it3.hasNext()) {
                list.add(it3.next().get(pkIndex));
            }
        }
        if (list == null) {
            throw new ShouldNeverHappenException();
        }
        if (!checkPkValues(list)) {
            throw new NotSupportYetException("not support sql [" + this.sqlRecognizer.getOriginalSQL() + "]");
        }
        if (list.size() > 0 && (list.get(0) instanceof SqlSequenceExpr)) {
            list = getPkValuesBySequence(list.get(0));
        } else if (list.size() == 1 && (list.get(0) instanceof SqlMethodExpr)) {
            list = getPkValuesByAuto();
        } else if (list.size() > 0 && (list.get(0) instanceof Null)) {
            list = getPkValuesByAuto();
        }
        return list;
    }

    protected List<Object> getPkValuesBySequence(Object obj) throws SQLException {
        try {
            return oracleByAuto();
        } catch (NotSupportYetException | SQLException e) {
            if (!(obj instanceof SqlSequenceExpr)) {
                throw new NotSupportYetException(String.format("not support expr [%s]", obj.getClass().getName()));
            }
            String str = "SELECT " + ((SqlSequenceExpr) obj).getSequence() + ".currval FROM DUAL";
            LOGGER.warn("Fail to get auto-generated keys, use '{}' instead. Be cautious, statement could be polluted. Recommend you set the statement to return generated keys.", str);
            ResultSet executeQuery = this.statementProxy.getConnection().createStatement().executeQuery(str);
            ArrayList arrayList = new ArrayList();
            while (executeQuery.next()) {
                arrayList.add(executeQuery.getObject(1));
            }
            return arrayList;
        }
    }

    protected List<Object> getPkValuesByAuto() throws SQLException {
        return StringUtils.equalsIgnoreCase("oracle", getDbType()) ? oracleByAuto() : defaultByAuto();
    }

    protected int getPkIndex() {
        SQLInsertRecognizer sQLInsertRecognizer = (SQLInsertRecognizer) this.sqlRecognizer;
        String pkName = getTableMeta().getPkName();
        List<String> insertColumns = sQLInsertRecognizer.getInsertColumns();
        if (insertColumns == null || insertColumns.isEmpty()) {
            int i = -1;
            Iterator<Map.Entry<String, ColumnMeta>> it = getTableMeta().getAllColumns().entrySet().iterator();
            while (it.hasNext()) {
                i++;
                if (it.next().getValue().getColumnName().equalsIgnoreCase(pkName)) {
                    break;
                }
            }
            return i;
        }
        int size = insertColumns.size();
        int i2 = -1;
        int i3 = 0;
        while (true) {
            if (i3 >= size) {
                break;
            }
            if (insertColumns.get(i3).equalsIgnoreCase(pkName)) {
                i2 = i3;
                break;
            }
            i3++;
        }
        return i2;
    }

    private boolean checkPkValues(List<Object> list) {
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        if (list.size() == 1) {
            return true;
        }
        for (Object obj : list) {
            if (obj instanceof Null) {
                z = true;
            } else {
                z2 = true;
                if (obj instanceof SqlMethodExpr) {
                    z3 = true;
                }
            }
        }
        if (z3) {
            return false;
        }
        return (z && z2) ? false : true;
    }

    private List<Object> defaultByAuto() throws SQLException {
        ResultSet executeQuery;
        Map<String, ColumnMeta> primaryKeyMap = getTableMeta().getPrimaryKeyMap();
        if (primaryKeyMap.size() != 1) {
            throw new NotSupportYetException();
        }
        if (!primaryKeyMap.values().iterator().next().isAutoincrement()) {
            throw new ShouldNeverHappenException();
        }
        try {
            executeQuery = this.statementProxy.getTargetStatement().getGeneratedKeys();
        } catch (SQLException e) {
            if (!ERR_SQL_STATE.equalsIgnoreCase(e.getSQLState())) {
                throw e;
            }
            LOGGER.warn("Fail to get auto-generated keys, use 'SELECT LAST_INSERT_ID()' instead. Be cautious, statement could be polluted. Recommend you set the statement to return generated keys.");
            executeQuery = this.statementProxy.getTargetStatement().executeQuery("SELECT LAST_INSERT_ID()");
        }
        ArrayList arrayList = new ArrayList();
        while (executeQuery.next()) {
            arrayList.add(executeQuery.getObject(1));
        }
        try {
            executeQuery.beforeFirst();
        } catch (SQLException e2) {
            LOGGER.warn("Fail to reset ResultSet cursor. can not get primary key value");
        }
        return arrayList;
    }

    private List<Object> oracleByAuto() throws SQLException {
        if (getTableMeta().getPrimaryKeyMap().size() != 1) {
            throw new NotSupportYetException();
        }
        try {
            ResultSet generatedKeys = this.statementProxy.getTargetStatement().getGeneratedKeys();
            ArrayList arrayList = new ArrayList();
            while (generatedKeys.next()) {
                arrayList.add(generatedKeys.getObject(1));
            }
            if (arrayList.isEmpty()) {
                throw new NotSupportYetException("not support sql [" + this.sqlRecognizer.getOriginalSQL() + "]");
            }
            return arrayList;
        } catch (SQLException e) {
            throw e;
        }
    }
}
