Kotlin, Spring Boot- JWT Token

Kotlin, Spring Boot- JWT Token

Creating a Custom JWT Token Utility in Spring Boot.

ยท

8 min read

[Cover Image by : Ricardo Loaiza]

Introduction:

JSON Web Tokens (JWT) have become a popular method for secure communication between parties. In this article, we will explore the implementation of a custom JWT token utility in a Spring Boot application using Kotlin. This utility will handle token generation, verification, expiration, and refreshing.

Setting Up the JWT Helper Class:

Let's start by creating a utility class named JwtHelper in Kotlin. This class will encapsulate the functionality required for JWT token handling. The class includes enums for JWT and signature algorithms, as well as methods for initializing, generating payloads, creating JWT tokens, verifying tokens, refreshing tokens, and checking token expiration.

object JwtHelper {
    /**
     * Enum class representing JWT algorithms with their corresponding values.
     */
    enum class JwtAlgorithm(val value: String) {
        ALGORITHM_HS256("HS256"), ALGORITHM_HS384("HS384"), ALGORITHM_HS512("HS512")
    }

    /**
     * Enum class representing HMAC signature algorithms with their corresponding values.
     */
    private enum class SignatureAlgorithm(val value: String) {
        ALGORITHM_HS256("HmacSHA256"), ALGORITHM_HS384("HmacSHA384"), ALGORITHM_HS512("HmacSHA512")
    }

    // Secret key for JWT generation and verification
    private var SECRET_KEY: String = ""
    // Default expiration time for JWT tokens (365 days)
    private const val EXPIRATION_TIME_MS = 31536000000 // 365 days - Customizable

    // Default JWT algorithm and signature algorithm
    private var jwtAlgorithm: String = JwtAlgorithm.ALGORITHM_HS512.value
    private var signatureAlgorithm: String = SignatureAlgorithm.ALGORITHM_HS512.value

    /**
     * Initializes the JWT Helper with the secret key and algorithm.
     *
     * @param key Secret key for JWT token generation and verification.
     * @param algorithm JWT algorithm to be used (default is ALGORITHM_HS512).
     */
    fun init(key: String, algorithm: JwtAlgorithm = JwtAlgorithm.ALGORITHM_HS512) {
        this.SECRET_KEY = key
        checkForSecretKey()
        this.jwtAlgorithm = algorithm.value
        this.signatureAlgorithm = when (algorithm) {
            JwtAlgorithm.ALGORITHM_HS256 -> SignatureAlgorithm.ALGORITHM_HS256.value
            JwtAlgorithm.ALGORITHM_HS384 -> SignatureAlgorithm.ALGORITHM_HS384.value
            JwtAlgorithm.ALGORITHM_HS512 -> SignatureAlgorithm.ALGORITHM_HS512.value
        }
    }

    /**
     * Generates a JWT payload with customizable parameters.
     *
     * @param subject Subject of the JWT token.
     * @param issuer Issuer of the token (default is "yLnk").
     * @param audience Audience for the token (default is "yLnk").
     * @param issuedAt Time at which the token was issued (default is 0, current time used if set to 0).
     * @param expireAt Time at which the token will expire (default is 0, 365 days expiration if set to 0).
     * @param additionalData Additional data to include in the payload.
     * @return Map representing the JWT payload.
     */
    fun generatePayload(
        subject: String,
        issuer: String = "yLnk",
        audience: String = "yLnk",
        issuedAt: Long = 0,
        expireAt: Long = 0,
        additionalData: Map<String, Any>? = null
    ): Map<String, Any> {
        return executeBodyOrReturnNull {
            val temp = HashMap<String, Any>()
            val dateTime = System.currentTimeMillis()
            val currentTime = if (issuedAt == "0".toLong()) dateTime else issuedAt
            temp["iss"] = issuer
            temp["iat"] = currentTime
            temp["exp"] = if (expireAt == "0".toLong()) {
                dateTime + EXPIRATION_TIME_MS
            } else {
                expireAt + EXPIRATION_TIME_MS
            }
            temp["aud"] = audience
            temp["sub"] = subject
            additionalData?.let {
                temp["additional"] = it
            }
            temp
        } ?: run { mutableMapOf() }
    }

    /**
     * Generates a JWT token using the provided payload.
     *
     * @param payload JWT payload to be encoded into the token.
     * @return Generated JWT token.
     */
    fun generateJwtToken(payload: Map<String, Any>): String {
        checkForSecretKey()
        val header = encodeBase64URL("{\"alg\":\"${jwtAlgorithm}\",\"typ\":\"JWT\"}".toByteArray(UTF_8))
        val encodedPayload = encodeBase64URL(serializeToJson(payload).toByteArray(UTF_8))
        val signature = generateSignature("$header.$encodedPayload")
        return "$header.$encodedPayload.$signature"
    }

