ZKMNIST - Noir

@GitHub|Demo

About

Ethereum's smart contracts enable trustless execution in a decentralized environment. Yet, the limitations in computational power and the transparent aspect of blockchain transactions limit the ability to develop applications that require intensive computation or handle private data, like those used in machine learning. Zero-Knowledge (ZK) proofs aim to free these constraints by allowing off-chain computations to be verified on-chain. Furthermore, they allow the verification of these computations without revealing the underlying data, thereby enabling privacy-preserving smart contracts.

Similarly, these proofs can be used to facilitate the execution of machine learning models in a way that keeps the input data or the model weights confidential. This might be useful when input data is sensitive and private, such as personal biometric information. Likewise, model parameters can be confidential, like those used in biometric authentication systems.

At the same time, it might be essential for downstream entities, such as on-chain smart contracts, that rely on the output of the machine learning model, to be certain that the input was correctly processed by the ML model to yield the claimed output. This ensures both privacy and reliability in the application of these models in sensitive areas.

This web app is an example of such zero knowledge machine learning. The app uses a neural network trained on the MNIST dataset (A collection of handwritten digits), that classifies a selected digit and then generates proof of the classification while keeping the inputs (the pixels) private. The proof can then be verified on or off-chain.

The ZK proving system is written using Aztec Network’s ZK domain specific language Noir. The web application was built from Aztec's Noir Starter. It uses Next.js as the frontend framework, Hardhat to deploy and test, and TensorFlow/Keras for model training and prediction.

Inspired by 0xPARC's ZK Machine Learning. Authored by Alexander Janiak.


Model & Data

The MNIST dataset is a collection of 70,000 handwritten digits ranging from 0 to 9, commonly used for training and testing in the field of machine learning. It was created by merging samples from a dataset of American Census Bureau employees and high school students to ensure a diverse set of handwriting styles. Each image is grayscale, 28x28 pixels, and labeled with the corresponding digit it represents. MNIST was designed to be a benchmark dataset to evaluate the performance of algorithms in accurately recognizing and classifying handwritten digits. It has become a standard for evaluating machine learning techniques and therefore a perfect dataset to test Noir's capabilities for zero knowledge machine learning.

Example of MNIST Digits

Unlike 0xPARC's CNN approach, the chosen model architecture is a simple dense network:

Layer (type)Output ShapeParam #
flatten (Flatten)(None, 784)0
dense (Dense)(None, 300)235,500
dense_1 (Dense)(None, 100)30,100
dense_2 (Dense)(None, 30)3,030
dense_3 (Dense)(None, 10)310
Total params: 268,940 (1.03 MB)Trainable params: 268,940 (1.03 MB)Non-trainable params: 0 (0.00 Byte)

The model uses the ReLu activation function between dense layers and a Softmax function after the last layer. The model was trained using Stochastic Gradient Descent and uses Sparse Categorical Cross Entropy for the loss function. After training, the model achieved 97.52% accuracy on the test set.


Approach

The approach used was similar to 0xPARC's implementation with a few alterations to account for changes in the model and Noir's quirks. Like in 0xPARC's demo, and to simplify the circuit, only the last layer of the model was implemented as a zk-SNARK. Noir's simple Rust-like syntax made writing the circuit easy.

The computation for the neural network's last layer forward pass implemented as a circuit, excluding the softmax σ\sigma function, is:

y^=Lx+b\mathbf{\hat{y}} = \mathbf{L}\mathbf{x} + \mathbf{b}

Where L\mathbf{L} corresponds to the weights of the last layer with size 10×3010 \times 30, x\mathbf{x} is the input from the previous layer with size 30×130 \times 1, and b\mathbf{b} is the layer's biases with size 10×110 \times 1. y^\mathbf{\hat{y}} is the model's prediction where σ(y^)i\sigma(\mathbf{\hat{y}})_i can be interpreted as the probability that the iith class is the actual class.

The model's class prediction is then computed as:

p^=arg max(y^)\hat{p} = \argmax{(\mathbf{\hat{y}})}
Note: The Softmax function preserves order so it does not affect the output of the argmax function and is not needed in the circuit.

Preprocessing

At the time of circuit development, Noir libraries like Signed Int and Fraction didn't exist. As a result, to circumvent Noir's lack of native signed integers and fraction/floating-points, the input, weights, and biases were scaled and truncated.

To account for negatives, a very big positive integer cc is added element-wise to the weights and biases. Due to the previous layer's ReLu activation function, x\mathbf{x} is guaranteed to be positive. To account for floating points, a positive scaler aa scales the weights, biases, and inputs element-wise. Each element is then floored (equivalent to truncation when 0\geq0). J\mathbf{J} is a matrix of ones with size 10×3010 \times 30.

z=axW=aL+cJv=a2b+c1 \mathbf{z} = \lfloor{a\mathbf{x}}\rfloor \qquad \mathbf{W} = \lfloor{a\mathbf{L}} + c\mathbf{J}\rfloor \qquad \mathbf{v} = \lfloor{a^2\mathbf{b}} + c\vec{1} \rfloor \qquad

Excluding truncation,

Wz+v=(aL+cJ)(ax)+(a2b+c1)=a2Lx+a2b+caJ1x+cJ2=a2y^+caJ1x+c1=a2y^+caJ1x+c1=a2y^+c[a(1x)+1a(1x)+1a(1x)+1a(1x)+1]a2y^+d,dR10arg max(a2y^+d)=arg max(y^)=p^ \begin{align*} \mathbf{W}\mathbf{z} + \mathbf{v} = (a\mathbf{L}+c\mathbf{J}) (a\mathbf{x}) + (a^2\mathbf{b} + c\vec{1}) \\ = a^2\mathbf{L}\mathbf{x} + a^2\mathbf{b} + ca\mathbf{J_1}\mathbf{x} + c\mathbf{J_2} \\ = a^2\mathbf{\hat{y}} + ca\mathbf{J_1}\mathbf{x} + c\vec{1} \\ = a^2\mathbf{\hat{y}} + ca\mathbf{J_1}\mathbf{x} + c\vec{1} \\ = a^2\mathbf{\hat{y}} + c\begin{bmatrix} a(\mathbf{1} \cdot \mathbf{x})+1 \\ a(\mathbf{1} \cdot \mathbf{x})+1 \\ \vdots \\ a(\mathbf{1} \cdot \mathbf{x})+1 \\ a(\mathbf{1} \cdot \mathbf{x})+1 \end{bmatrix} \\ \equiv a^2\mathbf{\hat{y}} + \mathbf{d},\quad \mathbf{d} \in \mathbb{R}^{10} \\ \therefore \quad \argmax{(a^2\mathbf{\hat{y}} + \mathbf{d})} = \argmax{(\mathbf{\hat{y}})} = \hat{p} \end{align*}

Commitment

After a digit is selected, the user generates a public commitment cxc_x equivalent to the pederson hash of the input x\mathbf{x}. The circuit then checks the constraint (hash(x)==cx)(\text{hash}{(\mathbf{x})} == c_x) which ensures that model's prediction corresponds to that commited input.