enlaight is a word creation of enlight and AI and refers to a key property of the contained models: their built-in interpretability. By this, it is possible to create machine learning models that surrogate an existing machine learning model so that it is possible to explain the reasoning process of the original model to some extent. Thus, these models enlight AI.
The available models are prototype-based learning methods implemented in PyTorch-Lightning. The available models are:
- Generalized Learning Vector Quantization (GLVQ),
- Generalized Tangent Learning Vector Quantization (GTLVQ),
- (Stable) Classification-by-Components (CBC),
- Radial Basis Function (RBF) networks,
Prototype models are interpretable machine learning models for classification tasks. In a nutshell, a prototype model consists of a distance function and a set of prototypes defined in the data space with fixed class labels. By computing the distance between prototypes and a given input, the closest prototype can be determined. This prototype determines by its class label the label of the input, so-called winner-takes-all rule. By defining a suitable loss function and by having a training dataset, the position of the prototypes in the data space can be learned from data so that the classification accuracy is maximized. The main advantages of prototype-based models are
- their built-in interpretability and
- their provably robust classification decisions.
The interface of the models is flexible. For instance, there is no requirement for how the prototypes have to be provided. They can be the result of another module or can be fixed and non-trainable. Moreover, the prototypes class supports constraints. Therefore, prototypes can be constraint to be in a certain data space.
The required distance operations are implemented such that they support fast and memory efficient computations (by reformulating the distance operations with dot-products). The following distance functions are supported as functions and PyTorch Module classes (see :class:`enlaight.core.distance` for the full list):
- Cosine similarity is imported from PyTorch
- Lp distance is import from PyTorch
- (Squared) Euclidean distance
- (Squared) Tangent distance
All distance operations support batching with respect to both arguments. The implementations support the computation of stable gradients.
To install the package, execute the following command from the root of the package directory:
pip install .Note that the package requires Python 3.9 or higher, which is checked during the installation. Moreover, if you install the package inside a conda environment, be aware of potential installation or package side-effects due to conflicts between conda and pip. If you encounter errors, install all dependencies directly with conda.
To build the documentation HTML files, install the package with docs dependencies:
pip install .[docs]Then, execute:
sphinx-build -b html docs docs/buildThe compiled documentation is located in docs/build.
For contributions, install the package in dev-mode:
pip install .[dev]or with all dependencies (including dev and docs):
pip install .[all]If you are working in a conda environment and encounter any installation or dependency errors, please install all packages using conda.
If you prepare a code submission, always ensure that you provide docstrings and that the documentation can be generated.
The documentation is completely generated from docstrings and this README file. So far, we avoid providing additional information in additional documentation files. If you encounter pandoc error during the documentation creation on Linux machines even though pandoc is installed via pip, install it via
apt-get install pandoc
If you have errors with ipykernel during doc compilation while using conda, uninstall the pip version and install it via conda.
Additionally, it is recommended to install pre-commit so that pre-commit checks are triggered automatically before making a commit; thus, avoiding non-standardized commits:
pip install pre-commit
pre-commit installMoreover, install
pre-commit install --hook-type commit-msgto ensure that your commit messages follow conventional commits, which is recommended. Again, if you encounter errors while using conda, uninstall pre-commit in pip and install it via conda.
If you prepare a commit, run
pre-commit run --all-filesto test for errors with respect to pre-commit hooks. In case you really want to do
a non-standardized commit use --no-verify option of git commit to skip the
checks.
The package was used to create a part of the results of the corresponding AAAI 2025 paper. In particular, the models provide in this package were used for the shallow model experiments. For the deep models, please check the HuggingFace and the GitHub repository.
To reproduce the results, install the package in dev-mode:
pip install .[dev]Then, execute
cd ./experiments
python model_comparison.pyto reproduce the results of the shallow model comparison. Please note that the script uses ray-tune for parallel scheduling of the jobs and assumes that a GPU is available. If multiple GPUs are available, ray-tune will execute the individual runs in parallel. Since the models are relatively small, it could be possible to compute multiple models in parallel on one GPU. For this, change the following line in the Python script:
tune.with_resources(objective, {"gpu": 1}) # 100% job-allocation per GPUto
tune.with_resources(objective, {"gpu": 1/2}) # 50% job-allocation per GPUThis will allow ray to run 2 jobs per GPU.
After the training of all the models is completed, you can use the script
./experiments/print_shallow_model_results.py to generate one consolidated
dictionary with all the results and to render the results in an easy human-readable
format. Only specify the root path at the top of the file.
To reproduce the robustness analysis (robustness curves), execute:
cd ./experiments
python robustness_analysis.pySimilar to before, you can specify the GPU usage of ray-tune in the file.
Moreover, use the ./experiments/robustness_plot.py script to generate the plots
from the paper. Again, specify the root path at the top of the file.