Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -452,12 +452,20 @@ public String getSchema() {

@Override
public Map<String, String> getSessionConfigs() {
return this.parameters.entrySet().stream()
.filter(
e ->
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keySet().stream()
.anyMatch(allowedConf -> allowedConf.toLowerCase().equals(e.getKey())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
Map<String, String> sessionConfigs =
this.parameters.entrySet().stream()
.filter(
e ->
ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keySet().stream()
.anyMatch(allowedConf -> allowedConf.toLowerCase().equals(e.getKey())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

// Add metric view metadata configuration if enabled
if (getEnableMetricViewMetadata()) {
sessionConfigs.put("spark.sql.thriftserver.metadata.metricview.enabled", "true");
}

return sessionConfigs;
}

@Override
Expand Down Expand Up @@ -979,6 +987,11 @@ public boolean enableShowCommandsForGetFunctions() {
return getParameter(DatabricksJdbcUrlParams.ENABLE_SHOW_COMMAND_FOR_GET_FUNCTIONS).equals("1");
}

@Override
public boolean getEnableMetricViewMetadata() {
return getParameter(DatabricksJdbcUrlParams.ENABLE_METRIC_VIEW_METADATA).equals("1");
}

private static boolean nullOrEmptyString(String s) {
return s == null || s.isEmpty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,4 +371,7 @@ public interface IDatabricksConnectionContext {

/** Returns whether transaction-related method calls should be ignored */
boolean getIgnoreTransactions();

/* Returns whether metric view metadata is enabled */
boolean getEnableMetricViewMetadata();
}
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ public enum DatabricksJdbcUrlParams {
"EnableSQLValidationForIsValid",
"Enable SQL query execution for connection validation in isValid() method",
"0"),
IGNORE_TRANSACTIONS("IgnoreTransactions", "Ignore transaction-related method calls", "0");
IGNORE_TRANSACTIONS("IgnoreTransactions", "Ignore transaction-related method calls", "0"),
ENABLE_METRIC_VIEW_METADATA("EnableMetricViewMetadata", "Enable metric view metadata", "0");

private final String paramName;
private final String defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ public class MetadataResultConstants {
PRIMARY_KEY_NAME_COLUMN);

public static List<List<Object>> TABLE_TYPES_ROWS =
Arrays.asList(List.of("SYSTEM TABLE"), List.of("TABLE"), List.of("VIEW"));
Arrays.asList(
List.of("SYSTEM TABLE"), List.of("TABLE"), List.of("VIEW"), List.of("METRIC_VIEW"));

public static List<ResultColumn> TABLE_TYPE_COLUMNS = List.of(TABLE_TYPE_COLUMN);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public class DatabricksTypeUtil {
public static final String VARIANT = "VARIANT";
public static final String CHAR = "CHAR";
public static final String INTERVAL = "INTERVAL";
public static final String MEASURE = "measure";
private static final ArrayList<ColumnInfoTypeName> SIGNED_TYPES =
new ArrayList<>(
Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.databricks.jdbc.common.MetadataResultConstants.*;
import static com.databricks.jdbc.common.util.DatabricksTypeUtil.INTERVAL;
import static com.databricks.jdbc.common.util.DatabricksTypeUtil.MEASURE;
import static com.databricks.jdbc.common.util.WildcardUtil.isNullOrEmpty;
import static com.databricks.jdbc.dbclient.impl.common.CommandConstants.*;
import static com.databricks.jdbc.dbclient.impl.common.TypeValConstants.*;
Expand Down Expand Up @@ -96,9 +97,16 @@ public DatabricksResultSet getTablesResult(DatabricksResultSet resultSet, String
}

public DatabricksResultSet getTableTypesResult() {
List<List<Object>> tableTypesRows =
ctx.getEnableMetricViewMetadata()
? TABLE_TYPES_ROWS
: TABLE_TYPES_ROWS.stream()
.filter(row -> !"METRIC_VIEW".equals(row.get(0)))
.collect(Collectors.toList());

return buildResultSet(
TABLE_TYPE_COLUMNS,
TABLE_TYPES_ROWS,
tableTypesRows,
GET_TABLE_TYPE_STATEMENT_ID,
CommandName.LIST_TABLE_TYPES);
}
Expand Down Expand Up @@ -267,7 +275,8 @@ List<List<Object>> getRows(
// Handle TYPE_NAME separately for potential modifications
if (mappedColumn.getColumnName().equals(COLUMN_TYPE_COLUMN.getColumnName())) {
if (typeVal != null
&& (typeVal.contains(ARRAY_TYPE)
&& (typeVal.contains(MEASURE)
|| typeVal.contains(ARRAY_TYPE)
|| typeVal.contains(MAP_TYPE)
|| typeVal.contains(
STRUCT_TYPE))) { // for complex data types, do not strip type name
Expand Down Expand Up @@ -965,7 +974,8 @@ List<List<Object>> getThriftRows(List<List<Object>> rows, List<ResultColumn> col
// Handle TYPE_NAME separately for potential modifications
if (column.getColumnName().equals(COLUMN_TYPE_COLUMN.getColumnName())) {
if (typeVal != null
&& (typeVal.contains(ARRAY_TYPE)
&& (typeVal.contains(MEASURE)
|| typeVal.contains(ARRAY_TYPE)
|| typeVal.contains(MAP_TYPE)
|| typeVal.contains(STRUCT_TYPE))) {
object = typeVal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,4 +564,32 @@ public void testIgnoreTransactionsEnabled() throws Exception {

connection.close();
}

@Test
public void testMetricViewMetadataInSessionConfigs() throws SQLException {
String metricViewEnabledUrl = JDBC_URL + ";EnableMetricViewMetadata=1";
IDatabricksConnectionContext connectionContextEnabled =
DatabricksConnectionContext.parse(metricViewEnabledUrl, new Properties());

Map<String, String> sessionConfigsEnabled = connectionContextEnabled.getSessionConfigs();
assertTrue(
sessionConfigsEnabled.containsKey("spark.sql.thriftserver.metadata.metricview.enabled"));
assertEquals(
"true", sessionConfigsEnabled.get("spark.sql.thriftserver.metadata.metricview.enabled"));

String metricViewDisabledUrl = JDBC_URL + ";EnableMetricViewMetadata=0";
IDatabricksConnectionContext connectionContextDisabled =
DatabricksConnectionContext.parse(metricViewDisabledUrl, new Properties());

Map<String, String> sessionConfigsDisabled = connectionContextDisabled.getSessionConfigs();
assertFalse(
sessionConfigsDisabled.containsKey("spark.sql.thriftserver.metadata.metricview.enabled"));

IDatabricksConnectionContext connectionContextDefault =
DatabricksConnectionContext.parse(JDBC_URL, new Properties());

Map<String, String> sessionConfigsDefault = connectionContextDefault.getSessionConfigs();
assertFalse(
sessionConfigsDefault.containsKey("spark.sql.thriftserver.metadata.metricview.enabled"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ private static Stream<Arguments> getRowsColumnTypeArguments() {
Arguments.of("DECIMAL", "DECIMAL"),
Arguments.of("DECIMAL(6,2)", "DECIMAL"),
Arguments.of("MAP<STRING, ARRAY<STRING>>", "MAP<STRING, ARRAY<STRING>>"),
Arguments.of("ARRAY<DOUBLE>", "ARRAY<DOUBLE>"));
Arguments.of("ARRAY<DOUBLE>", "ARRAY<DOUBLE>"),
Arguments.of("BINGINT measure", "BINGINT measure"));
}

@ParameterizedTest
Expand Down Expand Up @@ -311,6 +312,16 @@ void testGetThriftRowsWithRowIndexOutOfBounds() {
assertNull(updatedRow.get(1));
}

@Test
void testGetThriftRowsMeasureColumn() {
List<ResultColumn> columns = List.of(COLUMN_TYPE_COLUMN);
List<Object> row = List.of("DECIMAL(6,2) measure");
List<List<Object>> updatedRows = metadataResultSetBuilder.getThriftRows(List.of(row), columns);
List<Object> updatedRow = updatedRows.get(0);
// verify that type name for measure column is not stripped
assertEquals("DECIMAL(6,2) measure", updatedRow.get(0));
}

@ParameterizedTest
@MethodSource("provideSpecialColumnsArguments")
void testGetThriftRowsSpecialColumns(List<Object> row, List<Object> expectedRow) {
Expand Down Expand Up @@ -594,4 +605,44 @@ void testDecimalDigitsColumnInGetRows(

assertEquals(expectedScale, rows.get(0).get(8), message);
}

@Test
void testGetTableTypesResultWithMetricViewEnabled() throws SQLException {
// Test when EnableMetricViewMetadata=true
when(connectionContext.getEnableMetricViewMetadata()).thenReturn(true);

DatabricksResultSet resultSet = metadataResultSetBuilder.getTableTypesResult();

// Verify we get 4 table types including METRIC_VIEW
List<String> tableTypes = new ArrayList<>();
while (resultSet.next()) {
tableTypes.add(resultSet.getString("TABLE_TYPE"));
}

assertEquals(4, tableTypes.size());
assertTrue(tableTypes.contains("SYSTEM TABLE"));
assertTrue(tableTypes.contains("TABLE"));
assertTrue(tableTypes.contains("VIEW"));
assertTrue(tableTypes.contains("METRIC_VIEW"));
}

@Test
void testGetTableTypesResultWithMetricViewDisabled() throws SQLException {
// Test when EnableMetricViewMetadata=false
when(connectionContext.getEnableMetricViewMetadata()).thenReturn(false);

DatabricksResultSet resultSet = metadataResultSetBuilder.getTableTypesResult();

// Verify we get 3 table types without METRIC_VIEW
List<String> tableTypes = new ArrayList<>();
while (resultSet.next()) {
tableTypes.add(resultSet.getString("TABLE_TYPE"));
}

assertEquals(3, tableTypes.size());
assertTrue(tableTypes.contains("SYSTEM TABLE"));
assertTrue(tableTypes.contains("TABLE"));
assertTrue(tableTypes.contains("VIEW"));
assertFalse(tableTypes.contains("METRIC_VIEW"));
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
package com.databricks.jdbc.dbclient.impl.sqlexec;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.when;

import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.api.internal.IDatabricksSession;
import java.sql.ResultSet;
import java.sql.SQLException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
public class DatabricksEmptyMetadataClientTest {

@Mock private IDatabricksSession session;
@Mock private IDatabricksConnectionContext ctx;
private final DatabricksEmptyMetadataClient mockClient = new DatabricksEmptyMetadataClient(ctx);
private DatabricksEmptyMetadataClient mockClient;

@BeforeEach
void setUp() {
mockClient = new DatabricksEmptyMetadataClient(ctx);
}

@Test
void testListTypeInfo() throws SQLException {
Expand Down Expand Up @@ -60,6 +70,8 @@ void testListTables() throws SQLException {

@Test
void testListTableTypes() throws SQLException {
when(ctx.getEnableMetricViewMetadata()).thenReturn(false);

ResultSet resultSet = mockClient.listTableTypes(session);
assertNotNull(resultSet);
assertEquals(resultSet.getMetaData().getColumnCount(), 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.databricks.jdbc.api.impl.DatabricksResultSet;
import com.databricks.jdbc.api.impl.DatabricksResultSetMetaData;
import com.databricks.jdbc.api.impl.ImmutableSqlParameter;
import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.api.internal.IDatabricksSession;
import com.databricks.jdbc.common.CommandName;
import com.databricks.jdbc.common.IDatabricksComputeResource;
Expand Down Expand Up @@ -195,6 +196,12 @@ void testListCatalogs() throws SQLException {

@Test
void testListTableTypes() throws SQLException {
// Mock the connection context for the table types test
IDatabricksConnectionContext mockConnectionContext =
org.mockito.Mockito.mock(IDatabricksConnectionContext.class);
when(mockConnectionContext.getEnableMetricViewMetadata()).thenReturn(false);
when(mockClient.getConnectionContext()).thenReturn(mockConnectionContext);

DatabricksMetadataSdkClient metadataClient = new DatabricksMetadataSdkClient(mockClient);
DatabricksResultSet actualResult = metadataClient.listTableTypes(session);
assertEquals(actualResult.getStatementStatus().getState(), StatementState.SUCCEEDED);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ void testGetResultChunksThrowsError() throws SQLException {

@Test
void testListTableTypes() throws SQLException {
// Mock connection context to disable metric view metadata by default
when(connectionContext.getEnableMetricViewMetadata()).thenReturn(false);

DatabricksThriftServiceClient client =
new DatabricksThriftServiceClient(thriftAccessor, connectionContext);
DatabricksResultSet actualResult = client.listTableTypes(session);
Expand Down
Loading