|
| 1 | +/** |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0. |
| 4 | + */ |
| 5 | + |
| 6 | +package customkeyopspubsub; |
| 7 | + |
| 8 | +import software.amazon.awssdk.crt.CRT; |
| 9 | +import software.amazon.awssdk.crt.CrtResource; |
| 10 | +import software.amazon.awssdk.crt.CrtRuntimeException; |
| 11 | +import software.amazon.awssdk.crt.io.*; |
| 12 | +import software.amazon.awssdk.crt.mqtt.*; |
| 13 | +import software.amazon.awssdk.iot.AwsIotMqttConnectionBuilder; |
| 14 | + |
| 15 | +import software.amazon.awssdk.crt.Log; |
| 16 | +import software.amazon.awssdk.crt.Log.LogLevel; |
| 17 | + |
| 18 | +import java.io.BufferedReader; |
| 19 | +import java.io.ByteArrayOutputStream; |
| 20 | +import java.io.FileReader; |
| 21 | +import java.nio.charset.StandardCharsets; |
| 22 | +import java.security.KeyFactory; |
| 23 | +import java.security.PrivateKey; |
| 24 | +import java.security.Signature; |
| 25 | +import java.security.interfaces.RSAPrivateKey; |
| 26 | +import java.security.spec.PKCS8EncodedKeySpec; |
| 27 | +import java.util.Base64; |
| 28 | +import java.util.UUID; |
| 29 | +import java.util.concurrent.CompletableFuture; |
| 30 | +import java.util.concurrent.CountDownLatch; |
| 31 | +import java.util.concurrent.ExecutionException; |
| 32 | + |
| 33 | +import utils.commandlineutils.CommandLineUtils; |
| 34 | + |
| 35 | +public class CustomKeyOpsPubSub { |
| 36 | + |
| 37 | + // When run normally, we want to exit nicely even if something goes wrong |
| 38 | + // When run from CI, we want to let an exception escape which in turn causes the |
| 39 | + // exec:java task to return a non-zero exit code |
| 40 | + static String ciPropValue = System.getProperty("aws.crt.ci"); |
| 41 | + static boolean isCI = ciPropValue != null && Boolean.valueOf(ciPropValue); |
| 42 | + |
| 43 | + static CommandLineUtils cmdUtils; |
| 44 | + |
| 45 | + static String topic = "test/topic"; |
| 46 | + static String message = "Hello World!"; |
| 47 | + static int messagesToPublish = 10; |
| 48 | + static String certPath; |
| 49 | + static String keyPath; |
| 50 | + |
| 51 | + /* |
| 52 | + * When called during a CI run, throw an exception that will escape and fail the exec:java task |
| 53 | + * When called otherwise, print what went wrong (if anything) and just continue (return from main) |
| 54 | + */ |
| 55 | + static void onApplicationFailure(Throwable cause) { |
| 56 | + if (isCI) { |
| 57 | + throw new RuntimeException("CustomKeyOpsPubSub execution failure", cause); |
| 58 | + } else if (cause != null) { |
| 59 | + System.out.println("Exception encountered: " + cause.toString()); |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + static class MyKeyOperationHandler implements TlsKeyOperationHandler { |
| 64 | + RSAPrivateKey key; |
| 65 | + |
| 66 | + MyKeyOperationHandler(String keyPath) { |
| 67 | + key = loadPrivateKey(keyPath); |
| 68 | + } |
| 69 | + |
| 70 | + public void performOperation(TlsKeyOperation operation) { |
| 71 | + try { |
| 72 | + System.out.println("MyKeyOperationHandler.performOperation" + operation.getType().name()); |
| 73 | + |
| 74 | + if (operation.getType() != TlsKeyOperation.Type.SIGN) { |
| 75 | + throw new RuntimeException("Simple sample only handles SIGN operations"); |
| 76 | + } |
| 77 | + |
| 78 | + if (operation.getSignatureAlgorithm() != TlsSignatureAlgorithm.RSA) { |
| 79 | + throw new RuntimeException("Simple sample only handles RSA keys"); |
| 80 | + } |
| 81 | + |
| 82 | + if (operation.getDigestAlgorithm() != TlsHashAlgorithm.SHA256) { |
| 83 | + throw new RuntimeException("Simple sample only handles SHA256 digests"); |
| 84 | + } |
| 85 | + |
| 86 | + // A SIGN operation's inputData is the 32bytes of the SHA-256 digest. |
| 87 | + // Before doing the RSA signature, we need to construct a PKCS1 v1.5 DigestInfo. |
| 88 | + // See https://datatracker.ietf.org/doc/html/rfc3447#section-9.2 |
| 89 | + byte[] digest = operation.getInput(); |
| 90 | + |
| 91 | + // These are the appropriate bytes for the SHA-256 AlgorithmIdentifier: |
| 92 | + // https://tools.ietf.org/html/rfc3447#page-43 |
| 93 | + byte[] sha256DigestAlgorithm = { 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, (byte)0x86, 0x48, 0x01, |
| 94 | + 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20 }; |
| 95 | + |
| 96 | + ByteArrayOutputStream digestInfoStream = new ByteArrayOutputStream(); |
| 97 | + digestInfoStream.write(sha256DigestAlgorithm); |
| 98 | + digestInfoStream.write(digest); |
| 99 | + byte[] digestInfo = digestInfoStream.toByteArray(); |
| 100 | + |
| 101 | + // Sign the DigestInfo |
| 102 | + Signature rsaSign = Signature.getInstance("NONEwithRSA"); |
| 103 | + rsaSign.initSign(key); |
| 104 | + rsaSign.update(digestInfo); |
| 105 | + byte[] signatureBytes = rsaSign.sign(); |
| 106 | + |
| 107 | + operation.complete(signatureBytes); |
| 108 | + |
| 109 | + } catch (Exception ex) { |
| 110 | + System.out.println("Error during key operation:" + ex); |
| 111 | + operation.completeExceptionally(ex); |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + RSAPrivateKey loadPrivateKey(String filepath) { |
| 116 | + /* Adapted from: https://stackoverflow.com/a/27621696 |
| 117 | + * You probably need to convert your private key file from PKCS#1 |
| 118 | + * to PKCS#8 to get it working with this sample: |
| 119 | + * |
| 120 | + * $ openssl pkcs8 -topk8 -in my-private.pem.key -out my-private-pk8.pem.key -nocrypt |
| 121 | + * |
| 122 | + * IoT Core vends keys as PKCS#1 by default, |
| 123 | + * but Java only seems to have this PKCS8EncodedKeySpec class */ |
| 124 | + try { |
| 125 | + /* Read the BASE64-encoded contents of the private key file */ |
| 126 | + StringBuilder pemBase64 = new StringBuilder(); |
| 127 | + try (BufferedReader reader = new BufferedReader(new FileReader(filepath))) { |
| 128 | + String line; |
| 129 | + while ((line = reader.readLine()) != null) { |
| 130 | + // Strip off PEM header and footer |
| 131 | + if (line.startsWith("---")) { |
| 132 | + if (line.contains("RSA")) { |
| 133 | + throw new RuntimeException("private key must be converted from PKCS#1 to PKCS#8"); |
| 134 | + } |
| 135 | + continue; |
| 136 | + } |
| 137 | + pemBase64.append(line); |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + String pemBase64String = pemBase64.toString(); |
| 142 | + byte[] der = Base64.getDecoder().decode(pemBase64String); |
| 143 | + |
| 144 | + /* Create PrivateKey instance */ |
| 145 | + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(der); |
| 146 | + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); |
| 147 | + PrivateKey privateKey = keyFactory.generatePrivate(keySpec); |
| 148 | + return (RSAPrivateKey)privateKey; |
| 149 | + |
| 150 | + } catch (Exception ex) { |
| 151 | + throw new RuntimeException(ex); |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + public static void main(String[] args) { |
| 157 | + |
| 158 | + cmdUtils = new CommandLineUtils(); |
| 159 | + cmdUtils.registerProgramName("CustomKeyOpsPubSub"); |
| 160 | + cmdUtils.addCommonMQTTCommands(); |
| 161 | + cmdUtils.addCommonTopicMessageCommands(); |
| 162 | + cmdUtils.registerCommand("key", "<path>", "Path to your PKCS#8 key in PEM format."); |
| 163 | + cmdUtils.registerCommand("cert", "<path>", "Path to your client certificate in PEM format."); |
| 164 | + cmdUtils.registerCommand("client_id", "<int>", "Client id to use (optional, default='test-*')."); |
| 165 | + cmdUtils.registerCommand("port", "<int>", "Port to connect to on the endpoint (optional, default='8883')."); |
| 166 | + cmdUtils.registerCommand("count", "<int>", "Number of messages to publish (optional, default='10')."); |
| 167 | + cmdUtils.sendArguments(args); |
| 168 | + |
| 169 | + keyPath = cmdUtils.getCommandRequired("key", ""); |
| 170 | + certPath = cmdUtils.getCommandRequired("cert", ""); |
| 171 | + |
| 172 | + topic = cmdUtils.getCommandOrDefault("topic", topic); |
| 173 | + message = cmdUtils.getCommandOrDefault("message", message); |
| 174 | + messagesToPublish = Integer.parseInt(cmdUtils.getCommandOrDefault("count", String.valueOf(messagesToPublish))); |
| 175 | + |
| 176 | + MqttClientConnectionEvents callbacks = new MqttClientConnectionEvents() { |
| 177 | + @Override |
| 178 | + public void onConnectionInterrupted(int errorCode) { |
| 179 | + if (errorCode != 0) { |
| 180 | + System.out.println("Connection interrupted: " + errorCode + ": " + CRT.awsErrorString(errorCode)); |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + @Override |
| 185 | + public void onConnectionResumed(boolean sessionPresent) { |
| 186 | + System.out.println("Connection resumed: " + (sessionPresent ? "existing session" : "clean session")); |
| 187 | + } |
| 188 | + }; |
| 189 | + |
| 190 | + MyKeyOperationHandler myKeyOperationHandler = new MyKeyOperationHandler(keyPath); |
| 191 | + TlsContextCustomKeyOperationOptions keyOperationOptions = new TlsContextCustomKeyOperationOptions(myKeyOperationHandler) |
| 192 | + .withCertificateFilePath(certPath); |
| 193 | + |
| 194 | + try { |
| 195 | + MqttClientConnection connection = cmdUtils.buildCustomKeyOperationConnection(callbacks, keyOperationOptions); |
| 196 | + if (connection == null) |
| 197 | + { |
| 198 | + onApplicationFailure(new RuntimeException("MQTT connection creation failed!")); |
| 199 | + } |
| 200 | + |
| 201 | + CompletableFuture<Boolean> connected = connection.connect(); |
| 202 | + try { |
| 203 | + boolean sessionPresent = connected.get(); |
| 204 | + System.out.println("Connected to " + (!sessionPresent ? "new" : "existing") + " session!"); |
| 205 | + } catch (Exception ex) { |
| 206 | + throw new RuntimeException("Exception occurred during connect", ex); |
| 207 | + } |
| 208 | + |
| 209 | + CountDownLatch countDownLatch = new CountDownLatch(messagesToPublish); |
| 210 | + |
| 211 | + CompletableFuture<Integer> subscribed = connection.subscribe(topic, QualityOfService.AT_LEAST_ONCE, (message) -> { |
| 212 | + String payload = new String(message.getPayload(), StandardCharsets.UTF_8); |
| 213 | + System.out.println("MESSAGE: " + payload); |
| 214 | + countDownLatch.countDown(); |
| 215 | + }); |
| 216 | + |
| 217 | + subscribed.get(); |
| 218 | + |
| 219 | + int count = 0; |
| 220 | + while (count++ < messagesToPublish) { |
| 221 | + CompletableFuture<Integer> published = connection.publish(new MqttMessage(topic, message.getBytes(), QualityOfService.AT_LEAST_ONCE, false)); |
| 222 | + published.get(); |
| 223 | + Thread.sleep(1000); |
| 224 | + } |
| 225 | + |
| 226 | + countDownLatch.await(); |
| 227 | + |
| 228 | + CompletableFuture<Void> disconnected = connection.disconnect(); |
| 229 | + disconnected.get(); |
| 230 | + |
| 231 | + connection.close(); |
| 232 | + |
| 233 | + } catch (CrtRuntimeException | InterruptedException | ExecutionException ex) { |
| 234 | + onApplicationFailure(ex); |
| 235 | + } |
| 236 | + |
| 237 | + CrtResource.waitForNoResources(); |
| 238 | + System.out.println("Complete!"); |
| 239 | + } |
| 240 | +} |
0 commit comments