Train a machine learning model on a collection露
Here, we iterate over the artifacts within a collection to train a machine learning model at scale.
import lamindb as ln
馃挕 connected lamindb: testuser1/test-scrna
ln.settings.transform.stem_uid = "Qr1kIHvK506r"
ln.settings.transform.version = "1"
ln.track()
馃挕 notebook imports: lamindb==0.72.0 torch==2.3.0
馃挕 saved: Transform(version='1', uid='Qr1kIHvK506r5zKv', name='Train a machine learning model on a collection', key='scrna5', type='notebook', updated_at=2024-05-20 08:34:35 UTC, created_by_id=1)
馃挕 saved: Run(uid='f8CRJlKYc3z2dfPvj5y3', transform_id=5, created_by_id=1)
Query our collection:
collection = ln.Collection.filter(
name="My versioned scRNA-seq collection", version="2"
).one()
collection.describe()
Show code cell output
Collection(version='2', updated_at=2024-05-20 08:34:14 UTC, uid='h97PCjzwqbbQTdGPTxh4', name='My versioned scRNA-seq collection', hash='HNR3VFV60_yqRnUka11E', visibility=1)
Provenance:
馃搸 created_by: User(uid='DzTjkKse', handle='testuser1', name='Test User1')
馃搸 transform: Transform(version='1', uid='ManDYgmftZ8C5zKv', name='Standardize and append a batch of data', key='scrna2', type='notebook')
馃搸 run: Run(uid='ka9LM9UnxfRAbPbJ2vRI', started_at=2024-05-20 08:33:50 UTC, is_consecutive=True)
馃搸 input_of (core.Run): ['2024-05-20 08:34:23 UTC']
Features:
obs: FeatureSet(uid='nMzktuMdWrMztdGcNVQ9', n=4, registry='Feature')
馃敆 donor (4, cat[ULabel]):
馃敆 tissue (cat[bionty.Tissue])
馃敆 cell_type (cat[bionty.CellType])
馃敆 assay (cat[bionty.ExperimentalFactor])
var: FeatureSet(uid='m2n31KWfgmw4UL1IGSdm', n=36508, dtype='float', registry='bionty.Gene')
'MIR1302-2HG', 'FAM138A', 'OR4F5', 'None', 'OR4F29', 'OR4F16', 'LINC01409', 'FAM87B', 'LINC01128', 'LINC00115', 'FAM41C', 'LINC02593', 'SAMD11', 'NOC2L', 'KLHL17', 'PLEKHN1', 'PERM1', 'HES4'
Create a map-style dataset露
Let us create a map-style dataset using using mapped()
: a MappedCollection
. This is what, for example, the PyTorch DataLoader
expects as an input.
Under-the-hood, it performs a virtual inner join of the features of the underlying AnnData
objects and thus allows to work with very large collections.
You can either perform a virtual inner join:
with collection.mapped(obs_keys=["cell_type"], join="inner") as dataset:
print(len(dataset.var_joint))
749
Or a virtual outer join:
dataset = collection.mapped(obs_keys=["cell_type"], join="outer")
len(dataset.var_joint)
36508
This is compatible with a PyTorch DataLoader
because it implements __getitem__
over a list of backed AnnData
objects.
The 5th cell in the collection can be accessed like:
dataset[5]
Show code cell output
{'X': array([ 0. , 0. , 0. , ..., 0. , 0. , -0.456], dtype=float32),
'_store_idx': 0,
'cell_type': 27}
The labels
are encoded into integers:
dataset.encoders
Show code cell output
{'cell_type': {'B cell, CD19-positive': 0,
'animal cell': 1,
'CD8-positive, alpha-beta memory T cell, CD45RO-positive': 2,
'conventional dendritic cell': 3,
'mast cell': 4,
'CD4-positive, alpha-beta T cell': 5,
'megakaryocyte': 6,
'naive thymus-derived CD4-positive, alpha-beta T cell': 7,
'macrophage': 8,
'CD16-negative, CD56-bright natural killer cell, human': 9,
'lymphocyte': 10,
'CD4-positive helper T cell': 11,
'naive thymus-derived CD8-positive, alpha-beta T cell': 12,
'CD38-positive naive B cell': 13,
'germinal center B cell': 14,
'T follicular helper cell': 15,
'memory B cell': 16,
'CD8-positive, CD25-positive, alpha-beta regulatory T cell': 17,
'alveolar macrophage': 18,
'gamma-delta T cell': 19,
'plasmacytoid dendritic cell': 20,
'effector memory CD8-positive, alpha-beta T cell, terminally differentiated': 21,
'alpha-beta T cell': 22,
'group 3 innate lymphoid cell': 23,
'mucosal invariant T cell': 24,
'regulatory T cell': 25,
'CD14-positive, CD16-negative classical monocyte': 26,
'cytotoxic T cell': 27,
'progenitor cell': 28,
'naive B cell': 29,
'plasma cell': 30,
'non-classical monocyte': 31,
'classical monocyte': 32,
'CD8-positive, alpha-beta memory T cell': 33,
'dendritic cell, human': 34,
'effector memory CD4-positive, alpha-beta T cell': 35,
'dendritic cell': 36,
'plasmablast': 37,
'CD16-positive, CD56-dim natural killer cell, human': 38,
'effector memory CD4-positive, alpha-beta T cell, terminally differentiated': 39}}
Create a pytorch DataLoader露
Let us use a weighted sampler:
from torch.utils.data import DataLoader, WeightedRandomSampler
# label_key for weight doesn't have to be in labels on init
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
We can now iterate through the data loader:
for batch in dataloader:
pass
Close the connections in MappedCollection
:
dataset.close()
In practice, use a context manager
with collection.mapped(obs_keys=["cell_type"]) as dataset:
sampler = WeightedRandomSampler(
weights=dataset.get_label_weights("cell_type"), num_samples=len(dataset)
)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)
for batch in dataloader:
pass