Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#nullable enable

using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Enums;
using Bit.Core.AdminConsole.Models.Data.Organizations.Policies;
using Bit.Core.AdminConsole.OrganizationFeatures.Policies.PolicyRequirements;
using Bit.Core.AdminConsole.Repositories;
Expand All @@ -20,7 +18,7 @@ public async Task<T> GetAsync<T>(Guid userId) where T : IPolicyRequirement
throw new NotImplementedException("No Requirement Factory found for " + typeof(T));
}

var policyDetails = await GetPolicyDetails(userId);
var policyDetails = await GetPolicyDetails(userId, factory.PolicyType);
var filteredPolicies = policyDetails
.Where(p => p.PolicyType == factory.PolicyType)
.Where(factory.Enforce);
Expand Down Expand Up @@ -48,8 +46,8 @@ public async Task<IEnumerable<Guid>> GetManyByOrganizationIdAsync<T>(Guid organi
return eligibleOrganizationUserIds;
}

private Task<IEnumerable<PolicyDetails>> GetPolicyDetails(Guid userId)
=> policyRepository.GetPolicyDetailsByUserId(userId);
private async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetails(Guid userId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByUserIdsAndPolicyType([userId], policyType);

private async Task<IEnumerable<OrganizationPolicyDetails>> GetOrganizationPolicyDetails(Guid organizationId, PolicyType policyType)
=> await policyRepository.GetPolicyDetailsByOrganizationIdAsync(organizationId, policyType);
Expand Down
11 changes: 0 additions & 11 deletions src/Core/AdminConsole/Repositories/IPolicyRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@ public interface IPolicyRepository : IRepository<Policy, Guid>
Task<Policy?> GetByOrganizationIdTypeAsync(Guid organizationId, PolicyType type);
Task<ICollection<Policy>> GetManyByOrganizationIdAsync(Guid organizationId);
Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId);
/// <summary>
/// Gets all PolicyDetails for a user for all policy types.
/// </summary>
/// <remarks>
/// Each PolicyDetail represents an OrganizationUser and a Policy which *may* be enforced
/// against them. It only returns PolicyDetails for policies that are enabled and where the organization's plan
/// supports policies. It also excludes "revoked invited" users who are not subject to policy enforcement.
/// This is consumed by <see cref="IPolicyRequirementQuery"/> to create requirements for specific policy types.
/// You probably do not want to call it directly.
/// </remarks>
Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId);

/// <summary>
/// Retrieves <see cref="OrganizationPolicyDetails"/> of the specified <paramref name="policyType"/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,6 @@ public async Task<ICollection<Policy>> GetManyByUserIdAsync(Guid userId)
}
}

public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId)
{
using (var connection = new SqlConnection(ConnectionString))
{
var results = await connection.QueryAsync<PolicyDetails>(
$"[{Schema}].[PolicyDetails_ReadByUserId]",
new { UserId = userId },
commandType: CommandType.StoredProcedure);

return results.ToList();
}
}

public async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetailsByUserIdsAndPolicyType(IEnumerable<Guid> userIds, PolicyType type)
{
await using var connection = new SqlConnection(ConnectionString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,45 +56,6 @@ public PolicyRepository(IServiceScopeFactory serviceScopeFactory, IMapper mapper
}
}

public async Task<IEnumerable<PolicyDetails>> GetPolicyDetailsByUserId(Guid userId)
{
using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);

var providerOrganizations = from pu in dbContext.ProviderUsers
where pu.UserId == userId
join po in dbContext.ProviderOrganizations
on pu.ProviderId equals po.ProviderId
select po;

var query = from p in dbContext.Policies
join ou in dbContext.OrganizationUsers
on p.OrganizationId equals ou.OrganizationId
join o in dbContext.Organizations
on p.OrganizationId equals o.Id
where
p.Enabled &&
o.Enabled &&
o.UsePolicies &&
(
(ou.Status != OrganizationUserStatusType.Invited && ou.UserId == userId) ||
// Invited orgUsers do not have a UserId associated with them, so we have to match up their email
(ou.Status == OrganizationUserStatusType.Invited && ou.Email == dbContext.Users.Find(userId).Email)
)
select new PolicyDetails
{
OrganizationUserId = ou.Id,
OrganizationId = p.OrganizationId,
PolicyType = p.Type,
PolicyData = p.Data,
OrganizationUserType = ou.Type,
OrganizationUserStatus = ou.Status,
OrganizationUserPermissionsData = ou.Permissions,
IsProvider = providerOrganizations.Any(po => po.OrganizationId == p.OrganizationId)
};
return await query.ToListAsync();
}

