package com.facebook.presto.sql.planner;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/TestVariableExtractor.class */
public class TestVariableExtractor {
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA);
    private static final TypeProvider SYMBOL_TYPES = TypeProvider.fromVariables(ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "a", BigintType.BIGINT), new VariableReferenceExpression(Optional.empty(), "b", BigintType.BIGINT), new VariableReferenceExpression(Optional.empty(), "c", BigintType.BIGINT)));

    @Test
    public void testSimple() {
        assertVariables("a > b");
        assertVariables("a + b > c");
        assertVariables("sin(a) - b");
        assertVariables("sin(a) + cos(a) - b");
        assertVariables("sin(a) + cos(a) + a - b");
        assertVariables("COALESCE(a, b, 1)");
        assertVariables("a IN (a, b, c)");
        assertVariables("transform(sequence(1, 5), a -> a + b)");
        assertVariables("bigint '1'");
    }

    private static void assertVariables(String str) {
        Expression rewriteIdentifiersToSymbolReferences = ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(str, new ParsingOptions()));
        RowExpression translate = TRANSLATOR.translate(rewriteIdentifiersToSymbolReferences, SYMBOL_TYPES);
        Assert.assertEquals(VariablesExtractor.extractUnique(rewriteIdentifiersToSymbolReferences, SYMBOL_TYPES), VariablesExtractor.extractUnique(translate));
        Assert.assertEquals((Collection) VariablesExtractor.extractAll(rewriteIdentifiersToSymbolReferences, SYMBOL_TYPES).stream().sorted().collect(ImmutableList.toImmutableList()), (Collection) VariablesExtractor.extractAll(translate).stream().sorted().collect(ImmutableList.toImmutableList()));
    }
}
