Skip to content

Instantly share code, notes, and snippets.

@Brunomachadob
Last active May 7, 2020 23:52
Show Gist options
  • Save Brunomachadob/ff675e893328c5adc439aeec078bee37 to your computer and use it in GitHub Desktop.
Save Brunomachadob/ff675e893328c5adc439aeec078bee37 to your computer and use it in GitHub Desktop.
Educational purposed implementation of a Onion Router (specification of Tor)

This is a educational purposed implementation of the Onion Router specification for encrypted/annonymous conversation between two parties.

Do not use this in production (as it is simplifying a lot of stuff, despite I know you will ignore this)

If you use IntelliJ, you can just paste the implementation into a new scratch and run it.

Let me know in the comments if you have anything that could be improved (without adding too much complexity to it)

import java.security.KeyPairGenerator
import java.security.MessageDigest
import java.security.PublicKey
import java.util.*
import javax.crypto.Cipher
import javax.crypto.KeyGenerator
import javax.crypto.SecretKey
import javax.crypto.spec.SecretKeySpec
import kotlin.random.Random
class Network {
private val dns = mutableMapOf<UUID, NetworkNode>()
fun register(id: UUID, node: NetworkNode) = apply { dns[id] = node }
fun getRandomNodes(amount: Int): List<Pair<UUID, NetworkNode>> {
if (dns.values.size < amount) {
throw IllegalStateException("There are not enough nodes available in the network")
} else if (dns.values.size == amount) {
return dns.values.map { it.id to it }.shuffled()
}
val source = dns.values.toMutableList()
val result = mutableListOf<Pair<UUID, NetworkNode>>()
for (i in 1..amount) {
val index = Random.nextInt(source.size)
val node = source.removeAt(index)
result.add(node.id to node)
}
return result
}
fun forwardMessage(to: UUID, message: ByteArray, hash: ByteArray): ByteArray {
val node = dns[to] ?: throw IllegalStateException("Node '$to' does not exist.")
return node.handleMessage(message, hash)
}
fun sendMessage(to: String, message: String): String = dns.values.random().sendMessage(to, message)
}
class NetworkNode(
private val network: Network
) {
companion object {
const val NODE_CIRCUIT_SIZE = 2
const val UUID_SIZE = 36
const val DEST_BYTES_SIZE = 1 + UUID_SIZE // @ + UUID
const val CIPHER_ALGORITHM = "AES"
const val KEY_PAIR_ALGORITHM = "RSA"
val KEY_PAIR_GEN: KeyPairGenerator = KeyPairGenerator
.getInstance(KEY_PAIR_ALGORITHM)
.apply { initialize(2048) }
val sha512: MessageDigest = MessageDigest.getInstance("SHA-512")
}
val id: UUID = UUID.randomUUID()
private val keyPair = KEY_PAIR_GEN.generateKeyPair()
private val symmetricKey: SecretKey = KeyGenerator.getInstance(CIPHER_ALGORITHM).generateKey()
private val aesDecryptor = Cipher.getInstance(CIPHER_ALGORITHM)
.apply { init(Cipher.DECRYPT_MODE, symmetricKey) }
private val aesEncryptor = Cipher.getInstance(CIPHER_ALGORITHM)
.apply { init(Cipher.ENCRYPT_MODE, symmetricKey) }
init {
network.register(id, this)
}
private fun <L, R, X> Pair<L, R>.mapValue(mapFn: (R) -> X): Pair<L, X> {
return this.first to mapFn(this.second)
}
private fun getSymmetricKey(serverPublicKey: PublicKey): ByteArray =
Cipher.getInstance(KEY_PAIR_ALGORITHM)
.also { it.init(Cipher.ENCRYPT_MODE, serverPublicKey) }
.doFinal(symmetricKey.encoded)
/*
* Here we need to decrypt the message.
*
* If the hash of the message matches the desired hash, it means we got to the end of the chain, and this
* node will call whatever website it should call
*
* Otherwise, it means we are not the last in the chain, so we extract the destination from the message
* and forward it again to the next node
*/
fun handleMessage(message: ByteArray, hash: ByteArray): ByteArray {
val decryptedByteArray = aesDecryptor.doFinal(message)
val messageHash: ByteArray = sha512.digest(decryptedByteArray)
val response = if (messageHash.contentEquals(hash)) {
"200 OK".toByteArray()
} else {
val (to, encryptedMessage) = decryptedByteArray.let {
val dest = it.copyOf(UUID_SIZE)
val encryptedMessage = it.copyOfRange(DEST_BYTES_SIZE, it.size)
UUID.fromString(String(dest)) to encryptedMessage
}
network.forwardMessage(to, encryptedMessage, hash)
}
return aesEncryptor.doFinal(response)
}
fun sendMessage(to: String, message: String): String {
val decryptNodeKeyCipher = Cipher
.getInstance(KEY_PAIR_ALGORITHM)
.apply { init(Cipher.DECRYPT_MODE, keyPair.private) }
/*
* We need to build a random circuit that will be used to confuse people where the message is going
* We send our publicKey so they can encrypt their symmetric key with it, so only this Node can decrypt it
* and use it.
*/
val nodeCircuit = network.getRandomNodes(NODE_CIRCUIT_SIZE) // For each random node on the circuit
.map {
it
.mapValue { node -> node.getSymmetricKey(keyPair.public) } // Get its symmetric key encrypted with our public key
.mapValue { encKey -> decryptNodeKeyCipher.doFinal(encKey) } // Decrypt it with our private key
.mapValue { keyBytes ->
SecretKeySpec(
keyBytes,
0,
keyBytes.size,
CIPHER_ALGORITHM
)
} // map to a SecretKey impl
}
val cipher = Cipher.getInstance(CIPHER_ALGORITHM)
/*
* We build the payload and the hash/signature of the message
*/
val payload = "$to@$message"
val hash = sha512.digest(payload.toByteArray())
/*
* We need to reverse cascade encrypt the message together with the next node in the chain
* key1("2@" + key2("3@" + key3(message)))
*
* So the message will be decrypted by key1, forwarded to k2, get decrypted by it
* and forwarded to key3 and so on...
*/
val encryptedMessage = nodeCircuit
.reversed()
.foldIndexed(payload.toByteArray()) { index, bytes, (_, key) ->
val dest = if (index == 0) "" else nodeCircuit.reversed()[index - 1].first.toString() + "@"
cipher.apply { init(Cipher.ENCRYPT_MODE, key) }.doFinal(dest.toByteArray() + bytes)
}
/*
* We ask the network to forward the message to the next node
*/
val response = network.forwardMessage(nodeCircuit[0].first, encryptedMessage, hash)
/*
* We need to cascade decrypt the message
* key1(key2(key3(message)))
*/
val decryptedResponse = nodeCircuit.fold(response) { bytes, (_, key) ->
cipher.also { it.init(Cipher.DECRYPT_MODE, key) }.doFinal(bytes)
}
return String(decryptedResponse)
}
}
val network = Network()
val nodes = (1..5).map { NetworkNode(network) }
val response = network.sendMessage("google", "GET /something")
println(response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment