diff --git a/src/controllers/connectionManager.ts b/src/controllers/connectionManager.ts index 393c282519..f4ff97e9c2 100644 --- a/src/controllers/connectionManager.ts +++ b/src/controllers/connectionManager.ts @@ -1072,6 +1072,13 @@ export default class ConnectionManager { this.accountStore, getCloudProviderSettings(account.key.providerId).settings.sqlResource!, ); + + connectionInfo.azureAccountToken = profile.azureAccountToken; + connectionInfo.expiresOn = profile.expiresOn; + connectionInfo.accountId = profile.accountId; + connectionInfo.tenantId = profile.tenantId; + connectionInfo.user = profile.user; + connectionInfo.email = profile.email; } else { throw new Error(LocalizedConstants.cannotConnect); } diff --git a/src/controllers/mainController.ts b/src/controllers/mainController.ts index d309c4cb23..f6c735f777 100644 --- a/src/controllers/mainController.ts +++ b/src/controllers/mainController.ts @@ -731,7 +731,9 @@ export default class MainController implements vscode.Disposable { nodeUri, connectionCreds, ); - if (!isConnected) { + if (isConnected) { + node.updateEntraTokenInfo(connectionCreds); // may be updated Entra token after connect() call + } else { /** * The connection wasn't successful. Stopping scripting operation. * Not throwing an error because the user is already notified of @@ -754,6 +756,8 @@ export default class MainController implements vscode.Disposable { connectionStrategy: ConnectionStrategy.CopyConnectionFromInfo, connectionInfo: connectionCreds, }); + + node.updateEntraTokenInfo(connectionCreds); // newQuery calls connect() internally, so may be updated Entra token if (executeScript) { const preventAutoExecute = vscode.workspace .getConfiguration() @@ -1102,7 +1106,10 @@ export default class MainController implements vscode.Disposable { connectionUri, connectionCreds, ); - if (!connectionResult) { + + if (connectionResult) { + node.updateEntraTokenInfo(connectionCreds); + } else { return; } } diff --git a/src/controllers/sqlDocumentService.ts b/src/controllers/sqlDocumentService.ts index e261134421..6d065500e8 100644 --- a/src/controllers/sqlDocumentService.ts +++ b/src/controllers/sqlDocumentService.ts @@ -9,7 +9,7 @@ import { SqlOutputContentProvider } from "../models/sqlOutputContentProvider"; import StatusView from "../views/statusView"; import store from "../queryResult/singletonStore"; import SqlToolsServerClient from "../languageservice/serviceclient"; -import { getUriKey } from "../utils/utils"; +import { removeUndefinedProperties, getUriKey } from "../utils/utils"; import * as Utils from "../models/utils"; import * as Constants from "../constants/constants"; import * as LocalizedConstants from "../constants/locConstants"; @@ -104,30 +104,31 @@ export default class SqlDocumentService implements vscode.Disposable { return false; } - let connectionCreds: vscodeMssql.IConnectionInfo | undefined; let connectionStrategy: ConnectionStrategy; let nodeType: string | undefined; + let sourceNode: TreeNodeInfo | undefined; if (node) { // Case 1: User right-clicked on an OE node and selected "New Query" - connectionCreds = node.connectionProfile; nodeType = node.nodeType; connectionStrategy = ConnectionStrategy.CopyConnectionFromInfo; + sourceNode = node; } else if (this._lastActiveConnectionInfo) { // Case 2: User triggered "New Query" from command palette and the active document has a connection - connectionCreds = undefined; nodeType = "previousEditor"; connectionStrategy = ConnectionStrategy.CopyLastActive; } else if (this.objectExplorerTree.selection?.length === 1) { // Case 3: User triggered "New Query" from command palette while they have a connected OE node selected - connectionCreds = this.objectExplorerTree.selection[0].connectionProfile; - nodeType = this.objectExplorerTree.selection[0].nodeType; + sourceNode = this.objectExplorerTree.selection[0]; + nodeType = sourceNode.nodeType; connectionStrategy = ConnectionStrategy.CopyConnectionFromInfo; } else { // Case 4: User triggered "New Query" from command palette and there's no reasonable context connectionStrategy = ConnectionStrategy.PromptForConnection; } + const connectionCreds = sourceNode?.connectionProfile; + if (connectionCreds) { await this._connectionMgr.handlePasswordBasedCredentials(connectionCreds); } @@ -138,6 +139,11 @@ export default class SqlDocumentService implements vscode.Disposable { connectionInfo: connectionCreds, }); + if (sourceNode && connectionCreds) { + // newQuery may refresh the Entra token, so update the OE node's connection profile + sourceNode.updateEntraTokenInfo(connectionCreds); + } + const newEditorUri = getUriKey(newEditor.document.uri); const connectionResult = this._connectionMgr.getConnectionInfo(newEditorUri); @@ -388,6 +394,17 @@ export default class SqlDocumentService implements vscode.Disposable { connectionConfig.connectionInfo, ); } + + if (options.connectionInfo && connectionConfig.connectionInfo) { + const tokenUpdates = removeUndefinedProperties({ + azureAccountToken: connectionConfig.connectionInfo.azureAccountToken, + expiresOn: connectionConfig.connectionInfo.expiresOn, + }); + + if (Object.keys(tokenUpdates).length > 0) { + Object.assign(options.connectionInfo, tokenUpdates); + } + } } } diff --git a/src/objectExplorer/nodes/treeNodeInfo.ts b/src/objectExplorer/nodes/treeNodeInfo.ts index dc0f9b1c1e..7d259a6e80 100644 --- a/src/objectExplorer/nodes/treeNodeInfo.ts +++ b/src/objectExplorer/nodes/treeNodeInfo.ts @@ -11,6 +11,7 @@ import * as Constants from "../../constants/constants"; import { ITreeNodeInfo, ObjectMetadata } from "vscode-mssql"; import { IConnectionProfile } from "../../models/interfaces"; import { generateGuid } from "../../models/utils"; +import { removeUndefinedProperties } from "../../utils/utils"; export class TreeNodeInfo extends vscode.TreeItem implements ITreeNodeInfo { private _nodePath: string; @@ -234,10 +235,34 @@ export class TreeNodeInfo extends vscode.TreeItem implements ITreeNodeInfo { public set loadingLabel(value: string) { this._loadingLabel = value; } + public updateConnectionProfile(value: IConnectionProfile): void { this._connectionProfile = value; } + public updateEntraTokenInfo(updatedCredentials: vscodeMssql.IConnectionInfo): void { + if (!updatedCredentials) { + return; + } + + const updatedEntraTokenInfo = removeUndefinedProperties({ + azureAccountToken: updatedCredentials.azureAccountToken, + expiresOn: updatedCredentials.expiresOn, + }); + + if (Object.keys(updatedEntraTokenInfo).length === 0) { + // no refreshed token info to persist + return; + } + + const updatedProfile: IConnectionProfile = { + ...this.connectionProfile, + ...updatedEntraTokenInfo, + }; + + this.updateConnectionProfile(updatedProfile); + } + protected updateMetadata(value: ObjectMetadata): void { this._metadata = value; } diff --git a/src/reactviews/pages/PublishProject/types.ts b/src/reactviews/pages/PublishProject/types.ts index fac24956d8..3c57d6040d 100644 --- a/src/reactviews/pages/PublishProject/types.ts +++ b/src/reactviews/pages/PublishProject/types.ts @@ -1,23 +1,23 @@ -/*--------------------------------------------------------------------------------------------- - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See License.txt in the project root for license information. - *--------------------------------------------------------------------------------------------*/ - -import { FormContextProps } from "../../../sharedInterfaces/form"; -import { - IPublishForm, - PublishDialogFormItemSpec, - PublishDialogState, -} from "../../../sharedInterfaces/publishDialog"; - -/** - * Extended context type used across all publish project components. - * Combines the base form context with publish-specific actions. - */ -export interface PublishFormContext - extends FormContextProps { - publishNow: () => void; - generatePublishScript: () => void; - selectPublishProfile: () => void; - savePublishProfile: (profileName: string) => void; -} +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import { FormContextProps } from "../../../sharedInterfaces/form"; +import { + IPublishForm, + PublishDialogFormItemSpec, + PublishDialogState, +} from "../../../sharedInterfaces/publishDialog"; + +/** + * Extended context type used across all publish project components. + * Combines the base form context with publish-specific actions. + */ +export interface PublishFormContext + extends FormContextProps { + publishNow: () => void; + generatePublishScript: () => void; + selectPublishProfile: () => void; + savePublishProfile: (profileName: string) => void; +} diff --git a/src/schemaDesigner/schemaDesignerWebviewManager.ts b/src/schemaDesigner/schemaDesignerWebviewManager.ts index df66d71277..865435ad1e 100644 --- a/src/schemaDesigner/schemaDesignerWebviewManager.ts +++ b/src/schemaDesigner/schemaDesignerWebviewManager.ts @@ -64,6 +64,7 @@ export class SchemaDesignerWebviewManager { await mainController.connectionManager.createConnectionDetails(connectionInfo); await mainController.connectionManager.confirmEntraTokenValidity(connectionInfo); + treeNode.updateConnectionProfile(connectionInfo); connectionString = await mainController.connectionManager.getConnectionString( connectionDetails, diff --git a/src/tableDesigner/tableDesignerWebviewController.ts b/src/tableDesigner/tableDesignerWebviewController.ts index c955a97e07..dafa4038f3 100644 --- a/src/tableDesigner/tableDesignerWebviewController.ts +++ b/src/tableDesigner/tableDesignerWebviewController.ts @@ -157,6 +157,7 @@ export class TableDesignerWebviewController extends ReactWebviewPanelController< ); await this._connectionManager.confirmEntraTokenValidity(connectionInfo); + this._targetNode.updateConnectionProfile(connectionInfo); const accessToken = connectionInfo.azureAccountToken ? connectionInfo.azureAccountToken : undefined; diff --git a/src/utils/utils.ts b/src/utils/utils.ts index 1b962f1fbb..4f779a5faa 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -143,6 +143,19 @@ export function parseEnum>( return undefined; } +/** + * Removes all properties with undefined values from the given object. Null values are kept. + * @returns a Partial of the original object type with only defined (including null) properties. + */ +export function removeUndefinedProperties(source: T): Partial { + if (!source) { + return {}; + } + + const entries = Object.entries(source).filter(([_key, value]) => value !== undefined); + return Object.fromEntries(entries) as Partial; +} + /** * Checks if any required fields are missing values in a form. * Used to determine if form submission buttons should be disabled. diff --git a/test/unit/AGENTS.md b/test/unit/AGENTS.md index 5fa0da47ce..6dc753d451 100644 --- a/test/unit/AGENTS.md +++ b/test/unit/AGENTS.md @@ -5,7 +5,7 @@ - Do not edit application/source files unless the refactor demands it. Confirm before editing files outside of /test/unit, and justify why you need to make those changes. - Use Sinon, not TypeMoq. If easily possible, replace TypeMoq mocks/stubs/helpers with Sinon equivalents. - Use a Sinon sandbox (setup/teardown with sinon.createSandbox()); keep helper closures (e.g., createServer) inside setup where the - sandbox is created. + sandbox is created. Similarly, let the teardown handle all the stub restores wherever possible; avoid manual restore() calls in tests unless the test design needs the stub behavior changed. - Default to chai.expect; when checking Sinon interactions, use sinon-chai. - Avoid Object.defineProperty hacks and (if possible) fake/partial plain objects; use sandbox.createStubInstance(type) and sandbox.stub(obj, 'prop').value(...). - Add shared Sinon helpers to test/unit/utils.ts when they’ll be reused. diff --git a/test/unit/sqlDocumentService.test.ts b/test/unit/sqlDocumentService.test.ts index a4e24a6755..178928f02f 100644 --- a/test/unit/sqlDocumentService.test.ts +++ b/test/unit/sqlDocumentService.test.ts @@ -3,19 +3,22 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import * as vscode from "vscode"; import * as sinon from "sinon"; import sinonChai from "sinon-chai"; -import * as vscode from "vscode"; import { expect } from "chai"; import * as chai from "chai"; import * as Constants from "../../src/constants/constants"; import * as LocalizedConstants from "../../src/constants/locConstants"; import MainController from "../../src/controllers/mainController"; -import ConnectionManager from "../../src/controllers/connectionManager"; +import ConnectionManager, { ConnectionInfo } from "../../src/controllers/connectionManager"; import SqlDocumentService, { ConnectionStrategy } from "../../src/controllers/sqlDocumentService"; -import * as Telemetry from "../../src/telemetry/telemetry"; import SqlToolsServerClient from "../../src/languageservice/serviceclient"; -import { IConnectionInfo } from "vscode-mssql"; +import { IConnectionInfo, IServerInfo } from "vscode-mssql"; +import { TreeNodeInfo } from "../../src/objectExplorer/nodes/treeNodeInfo"; +import { IConnectionProfile } from "../../src/models/interfaces"; +import { ConnectionStore } from "../../src/models/connectionStore"; +import { stubTelemetry } from "./utils"; chai.use(sinonChai); @@ -101,6 +104,7 @@ suite("SqlDocumentService Tests", () => { }); test("handleNewQueryCommand should create a new query and update recents", async () => { + stubTelemetry(sandbox); const editor: vscode.TextEditor = { document: { uri: "test_uri" }, } as any; @@ -111,19 +115,17 @@ suite("SqlDocumentService Tests", () => { }; connectionManager.getServerInfo.returns(undefined as any); connectionManager.handlePasswordBasedCredentials.resolves(); - const sendActionStub = sandbox.stub(Telemetry, "sendActionEvent"); - const node: any = { connectionProfile: {}, nodeType: "Server" }; + const node: TreeNodeInfo = sandbox.createStubInstance(TreeNodeInfo); + sandbox.stub(node, "connectionProfile").get(() => ({}) as IConnectionProfile); + sandbox.stub(node, "nodeType").get(() => "Server"); + await sqlDocumentService.handleNewQueryCommand(node, undefined); expect(newQueryStub).to.have.been.calledOnce; expect((connectionManager as any).connectionStore.removeRecentlyUsed).to.have.been .calledOnce; expect(connectionManager.handlePasswordBasedCredentials).to.have.been.calledOnce; - expect(sendActionStub).to.have.been.calledOnce; - - newQueryStub.restore(); - sendActionStub.restore(); }); test("handleNewQueryCommand should not create a new connection if new query fails", async () => { @@ -168,9 +170,14 @@ suite("SqlDocumentService Tests", () => { }); test("handleNewQueryCommand uses OE selection when exactly one node is selected", async () => { - const nodeConnection = { server: "oeServer" } as any; + const nodeConnection = { server: "oeServer" } as IConnectionProfile; + + const selectedNode: TreeNodeInfo = sandbox.createStubInstance(TreeNodeInfo); + sandbox.stub(selectedNode, "connectionProfile").get(() => nodeConnection); + sandbox.stub(selectedNode, "nodeType").get(() => "Server"); + mainController.objectExplorerTree = { - selection: [{ connectionProfile: nodeConnection, nodeType: "Database" }], + selection: [selectedNode], } as any; connectionManager.handlePasswordBasedCredentials.resolves(); connectionManager.connectionStore = { @@ -192,6 +199,59 @@ suite("SqlDocumentService Tests", () => { newQueryStub.restore(); }); + test("handleNewQueryCommand refreshes Entra token info on source node", async () => { + stubTelemetry(sandbox); + + const oldToken = { + azureAccountToken: "oldToken", + expiresOn: Date.now() / 1000 - 60, // 60 seconds in the past; not that the test actually requires this to be expired + }; + + const newToken = { + azureAccountToken: "refreshedToken", + expiresOn: oldToken.expiresOn + 600 + 60, // 10 minutes in the future (plus making up for the past offset) + }; + + const nodeConnection = { + server: "server", + ...oldToken, + } as IConnectionProfile; + + const node = { + connectionProfile: nodeConnection, + nodeType: "Server", + updateEntraTokenInfo: sandbox.stub(), + } as unknown as TreeNodeInfo; + + connectionManager.handlePasswordBasedCredentials.resolves(); + + const connectionStoreStub = sandbox.createStubInstance(ConnectionStore); + + connectionManager.connectionStore = connectionStoreStub; + connectionManager.getServerInfo.returns({} as IServerInfo); + connectionManager.getConnectionInfo.returns({} as ConnectionInfo); + + const editor: vscode.TextEditor = { + document: { uri: vscode.Uri.parse("untitled:tokenTest") }, + } as vscode.TextEditor; + + sandbox.stub(sqlDocumentService, "newQuery").callsFake(async (opts) => { + expect(opts.connectionInfo).to.equal(nodeConnection); + Object.assign(opts.connectionInfo, newToken); + + return editor; + }); + + expect(nodeConnection.azureAccountToken).to.equal(oldToken.azureAccountToken); + expect(nodeConnection.expiresOn).to.equal(oldToken.expiresOn); + + await sqlDocumentService.handleNewQueryCommand(node, undefined); + + expect(node.updateEntraTokenInfo).to.have.been.calledOnceWith(nodeConnection); + expect(nodeConnection.azureAccountToken).to.equal(newToken.azureAccountToken); + expect(nodeConnection.expiresOn).to.equal(newToken.expiresOn); + }); + test("handleNewQueryCommand prompts for connection when no context", async () => { // clear last active and OE selection sqlDocumentService["_lastActiveConnectionInfo"] = undefined; diff --git a/test/unit/treeNodeInfo.test.ts b/test/unit/treeNodeInfo.test.ts index fcfc193d1e..5f70be6db3 100644 --- a/test/unit/treeNodeInfo.test.ts +++ b/test/unit/treeNodeInfo.test.ts @@ -3,14 +3,33 @@ * Licensed under the MIT License. See License.txt in the project root for license information. *--------------------------------------------------------------------------------------------*/ +import * as vscode from "vscode"; +import * as sinon from "sinon"; +import sinonChai from "sinon-chai"; import { expect } from "chai"; +import * as chai from "chai"; + +import type { IConnectionInfo } from "vscode-mssql"; import { TreeNodeInfo } from "../../src/objectExplorer/nodes/treeNodeInfo"; import { initializeIconUtils } from "./utils"; +import { IConnectionProfile } from "../../src/models/interfaces"; +import { azureMfa } from "../../src/constants/constants"; +import { removeUndefinedProperties } from "../../src/utils/utils"; + +chai.use(sinonChai); suite("TreeNodeInfo", () => { + let sandbox: sinon.SinonSandbox; + setup(() => { + sandbox = sinon.createSandbox(); initializeIconUtils(); }); + + teardown(() => { + sandbox.restore(); + }); + test("When creating multiple TreeNodeInfo in quick succession, the nodePath should be unique", () => { const node1 = new TreeNodeInfo( "node_label", @@ -39,6 +58,111 @@ suite("TreeNodeInfo", () => { undefined, undefined, ); + expect(node1.id).to.not.equal(node2.id, "Node IDs should be unique"); }); + + test("loadingLabel can be set and retrieved", () => { + const newLabel = "TestValue"; + const node = createTreeNode(); + expect(node.loadingLabel).to.not.equal(newLabel); + + node.loadingLabel = newLabel; + expect(node.loadingLabel).to.equal(newLabel); + }); + + suite("updateEntraTokenInfo", () => { + test("updates only Entra token fields when refreshed token is provided", () => { + const oldToken = { + azureAccountToken: "oldToken", + expiresOn: Date.now() / 1000 - 60, // 60 seconds in the past; not that the test actually requires this to be expired + }; + + const newToken = { + azureAccountToken: "refreshedToken", + expiresOn: oldToken.expiresOn + 600 + 60, // 10 minutes in the future (plus making up for the past offset) + }; + + const node = createTreeNode({ + server: "testServer", + authenticationType: azureMfa, + user: "test@contoso.com", + ...oldToken, + }); + const updateSpy = sandbox.spy(node, "updateConnectionProfile"); + + expect(node.connectionProfile.azureAccountToken).to.equal(oldToken.azureAccountToken); + expect(node.connectionProfile.expiresOn).to.equal(oldToken.expiresOn); + + node.updateEntraTokenInfo({ + ...newToken, + } as IConnectionProfile); + + expect(updateSpy).to.have.been.calledOnce; + expect(node.connectionProfile.azureAccountToken).to.equal(newToken.azureAccountToken); + expect(node.connectionProfile.expiresOn).to.equal(newToken.expiresOn); + expect(node.connectionProfile.server, "Server should not be changed").to.equal( + "testServer", + ); + expect( + node.connectionProfile.authenticationType, + "authenticationType should not be changed", + ).to.equal(azureMfa); + expect(node.connectionProfile.user, "user should not be changed").to.equal( + "test@contoso.com", + ); + }); + + test("ignores Entra token update when both fields are undefined", () => { + const node = createTreeNode(); + const updateSpy = sandbox.spy(node, "updateConnectionProfile"); + + node.updateEntraTokenInfo({} as IConnectionProfile); + + expect(updateSpy).to.not.have.been.called; + const profile = node.connectionProfile; + expect(profile.azureAccountToken).to.equal("oldToken"); + expect(profile.expiresOn).to.equal(111); + }); + + test("no op when no credentials are passed", () => { + const node = createTreeNode(); + const removedUndefinedSpy = sandbox.spy(removeUndefinedProperties); + node.updateEntraTokenInfo(undefined); + + expect(removedUndefinedSpy).to.not.have.been.called; + }); + }); }); +function createTreeNode(overrides: Partial = {}): TreeNodeInfo { + const baseProfile: IConnectionProfile = { + id: "id", + profileName: "profile", + groupId: "group", + savePassword: false, + emptyPasswordInput: false, + azureAuthType: 0, + accountStore: undefined, + server: "server", + database: "db", + azureAccountToken: "oldToken", + expiresOn: 111, + ...overrides, + } as IConnectionProfile; + + return new TreeNodeInfo( + "label", + { type: "Server", filterable: false, hasFilters: false, subType: undefined }, + vscode.TreeItemCollapsibleState.None, + "nodePath", + "ready", + "Server", + "session", + baseProfile, + undefined as unknown as TreeNodeInfo, + [], + undefined, + undefined, + undefined, + ); +} diff --git a/test/unit/utils.test.ts b/test/unit/utils.test.ts index ed256831f9..da7c1944ff 100644 --- a/test/unit/utils.test.ts +++ b/test/unit/utils.test.ts @@ -174,6 +174,40 @@ suite("Utility tests - parseEnum", () => { }); }); +type Sample = { + token?: string; + expiresOn?: number; + notes?: string | null; +}; + +suite("removeUndefinedProperties", () => { + test("removes only undefined properties", () => { + /* eslint-disable no-restricted-syntax */ + const input: Sample = { + token: "abc", + expiresOn: undefined, + notes: null, + }; + + const result = utilUtils.removeUndefinedProperties(input); + + expect( + result, + "removeUndefinedValues should remove undefined properties, but leave null", + ).to.deep.equal({ + token: "abc", + notes: null, + }); + /* eslint-enable no-restricted-syntax */ + }); + + test("returns empty object when source is undefined", () => { + const result = utilUtils.removeUndefinedProperties(undefined); + + assert.deepStrictEqual(result, {}); + }); +}); + suite("ConnectionMatcher", () => { test("Should match connections correctly", () => { const testCases: {