public async Task<IEnumerable<OrganizationPolicyDetails>> GetPolicyDetailsByOrganizationIdAsync(Guid organizationId, PolicyType policyType)
{
using var scope = ServiceScopeFactory.CreateScope();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,70 @@
AS
BEGIN
SET NOCOUNT ON
SELECT
OU.[Id] AS OrganizationUserId,
P.[OrganizationId],
P.[Type] AS PolicyType,
P.[Enabled] AS PolicyEnabled,
P.[Data] AS PolicyData,
OU.[Type] AS OrganizationUserType,
OU.[Status] AS OrganizationUserStatus,
OU.[Permissions] AS OrganizationUserPermissionsData,
CASE WHEN EXISTS (
SELECT 1
FROM [dbo].[ProviderUserView] PU
INNER JOIN [dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId]
WHERE PU.[UserId] = OU.[UserId] AND PO.[OrganizationId] = P.[OrganizationId]
) THEN 1 ELSE 0 END AS IsProvider
FROM [dbo].[PolicyView] P
INNER JOIN [dbo].[OrganizationUserView] OU
ON P.[OrganizationId] = OU.[OrganizationId]
WHERE P.[Type] = @PolicyType AND

DECLARE @UserEmail NVARCHAR(320)
SELECT @UserEmail = Email
FROM
[dbo].[UserView]
WHERE
Id = @UserId

;WITH OrgUsers AS
(
(OU.[Status] != 0 AND OU.[UserId] = @UserId) -- OrgUsers who have accepted their invite and are linked to a UserId
OR EXISTS (
SELECT 1
FROM [dbo].[UserView] U
WHERE U.[Id] = @UserId AND OU.[Email] = U.[Email] AND OU.[Status] = 0 -- 'Invited' OrgUsers are not linked to a UserId yet, so we have to look up their email
)
-- All users except invited (Status <> 0): direct UserId match
SELECT
OU.[Id],
OU.[OrganizationId],
OU.[Type],
OU.[Status],
OU.[Permissions]
FROM
[dbo].[OrganizationUserView] OU
WHERE
OU.[Status] <> 0
AND OU.[UserId] = @UserId

UNION ALL

-- Invited users: email match
SELECT
OU.[Id],
OU.[OrganizationId],
OU.[Type],
OU.[Status],
OU.[Permissions]
FROM
[dbo].[OrganizationUserView] OU
WHERE
OU.[Status] = 0
AND OU.[Email] = @UserEmail
AND @UserEmail IS NOT NULL
),
Providers AS
(
SELECT
OrganizationId
FROM
[dbo].[UserProviderAccessView]
WHERE
UserId = @UserId
)
END
SELECT
OU.[Id] AS [OrganizationUserId],
P.[OrganizationId],
P.[Type] AS [PolicyType],
P.[Enabled] AS [PolicyEnabled],
P.[Data] AS [PolicyData],
OU.[Type] AS [OrganizationUserType],
OU.[Status] AS [OrganizationUserStatus],
OU.[Permissions] AS [OrganizationUserPermissionsData],
CASE WHEN PR.[OrganizationId] IS NULL THEN 0 ELSE 1 END AS [IsProvider]
FROM
[dbo].[PolicyView] P
INNER JOIN
OrgUsers OU ON P.[OrganizationId] = OU.[OrganizationId]
LEFT JOIN
Providers PR ON PR.[OrganizationId] = OU.[OrganizationId]
WHERE
P.[Type] = @PolicyType
END
9 changes: 9 additions & 0 deletions src/Sql/dbo/Views/UserProviderAccessView.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE VIEW [dbo].[UserProviderAccessView]
AS
SELECT DISTINCT
PU.[UserId],
PO.[OrganizationId]
FROM
[dbo].[ProviderUserView] PU
INNER JOIN
[dbo].[ProviderOrganizationView] PO ON PO.[ProviderId] = PU.[ProviderId]
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ public class PolicyRequirementQueryTests
[Theory, BitAutoData]
public async Task GetAsync_IgnoresOtherPolicyTypes(Guid userId)
{
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.RequireSso };
var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.RequireSso, UserId = userId };
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserId(userId).Returns([otherPolicy, thisPolicy]);
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([otherPolicy, thisPolicy]);

var factory = new TestPolicyRequirementFactory(_ => true);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
Expand All @@ -33,9 +35,11 @@ public async Task GetAsync_CallsEnforceCallback(Guid userId)
{
// Arrange policies
var policyRepository = Substitute.For<IPolicyRepository>();
var thisPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
var otherPolicy = new PolicyDetails { PolicyType = PolicyType.SingleOrg };
policyRepository.GetPolicyDetailsByUserId(userId).Returns([thisPolicy, otherPolicy]);
var thisPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
var otherPolicy = new OrganizationPolicyDetails { PolicyType = PolicyType.SingleOrg, UserId = userId };
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([thisPolicy, otherPolicy]);

// Arrange a substitute Enforce function so that we can inspect the received calls
var callback = Substitute.For<Func<PolicyDetails, bool>>();
Expand Down Expand Up @@ -70,7 +74,9 @@ public async Task GetAsync_ThrowsIfNoFactoryRegistered(Guid userId)
public async Task GetAsync_HandlesNoPolicies(Guid userId)
{
var policyRepository = Substitute.For<IPolicyRepository>();
policyRepository.GetPolicyDetailsByUserId(userId).Returns([]);
policyRepository.GetPolicyDetailsByUserIdsAndPolicyType(
Arg.Is<IEnumerable<Guid>>(ids => ids.Contains(userId)), PolicyType.SingleOrg)
.Returns([]);

var factory = new TestPolicyRequirementFactory(x => x.IsProvider);
var sut = new PolicyRequirementQuery(policyRepository, [factory]);
Expand Down
Loading
Loading