package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slices;
import io.trino.spi.Plugin;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.Row;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.Assignments;
import io.trino.type.UnknownType;
import java.util.Map;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushCastIntoRow.class */
public class TestPushCastIntoRow extends BaseRuleTest {
    public TestPushCastIntoRow() {
        super(new Plugin[0]);
    }

    @Test
    public void test() {
        test(new Cast(new Row(ImmutableList.of(new Constant(IntegerType.INTEGER, 1L))), RowType.anonymousRow(new Type[]{BigintType.BIGINT})), new Row(ImmutableList.of(new Cast(new Constant(IntegerType.INTEGER, 1L), BigintType.BIGINT))));
        test(new Cast(new Row(ImmutableList.of(new Constant(IntegerType.INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")))), RowType.anonymousRow(new Type[]{BigintType.BIGINT, VarcharType.VARCHAR})), new Row(ImmutableList.of(new Cast(new Constant(IntegerType.INTEGER, 1L), BigintType.BIGINT), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), VarcharType.VARCHAR))));
        test(new Cast(new Cast(new Row(ImmutableList.of(new Constant(IntegerType.INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")))), RowType.anonymousRow(new Type[]{SmallintType.SMALLINT, VarcharType.VARCHAR})), RowType.anonymousRow(new Type[]{BigintType.BIGINT, VarcharType.VARCHAR})), new Row(ImmutableList.of(new Cast(new Cast(new Constant(IntegerType.INTEGER, 1L), SmallintType.SMALLINT), BigintType.BIGINT), new Cast(new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), VarcharType.VARCHAR), VarcharType.VARCHAR))));
        test(new Cast(new Cast(new Row(ImmutableList.of(new Constant(IntegerType.INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")))), RowType.anonymousRow(new Type[]{SmallintType.SMALLINT, VarcharType.VARCHAR})), RowType.rowType(new RowType.Field[]{RowType.field("x", BigintType.BIGINT), RowType.field(VarcharType.VARCHAR)})), new Cast(new Row(ImmutableList.of(new Cast(new Constant(IntegerType.INTEGER, 1L), SmallintType.SMALLINT), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), VarcharType.VARCHAR))), RowType.rowType(new RowType.Field[]{RowType.field("x", BigintType.BIGINT), RowType.field(VarcharType.VARCHAR)})));
        test(new Cast(new Cast(new Row(ImmutableList.of(new Constant(IntegerType.INTEGER, 1L), new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")))), RowType.rowType(new RowType.Field[]{RowType.field("a", SmallintType.SMALLINT), RowType.field("b", VarcharType.VARCHAR)})), RowType.rowType(new RowType.Field[]{RowType.field("x", BigintType.BIGINT), RowType.field(VarcharType.VARCHAR)})), new Cast(new Row(ImmutableList.of(new Cast(new Constant(IntegerType.INTEGER, 1L), SmallintType.SMALLINT), new Cast(new Constant(VarcharType.VARCHAR, Slices.utf8Slice("a")), VarcharType.VARCHAR))), RowType.rowType(new RowType.Field[]{RowType.field("x", BigintType.BIGINT), RowType.field(VarcharType.VARCHAR)})));
        test(new FieldReference(new Cast(new Row(ImmutableList.of(new Constant(IntegerType.INTEGER, 1L))), RowType.anonymousRow(new Type[]{BigintType.BIGINT})), 0), new FieldReference(new Row(ImmutableList.of(new Cast(new Constant(IntegerType.INTEGER, 1L), BigintType.BIGINT))), 0));
        test(new Cast(new Row(ImmutableList.of(new Constant(UnknownType.UNKNOWN, (Object) null))), RowType.anonymousRow(new Type[]{UnknownType.UNKNOWN})), new Row(ImmutableList.of(new Constant(UnknownType.UNKNOWN, (Object) null))));
    }

    private void test(Expression expression, Expression expression2) {
        tester().assertThat(new PushCastIntoRow().projectExpressionRewrite()).on(planBuilder -> {
            return planBuilder.project(Assignments.builder().put(planBuilder.symbol("output", expression.type()), expression).build(), planBuilder.values(new Symbol[0]));
        }).matches(PlanMatchPattern.project(Map.of("output", PlanMatchPattern.expression(expression2)), PlanMatchPattern.values(new String[0])));
    }
}