    /**
     * Verifies the validity of a JWT token.
     *
     * @param token JWT token to be verified.
     * @return True if the token is valid, false otherwise.
     */
    fun verifyJwtToken(token: String): Boolean {
        checkForSecretKey()
        val parts = token.split("\\.".toRegex())
        if (parts.size != 3) {
            return false
        }
        val header = parts[0]
        val payload = parts[1]
        val signature = parts[2]
        val calculatedSignature = generateSignature("$header.$payload")
        return calculatedSignature == signature
    }

    /**
     * Refreshes a JWT token by updating its issued at and expiration time.
     *
     * @param token JWT token to be refreshed.
     * @return Refreshed JWT token.
     * @throws IllegalArgumentException if the token format is invalid.
     */
    fun refreshJwtToken(token: String): String {
        checkForSecretKey()
        val parts = token.split("\\.".toRegex())
        if (parts.size != 3) {
            throw IllegalArgumentException("Invalid JWT format")
        }
        val header = parts[0]
        val payload = parts[1]
        val payloadJson = decodeBase64URL(payload)
        val payloadMap = ObjectMapper().readValue(payloadJson, Map::class.java) as MutableMap<String, Any>

        val currentTime = System.currentTimeMillis()
        payloadMap["iat"] = currentTime
        payloadMap["exp"] = currentTime + EXPIRATION_TIME_MS
        val newPayload = encodeBase64URL(serializeToJson(payloadMap).toByteArray(UTF_8))

        val newSignature = generateSignature("$header.$newPayload")
        return "$header.$newPayload.$newSignature"
    }


    /**
     * Checks if a JWT token has expired.
     *
     * @param token JWT token to be checked.
     * @return True if the token has expired, false otherwise.
     */
    fun isTokenExpired(token: String): Boolean {
        checkForSecretKey()
        val payloadMap = extractPayload(token) ?: return true
        val expiration = payloadMap["exp"] as? Long ?: return true
        return System.currentTimeMillis() > expiration
    }

    /**
     * Decodes a base64 URL-encoded string.
     *
     * @param input Base64 URL-encoded string to decode.
     * @return Decoded string.
     */
    private fun decodeBase64URL(input: String): String {
        val decodedBytes = Base64.getUrlDecoder().decode(input)
        return String(decodedBytes, UTF_8)
    }

    /**
     * Extracts and decodes the payload from a JWT token.
     *
     * @param token JWT token from which to extract the payload.
     * @return Decoded payload as a Map or null if the token format is invalid.
     */
    fun extractPayload(token: String): Map<String, Any>? {
        checkForSecretKey()
        val parts = token.split("\\.".toRegex())
        if (parts.size != 3) {
            return null
        }
        val payloadBase64 = parts[1]
        val payloadJson = String(Base64.getUrlDecoder().decode(payloadBase64), UTF_8)
        return executeBodyOrReturnNull { ObjectMapper().readValue(payloadJson, Map::class.java) as Map<String, Any> }
    }

    /**
     * Generates the signature for a given data using the specified signature algorithm.
     *
     * @param data Data for which the signature is generated.
     * @return Base64 URL-encoded signature.
     */
    private fun generateSignature(
        data: String,
    ): String {
        val secretKeySpec = SecretKeySpec(SECRET_KEY.toByteArray(), signatureAlgorithm)
        val mac = Mac.getInstance(signatureAlgorithm)
        mac.init(secretKeySpec)
        val signatureBytes = mac.doFinal(data.toByteArray())
        return encodeBase64URL(signatureBytes)
    }

    /**
     * Encodes a byte array into a base64 URL-encoded string and removes padding.
     *
     * @param input Byte array to encode.
     * @return Base64 URL-encoded string without padding.
     */
    private fun encodeBase64URL(input: ByteArray): String {
        val encoded = Base64.getUrlEncoder().encodeToString(input)
        return encoded.replace("=", "")
    }

    /**
     * Serializes a Map of key-value pairs into a JSON-formatted string.
     *
     * @param data Map of key-value pairs to be serialized.
     * @return JSON-formatted string representing the serialized data.
     */
    private fun serializeToJson(data: Map<String, Any>): String {
        val entries = data.entries.joinToString(",") { "\"${it.key}\":\"${it.value}\"" }
        return "{$entries}"
    }

    /**
     * Checks if the secret key is empty and throws an exception if it is.
     *
     * @throws InvalidJWTSecretKey if the secret key is empty.
     */
    private fun checkForSecretKey() {
        if (SECRET_KEY.isEmpty()) throw InvalidJWTSecretKey()
    }
}

Initializing the JWT Helper:

