Appearance
Mitigations
Diffusion Augmentation
class relai.vision.classification.mitigate.diffusion_augmentation.DiffusionAugmentation(num_samples_per_class: int, prompt_template: Template, options: dict[str, dict[str, list[str]]], output_dir: str | Path, fabric: Fabric = None)
Bases: RELAIAlgorithm
Augment a dataset with synthetic images generated using a text-to-image diffusion model. This is useful when the samples in the dataset are biased, with some combinations of (class, concepts) not being represented sufficiently. By generating synthetic samples for these underrepresented combinations, and training a model on the new, augmented dataset, one can potentially mitigate the effects of bias in the original dataset.
- Parameters:
- num_samples_per_class (int) – The number of synthetic samples to generate per class.
- prompt_template (string.Template) – a placeholder called “class” ($class) where the class name will be inserted.
- options (dict *[*str , dict *[*str , list *[*str ] ] ]) – A dictionary of options for the prompt. The keys are the class names and they map to a dictionary of options for the prompt. The keys of the inner dictionary are the names of the placeholders in the prompt template and the values are lists of possible values for the placeholder.
- output_dir (Union *[*str , Path ]) – The directory to save the synthetic samples. Outputs will be saved in subdirectories whose names match the values of the “class” key in “options”.
- fabric (L.Fabric , optional) – The fabric to use for the algorithm. If not provided, a new fabric will be created.
Example:
python
from relai.vision.classification.mitigate import DiffusionAugmentation
from string import Template
num_samples_per_class = 100
prompt_template = Template("a $class in $location")
# Generate images with the prompts: "a dog in the park", "a dog in the garden", and "a cat in the park"
options = {
"dog": {
"location": ["the park", "the garden"]
},
"cat": {
"location": ["the park"]
}
}
output_dir = "./synthetic_samples"
diffusion_augmentation = DiffusionAugmentation(num_samples_per_class, prompt_template, options, output_dir)
diffusion_augmentation.run() # saves images in "./synthetic_samples/dog" and "./synthetic_samples/cat"run()
Generate synthetic samples.