Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class ClinicalDataResourceService implements ClinicalDataResource {
if (patientMap) {
intermediateResults = innerResultFactory.
createIntermediateResults(session,
patients, flattenedVariables)
patientMap.values(), flattenedVariables)
} else {
log.info("No patients passed to retrieveData() with" +
"variables $variables; will skip main queries")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ import org.hibernate.engine.SessionImplementor
import org.transmartproject.core.dataquery.clinical.ClinicalVariable
import org.transmartproject.core.exceptions.InvalidArgumentsException
import org.transmartproject.db.dataquery.clinical.variables.TerminalConceptVariable
import org.transmartproject.db.i2b2data.ConceptDimension
import org.transmartproject.db.i2b2data.ObservationFact
import org.transmartproject.db.i2b2data.PatientDimension

import static org.transmartproject.db.util.GormWorkarounds.createCriteriaBuilder
import static org.transmartproject.db.util.GormWorkarounds.getHibernateInCriterion
import org.transmartproject.db.util.DatabaseMultisetStorage

class TerminalConceptVariablesDataQuery {

Expand All @@ -56,29 +52,33 @@ class TerminalConceptVariablesDataQuery {
throw new IllegalStateException('init() not called successfully yet')
}

def criteriaBuilder = createCriteriaBuilder(ObservationFact, 'obs', session)
criteriaBuilder.with {
projections {
property 'patient.id'
property 'conceptCode'
property 'valueType'
property 'textValue'
property 'numberValue'
}
order 'patient.id'
order 'conceptCode'
}

if (patients instanceof PatientQuery) {
criteriaBuilder.add(getHibernateInCriterion('patient.id',
patients.forIds()))
} else {
criteriaBuilder.in('patient', Lists.newArrayList(patients))
}

criteriaBuilder.in('conceptCode', clinicalVariables*.code)

criteriaBuilder.scroll ScrollMode.FORWARD_ONLY
def patientIds = patients.collect { it.id }

def dms = new DatabaseMultisetStorage(session)
def patientsBagId = dms.saveIntegerData(patientIds)
def intsTable = dms.getIntegerDataTableName()
def cvarsBagId = dms.saveStringData(clinicalVariables*.code)
def strsTable = dms.getStringDataTableName()

def query = session.createSQLQuery """\
select patient_num as patient,
concept_cd as conceptCode,
valtype_cd as valueType,
tval_char as textValue,
nval_num as numberValue
from observation_fact
where patient_num in (select /*+dynamic_sampling(10)*/ id from $intsTable where mid=:pid)
and concept_cd in (select /*+dynamic_sampling(10)*/ id from $strsTable where mid=:cid)
order by patient, conceptCode
""".stripIndent()

query.setInteger('pid', patientsBagId)
query.setInteger('cid', cvarsBagId)
query.cacheable = false
query.readOnly = true
query.fetchSize = 10000

query.scroll ScrollMode.FORWARD_ONLY
}

private void fillInTerminalConceptVariables() {
Expand Down Expand Up @@ -113,25 +113,29 @@ class TerminalConceptVariablesDataQuery {
}

// find the concepts
def res = ConceptDimension.withCriteria {
projections {
property 'conceptPath'
property 'conceptCode'
}

or {
if (conceptPaths.keySet()) {
'in' 'conceptPath', conceptPaths.keySet()
}
if (conceptCodes.keySet()) {
'in' 'conceptCode', conceptCodes.keySet()
}
}
}

for (concept in res) {
String conceptPath = concept[0],
conceptCode = concept[1]
def dms = new DatabaseMultisetStorage(session)
def pathsBagId = dms.saveStringData(conceptPaths.keySet())
def codesBagId = dms.saveStringData(conceptCodes.keySet())
def strsTable = dms.getStringDataTableName()

def stmt = session.connection().prepareStatement("""\
select concept_path,
concept_cd
from concept_dimension
where concept_path in (select /*+dynamic_sampling(10)*/ id from $strsTable where mid=?)
union
select concept_path,
concept_cd
from concept_dimension
where concept_cd in (select /*+dynamic_sampling(10)*/ id from $strsTable where mid=?)
""".stripIndent())
stmt.setInt(1, pathsBagId)
stmt.setInt(2, codesBagId)

def res = stmt.executeQuery()
while (res.next()) {
String conceptPath = res.getString(1),
conceptCode = res.getString(2)

if (conceptPaths[conceptPath]) {
TerminalConceptVariable variable = conceptPaths[conceptPath]
Expand Down
160 changes: 160 additions & 0 deletions src/groovy/org/transmartproject/db/util/DatabaseMultisetStorage.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package org.transmartproject.db.util

import org.codehaus.groovy.grails.commons.ApplicationHolder
import org.hibernate.engine.SessionImplementor

import java.sql.Connection

/**
* Helper class to save multisets of ids in temporary tables in the database
* that can be ised in 'IN' SQL query conditions, e.g.:
* select * from table1 where col1 in (select id from tmp_table where mid=?)
*/
class DatabaseMultisetStorage {

def databasePortabilityService
SessionImplementor session
Integer batchSize
protected Connection connection
protected Boolean isPostgres

enum elementsType {
INTEGER,
STRING
}

DatabaseMultisetStorage(SessionImplementor sess, Integer bsize) {
session = sess
batchSize = bsize
connection = session.connection()
if (!databasePortabilityService) {
databasePortabilityService = ApplicationHolder.getApplication().getMainContext().getBean('databasePortabilityService')
}
isPostgres = databasePortabilityService.databaseType == org.transmartproject.db.support.DatabasePortabilityService.DatabaseType.POSTGRESQL

def isReadOnly = ensureReadWriteTransaction()
[[getIntegerDataTableName(), 'bigint'], [getStringDataTableName(), 'text']].each {
if (isPostgres) {
def sql = 'create temporary table if not exists ' + it[0] + ' (id ' + it[1] + ', mid int not null) on commit preserve rows'
def stmt = connection.prepareStatement(sql)
stmt.execute()
}
// NOTICE: this commits transaction on Oracle
def stmt = connection.prepareStatement('truncate table ' + it[0])
stmt.execute()
}
restoreTransactionState(isReadOnly)
}

DatabaseMultisetStorage(SessionImplementor sess) {
this(sess, 1000)
}

protected Boolean ensureReadWriteTransaction() {
Boolean isReadOnly = connection.isReadOnly()
if (isReadOnly) {
connection.rollback()
connection.setReadOnly(false)
}
return isReadOnly
}

protected void restoreTransactionState(Boolean isReadOnly) {
if (!isReadOnly) {
return
}
connection.commit()
connection.setReadOnly(true)
}

/**
* Table name where integer data is stored
*
* @return table name
*/
static String getIntegerDataTableName() {
return 'session_multisets_of_integers'
}

/**
* Table name where string data is stored
*
* @return table name
*/
static String getStringDataTableName() {
return 'session_multisets_of_strings'
}

/**
* Store a collection of Integers into temporary table
*
* @param data collection of integer values
* @param session hibernate session
*
* @return stored multiset id
*/
Integer saveIntegerData(Iterable<Long> data) {
return saveData(data, elementsType.INTEGER)
}

/**
* Store a collection of Strings into temporary table
*
* @param data collection of string values
* @param session hibernate session
*
* @return stored multiset id
*/
Integer saveStringData(Iterable<String> data) {
return saveData(data, elementsType.STRING)
}

protected Integer saveData(data, dataType) {
if (!data.iterator().hasNext()) {
return 0
}

def multisetId
def isReadOnly = ensureReadWriteTransaction()
try {
def isStringData = dataType == elementsType.STRING
def tableName = isStringData ? getStringDataTableName() : getIntegerDataTableName()

def stmt = connection.prepareStatement('select coalesce(max(mid), 0)+1 from ' + tableName)
def res = stmt.executeQuery()
res.next()
multisetId = res.getInt(1)

def counter = 0
stmt = connection.prepareStatement('insert into ' + tableName + '(id,mid) values(?,?)')
data.each {
if (isStringData) {
stmt.setString(1, it)
} else {
stmt.setLong(1, it)
}
stmt.setInt(2, multisetId)
stmt.addBatch()
counter++
if (counter >= batchSize) {
stmt.executeBatch()
counter = 0
}
}
if (counter > 0) {
stmt.executeBatch()
}

// For Oracle use dynamic_sampling(N) hint in queries
if (isPostgres) {
stmt = connection.prepareStatement('analyze ' + tableName)
stmt.execute()
}
} finally {
restoreTransactionState(isReadOnly)
}

return multisetId
}
}