Before using the JWT utility, we need to initialize it with a secret key and algorithm. This is achieved by calling the init method, specifying the secret key and desired JWT algorithm.

JwtHelper.init("your_secret_key", JwtHelper.JwtAlgorithm.ALGORITHM_HS512)

Generating Payloads:

The generatePayload method allows you to create a JWT payload with customizable parameters such as subject, issuer, audience, issued at, expired at, and additional data.

val payload = JwtHelper.generatePayload("user123", issuer = "your_issuer", audience = "your_audience")

Generating JWT Tokens:

To generate a JWT token, use the generateJwtToken method by passing the payload.

val jwtToken = JwtHelper.generateJwtToken(payload)

Verifying JWT Tokens:

The verifyJwtToken method checks the validity of a JWT token.

val isTokenValid = JwtHelper.verifyJwtToken(jwtToken)

Refreshing JWT Tokens:

If token refreshment is needed, the refreshJwtToken method can be used.

val refreshedToken = JwtHelper.refreshJwtToken(jwtToken)

Checking Token Expiration:

The isTokenExpired method checks if a token has expired.

val isExpired = JwtHelper.isTokenExpired(jwtToken)

Unit Tests:

Now, let's do some tests.

/**
 * Test class for JwtHelper functionality.
 */
class JwtHelperTests {

    /**
     * Test method covering various functionalities of the JwtHelper class.
     */
    @Test
    fun test() {
        // Initialize JwtHelper with a secret key
        JwtHelper.init("abc")

        // Generate a JWT payload with subject and additional data
        val payload = JwtHelper.generatePayload(
            subject = "rommansabbir",
            additionalData = mapOf("userId" to "123", "role" to "admin")
        )
        println("Payload: $payload")

        // Generate a JWT token based on the generated payload
        val token = JwtHelper.generateJwtToken(payload)
        println("Generated Token: $token")
        assert(token.isNotEmpty())

        // Verify the validity of the generated token
        val isValid = JwtHelper.verifyJwtToken(token)
        assert(isValid)
        if (isValid) {
            println("Token is valid")
        } else {
            println("Token is invalid")
        }

        // Refresh the generated token
        val newToken = JwtHelper.refreshJwtToken(token)
        println("Refreshed Token: $newToken")
        assert(newToken.isNotEmpty())

        // Extract and assert the payload from the original token
        val extractedPayload = JwtHelper.extractPayload(token)
        assert(extractedPayload?.isNotEmpty() == true)
    }
}

The JwtHelperTests class contains a test method named test that covers various functionalities of the JwtHelper class. Let's go through each section of the test:

  1. Initialization and Payload Generation:

    • The init method is called to initialize the JwtHelper with a secret key ("abc").

    • The generatePayload method is then used to create a JWT payload with the subject "rommansabbir" and additional data containing "userId" and "role" information.

JwtHelper.init("abc")
val payload = JwtHelper.generatePayload(
    subject = "rommansabbir",
    additionalData = mapOf("userId" to "123", "role" to "admin")
)
println("Payload: $payload")
  1. Token Generation:

    • The generateJwtToken method is used to generate a JWT token based on the generated payload.

    • The generated token is then printed to the console.

val token = JwtHelper.generateJwtToken(payload)
println("Generated Token: $token")
assert(token.isNotEmpty())
  1. Token Verification:

    • The verifyJwtToken method is called to verify the validity of the generated token.

    • The result of the verification is printed to the console.

val isValid = JwtHelper.verifyJwtToken(token)
assert(isValid)
if (isValid) {
    println("Token is valid")
} else {
    println("Token is invalid")
}
  1. Token Refreshment:

    • The refreshJwtToken method is used to refresh the generated token.

    • The refreshed token is printed to the console.

val newToken = JwtHelper.refreshJwtToken(token)
println("Refreshed Token: $newToken")
assert(newToken.isNotEmpty())
  1. Payload Extraction:

    • The extractPayload method is called to extract the payload from the original token.

    • The extracted payload is then asserted to ensure it is not null and not empty.

val extractedPayload = JwtHelper.extractPayload(token)
assert(extractedPayload?.isNotEmpty() == true)

These test cases cover the initialization, payload generation, token generation, token verification, token refreshment, and payload extraction functionalities of the JwtHelper class. The println statements and assertions help in observing and verifying the behavior of the methods during testing.

Conclusion:

In this article, we have explored the implementation of a custom JWT token utility in a Spring Boot application using Kotlin. The JwtHelper class provides methods for initializing, generating payloads, creating JWT tokens, verifying tokens, refreshing tokens, and checking token expiration. This utility can be a valuable addition to your Spring Boot projects, enhancing security and facilitating secure communication between different components.


That's it for today. Happy Coding...

ย