Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/controllers/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,13 @@ export default class ConnectionManager {
this.accountStore,
getCloudProviderSettings(account.key.providerId).settings.sqlResource!,
);

connectionInfo.azureAccountToken = profile.azureAccountToken;
connectionInfo.expiresOn = profile.expiresOn;
connectionInfo.accountId = profile.accountId;
connectionInfo.tenantId = profile.tenantId;
connectionInfo.user = profile.user;
connectionInfo.email = profile.email;
} else {
throw new Error(LocalizedConstants.cannotConnect);
}
Expand Down
11 changes: 9 additions & 2 deletions src/controllers/mainController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,9 @@ export default class MainController implements vscode.Disposable {
nodeUri,
connectionCreds,
);
if (!isConnected) {
if (isConnected) {
node.updateEntraTokenInfo(connectionCreds); // may be updated Entra token after connect() call
} else {
/**
* The connection wasn't successful. Stopping scripting operation.
* Not throwing an error because the user is already notified of
Expand All @@ -754,6 +756,8 @@ export default class MainController implements vscode.Disposable {
connectionStrategy: ConnectionStrategy.CopyConnectionFromInfo,
connectionInfo: connectionCreds,
});

node.updateEntraTokenInfo(connectionCreds); // newQuery calls connect() internally, so may be updated Entra token
if (executeScript) {
const preventAutoExecute = vscode.workspace
.getConfiguration()
Expand Down Expand Up @@ -1102,7 +1106,10 @@ export default class MainController implements vscode.Disposable {
connectionUri,
connectionCreds,
);
if (!connectionResult) {

if (connectionResult) {
node.updateEntraTokenInfo(connectionCreds);
} else {
return;
}
}
Expand Down
29 changes: 23 additions & 6 deletions src/controllers/sqlDocumentService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { SqlOutputContentProvider } from "../models/sqlOutputContentProvider";
import StatusView from "../views/statusView";
import store from "../queryResult/singletonStore";
import SqlToolsServerClient from "../languageservice/serviceclient";
import { getUriKey } from "../utils/utils";
import { removeUndefinedProperties, getUriKey } from "../utils/utils";
import * as Utils from "../models/utils";
import * as Constants from "../constants/constants";
import * as LocalizedConstants from "../constants/locConstants";
Expand Down Expand Up @@ -104,30 +104,31 @@ export default class SqlDocumentService implements vscode.Disposable {
return false;
}

let connectionCreds: vscodeMssql.IConnectionInfo | undefined;
let connectionStrategy: ConnectionStrategy;
let nodeType: string | undefined;
let sourceNode: TreeNodeInfo | undefined;

if (node) {
// Case 1: User right-clicked on an OE node and selected "New Query"
connectionCreds = node.connectionProfile;
nodeType = node.nodeType;
connectionStrategy = ConnectionStrategy.CopyConnectionFromInfo;
sourceNode = node;
} else if (this._lastActiveConnectionInfo) {
// Case 2: User triggered "New Query" from command palette and the active document has a connection
connectionCreds = undefined;
nodeType = "previousEditor";
connectionStrategy = ConnectionStrategy.CopyLastActive;
} else if (this.objectExplorerTree.selection?.length === 1) {
// Case 3: User triggered "New Query" from command palette while they have a connected OE node selected
connectionCreds = this.objectExplorerTree.selection[0].connectionProfile;
nodeType = this.objectExplorerTree.selection[0].nodeType;
sourceNode = this.objectExplorerTree.selection[0];
nodeType = sourceNode.nodeType;
connectionStrategy = ConnectionStrategy.CopyConnectionFromInfo;
} else {
// Case 4: User triggered "New Query" from command palette and there's no reasonable context
connectionStrategy = ConnectionStrategy.PromptForConnection;
}

const connectionCreds = sourceNode?.connectionProfile;

if (connectionCreds) {
await this._connectionMgr.handlePasswordBasedCredentials(connectionCreds);
}
Expand All @@ -138,6 +139,11 @@ export default class SqlDocumentService implements vscode.Disposable {
connectionInfo: connectionCreds,
});

if (sourceNode && connectionCreds) {
// newQuery may refresh the Entra token, so update the OE node's connection profile
sourceNode.updateEntraTokenInfo(connectionCreds);
}

const newEditorUri = getUriKey(newEditor.document.uri);

const connectionResult = this._connectionMgr.getConnectionInfo(newEditorUri);
Expand Down Expand Up @@ -388,6 +394,17 @@ export default class SqlDocumentService implements vscode.Disposable {
connectionConfig.connectionInfo,
);
}

if (options.connectionInfo && connectionConfig.connectionInfo) {
const tokenUpdates = removeUndefinedProperties({
azureAccountToken: connectionConfig.connectionInfo.azureAccountToken,
expiresOn: connectionConfig.connectionInfo.expiresOn,
});

if (Object.keys(tokenUpdates).length > 0) {
Object.assign(options.connectionInfo, tokenUpdates);
}
}
}
}

Expand Down
25 changes: 25 additions & 0 deletions src/objectExplorer/nodes/treeNodeInfo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import * as Constants from "../../constants/constants";
import { ITreeNodeInfo, ObjectMetadata } from "vscode-mssql";
import { IConnectionProfile } from "../../models/interfaces";
import { generateGuid } from "../../models/utils";
import { removeUndefinedProperties } from "../../utils/utils";

export class TreeNodeInfo extends vscode.TreeItem implements ITreeNodeInfo {
private _nodePath: string;
Expand Down Expand Up @@ -234,10 +235,34 @@ export class TreeNodeInfo extends vscode.TreeItem implements ITreeNodeInfo {
public set loadingLabel(value: string) {
this._loadingLabel = value;
}

public updateConnectionProfile(value: IConnectionProfile): void {
this._connectionProfile = value;
}

public updateEntraTokenInfo(updatedCredentials: vscodeMssql.IConnectionInfo): void {
if (!updatedCredentials) {
return;
}

const updatedEntraTokenInfo = removeUndefinedProperties({
azureAccountToken: updatedCredentials.azureAccountToken,
expiresOn: updatedCredentials.expiresOn,
});

if (Object.keys(updatedEntraTokenInfo).length === 0) {
// no refreshed token info to persist
return;
}

const updatedProfile: IConnectionProfile = {
...this.connectionProfile,
...updatedEntraTokenInfo,
};

this.updateConnectionProfile(updatedProfile);
}

protected updateMetadata(value: ObjectMetadata): void {
this._metadata = value;
}
Expand Down
46 changes: 23 additions & 23 deletions src/reactviews/pages/PublishProject/types.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { FormContextProps } from "../../../sharedInterfaces/form";
import {
IPublishForm,
PublishDialogFormItemSpec,
PublishDialogState,
} from "../../../sharedInterfaces/publishDialog";
/**
* Extended context type used across all publish project components.
* Combines the base form context with publish-specific actions.
*/
export interface PublishFormContext
extends FormContextProps<IPublishForm, PublishDialogState, PublishDialogFormItemSpec> {
publishNow: () => void;
generatePublishScript: () => void;
selectPublishProfile: () => void;
savePublishProfile: (profileName: string) => void;
}
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import { FormContextProps } from "../../../sharedInterfaces/form";
import {
IPublishForm,
PublishDialogFormItemSpec,
PublishDialogState,
} from "../../../sharedInterfaces/publishDialog";

/**
* Extended context type used across all publish project components.
* Combines the base form context with publish-specific actions.
*/
export interface PublishFormContext
extends FormContextProps<IPublishForm, PublishDialogState, PublishDialogFormItemSpec> {
publishNow: () => void;
generatePublishScript: () => void;
selectPublishProfile: () => void;
savePublishProfile: (profileName: string) => void;
}
1 change: 1 addition & 0 deletions src/schemaDesigner/schemaDesignerWebviewManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export class SchemaDesignerWebviewManager {
await mainController.connectionManager.createConnectionDetails(connectionInfo);

await mainController.connectionManager.confirmEntraTokenValidity(connectionInfo);
treeNode.updateConnectionProfile(connectionInfo);

connectionString = await mainController.connectionManager.getConnectionString(
connectionDetails,
Expand Down
1 change: 1 addition & 0 deletions src/tableDesigner/tableDesignerWebviewController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ export class TableDesignerWebviewController extends ReactWebviewPanelController<
);

await this._connectionManager.confirmEntraTokenValidity(connectionInfo);
this._targetNode.updateConnectionProfile(connectionInfo);
const accessToken = connectionInfo.azureAccountToken
? connectionInfo.azureAccountToken
: undefined;
Expand Down
13 changes: 13 additions & 0 deletions src/utils/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@ export function parseEnum<T extends Record<string, string | number>>(
return undefined;
}

/**
* Removes all properties with undefined values from the given object. Null values are kept.
* @returns a Partial of the original object type with only defined (including null) properties.
*/
export function removeUndefinedProperties<T extends object>(source: T): Partial<T> {
if (!source) {
return {};
}

const entries = Object.entries(source).filter(([_key, value]) => value !== undefined);
return Object.fromEntries(entries) as Partial<T>;
}

/**
* Checks if any required fields are missing values in a form.
* Used to determine if form submission buttons should be disabled.
Expand Down
2 changes: 1 addition & 1 deletion test/unit/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
- Do not edit application/source files unless the refactor demands it. Confirm before editing files outside of /test/unit, and justify why you need to make those changes.
- Use Sinon, not TypeMoq. If easily possible, replace TypeMoq mocks/stubs/helpers with Sinon equivalents.
- Use a Sinon sandbox (setup/teardown with sinon.createSandbox()); keep helper closures (e.g., createServer) inside setup where the
sandbox is created.
sandbox is created. Similarly, let the teardown handle all the stub restores wherever possible; avoid manual restore() calls in tests unless the test design needs the stub behavior changed.
- Default to chai.expect; when checking Sinon interactions, use sinon-chai.
- Avoid Object.defineProperty hacks and (if possible) fake/partial plain objects; use sandbox.createStubInstance(type) and sandbox.stub(obj, 'prop').value(...).
- Add shared Sinon helpers to test/unit/utils.ts when they’ll be reused.
Expand Down
84 changes: 72 additions & 12 deletions test/unit/sqlDocumentService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import * as vscode from "vscode";
import * as sinon from "sinon";
import sinonChai from "sinon-chai";
import * as vscode from "vscode";
import { expect } from "chai";
import * as chai from "chai";
import * as Constants from "../../src/constants/constants";
import * as LocalizedConstants from "../../src/constants/locConstants";
import MainController from "../../src/controllers/mainController";
import ConnectionManager from "../../src/controllers/connectionManager";
import ConnectionManager, { ConnectionInfo } from "../../src/controllers/connectionManager";
import SqlDocumentService, { ConnectionStrategy } from "../../src/controllers/sqlDocumentService";
import * as Telemetry from "../../src/telemetry/telemetry";
import SqlToolsServerClient from "../../src/languageservice/serviceclient";
import { IConnectionInfo } from "vscode-mssql";
import { IConnectionInfo, IServerInfo } from "vscode-mssql";
import { TreeNodeInfo } from "../../src/objectExplorer/nodes/treeNodeInfo";
import { IConnectionProfile } from "../../src/models/interfaces";
import { ConnectionStore } from "../../src/models/connectionStore";
import { stubTelemetry } from "./utils";

chai.use(sinonChai);

Expand Down Expand Up @@ -101,6 +104,7 @@ suite("SqlDocumentService Tests", () => {
});

test("handleNewQueryCommand should create a new query and update recents", async () => {
stubTelemetry(sandbox);
const editor: vscode.TextEditor = {
document: { uri: "test_uri" },
} as any;
Expand All @@ -111,19 +115,17 @@ suite("SqlDocumentService Tests", () => {
};
connectionManager.getServerInfo.returns(undefined as any);
connectionManager.handlePasswordBasedCredentials.resolves();
const sendActionStub = sandbox.stub(Telemetry, "sendActionEvent");

const node: any = { connectionProfile: {}, nodeType: "Server" };
const node: TreeNodeInfo = sandbox.createStubInstance(TreeNodeInfo);
sandbox.stub(node, "connectionProfile").get(() => ({}) as IConnectionProfile);
sandbox.stub(node, "nodeType").get(() => "Server");

await sqlDocumentService.handleNewQueryCommand(node, undefined);

expect(newQueryStub).to.have.been.calledOnce;
expect((connectionManager as any).connectionStore.removeRecentlyUsed).to.have.been
.calledOnce;
expect(connectionManager.handlePasswordBasedCredentials).to.have.been.calledOnce;
expect(sendActionStub).to.have.been.calledOnce;

newQueryStub.restore();
sendActionStub.restore();
});

test("handleNewQueryCommand should not create a new connection if new query fails", async () => {
Expand Down Expand Up @@ -168,9 +170,14 @@ suite("SqlDocumentService Tests", () => {
});

test("handleNewQueryCommand uses OE selection when exactly one node is selected", async () => {
const nodeConnection = { server: "oeServer" } as any;
const nodeConnection = { server: "oeServer" } as IConnectionProfile;

const selectedNode: TreeNodeInfo = sandbox.createStubInstance(TreeNodeInfo);
sandbox.stub(selectedNode, "connectionProfile").get(() => nodeConnection);
sandbox.stub(selectedNode, "nodeType").get(() => "Server");

mainController.objectExplorerTree = {
selection: [{ connectionProfile: nodeConnection, nodeType: "Database" }],
selection: [selectedNode],
} as any;
connectionManager.handlePasswordBasedCredentials.resolves();
connectionManager.connectionStore = {
Expand All @@ -192,6 +199,59 @@ suite("SqlDocumentService Tests", () => {
newQueryStub.restore();
});

test("handleNewQueryCommand refreshes Entra token info on source node", async () => {
stubTelemetry(sandbox);

const oldToken = {
azureAccountToken: "oldToken",
expiresOn: Date.now() / 1000 - 60, // 60 seconds in the past; not that the test actually requires this to be expired
};

const newToken = {
azureAccountToken: "refreshedToken",
expiresOn: oldToken.expiresOn + 600 + 60, // 10 minutes in the future (plus making up for the past offset)
};

const nodeConnection = {
server: "server",
...oldToken,
} as IConnectionProfile;

const node = {
connectionProfile: nodeConnection,
nodeType: "Server",
updateEntraTokenInfo: sandbox.stub(),
} as unknown as TreeNodeInfo;

connectionManager.handlePasswordBasedCredentials.resolves();

const connectionStoreStub = sandbox.createStubInstance(ConnectionStore);

connectionManager.connectionStore = connectionStoreStub;
connectionManager.getServerInfo.returns({} as IServerInfo);
connectionManager.getConnectionInfo.returns({} as ConnectionInfo);

const editor: vscode.TextEditor = {
document: { uri: vscode.Uri.parse("untitled:tokenTest") },
} as vscode.TextEditor;

sandbox.stub(sqlDocumentService, "newQuery").callsFake(async (opts) => {
expect(opts.connectionInfo).to.equal(nodeConnection);
Object.assign(opts.connectionInfo, newToken);

return editor;
});

expect(nodeConnection.azureAccountToken).to.equal(oldToken.azureAccountToken);
expect(nodeConnection.expiresOn).to.equal(oldToken.expiresOn);

await sqlDocumentService.handleNewQueryCommand(node, undefined);

expect(node.updateEntraTokenInfo).to.have.been.calledOnceWith(nodeConnection);
expect(nodeConnection.azureAccountToken).to.equal(newToken.azureAccountToken);
expect(nodeConnection.expiresOn).to.equal(newToken.expiresOn);
});

test("handleNewQueryCommand prompts for connection when no context", async () => {
// clear last active and OE selection
sqlDocumentService["_lastActiveConnectionInfo"] = undefined;
Expand Down
Loading