1use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37 primitives::Blob,
38 types::{MessageType, SigningAlgorithmSpec},
39 Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::{collections::HashMap, sync::Arc};
44use tokio::sync::RwLock;
45
46use crate::{
47 models::{Address, AwsKmsSignerConfig},
48 services::{
49 client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der,
50 },
51 utils::{
52 self, aws_error::DisplayErrorContext, classify_sdk_error, derive_ethereum_address_from_der,
53 derive_solana_address_from_der, derive_stellar_address_from_der,
54 },
55};
56use tracing::{debug, warn};
57
58#[cfg(test)]
59use mockall::{automock, mock};
60
61#[derive(Clone, Debug, thiserror::Error, Serialize)]
62pub enum AwsKmsError {
63 #[error("AWS KMS response parse error: {0}")]
64 ParseError(String),
65 #[error("AWS KMS config error: {0}")]
66 ConfigError(String),
67 #[error("AWS KMS get error: {0}")]
68 GetError(String),
69 #[error("AWS KMS signing error: {0}")]
70 SignError(String),
71 #[error("AWS KMS public key error: {0}")]
72 RecoveryError(#[from] utils::Secp256k1Error),
73 #[error("AWS KMS conversion error: {0}")]
74 ConvertError(String),
75 #[error("AWS KMS Other error: {0}")]
76 Other(String),
77}
78
79pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
80
81#[async_trait]
82#[cfg_attr(test, automock)]
83pub trait AwsKmsEvmService: Send + Sync {
84 async fn get_evm_address(&self) -> AwsKmsResult<Address>;
86 async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
96
97 async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>>;
107}
108
109#[async_trait]
110#[cfg_attr(test, automock)]
111pub trait AwsKmsK256: Send + Sync {
112 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
114 async fn sign_digest<'a, 'b>(
116 &'a self,
117 key_id: &'b str,
118 digest: [u8; 32],
119 ) -> AwsKmsResult<Vec<u8>>;
120}
121
122#[async_trait]
125#[cfg_attr(test, automock)]
126pub trait AwsKmsEd25519: Send + Sync {
127 async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
129 async fn sign_ed25519<'a, 'b>(
132 &'a self,
133 key_id: &'b str,
134 message: &'b [u8],
135 ) -> AwsKmsResult<Vec<u8>>;
136}
137
138#[async_trait]
140#[cfg_attr(test, automock)]
141pub trait AwsKmsSolanaService: Send + Sync {
142 async fn get_solana_address(&self) -> AwsKmsResult<Address>;
144 async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
146}
147
148#[async_trait]
150#[cfg_attr(test, automock)]
151pub trait AwsKmsStellarService: Send + Sync {
152 async fn get_stellar_address(&self) -> AwsKmsResult<Address>;
154 async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
156}
157
158#[cfg(test)]
159mock! {
160 pub AwsKmsClient { }
161 impl Clone for AwsKmsClient {
162 fn clone(&self) -> Self;
163 }
164
165 #[async_trait]
166 impl AwsKmsK256 for AwsKmsClient {
167 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
168 async fn sign_digest<'a, 'b>(
169 &'a self,
170 key_id: &'b str,
171 digest: [u8; 32],
172 ) -> AwsKmsResult<Vec<u8>>;
173 }
174
175 #[async_trait]
176 impl AwsKmsEd25519 for AwsKmsClient {
177 async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
178 async fn sign_ed25519<'a, 'b>(
179 &'a self,
180 key_id: &'b str,
181 message: &'b [u8],
182 ) -> AwsKmsResult<Vec<u8>>;
183 }
184}
185
186static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
188 Lazy::new(|| RwLock::new(HashMap::new()));
189
190static KMS_ED25519_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
192 Lazy::new(|| RwLock::new(HashMap::new()));
193
194#[derive(Clone, Debug, Eq, PartialEq, Hash)]
195struct AwsKmsClientKey {
196 region: String,
197}
198
199static KMS_CLIENT_CACHE: Lazy<AsyncClientCache<AwsKmsClientKey, Client>> =
200 Lazy::new(AsyncClientCache::new);
201
202async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult<Arc<Client>> {
205 let resolved_region = resolve_aws_region(config).await?;
206 let key = AwsKmsClientKey {
207 region: resolved_region.clone(),
208 };
209
210 KMS_CLIENT_CACHE
211 .get_or_try_init(key, || async {
212 debug!(
213 region = %resolved_region,
214 "Creating new AWS KMS client"
215 );
216 let auth_config = aws_config::defaults(BehaviorVersion::latest())
217 .region(Region::new(resolved_region))
218 .load()
219 .await;
220
221 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config)))
224 .map_err(|panic| {
225 let msg = panic
226 .downcast_ref::<String>()
227 .map(|s| s.as_str())
228 .or_else(|| panic.downcast_ref::<&str>().copied())
229 .unwrap_or("unknown panic");
230 AwsKmsError::ConfigError(format!(
231 "Failed to initialize AWS KMS client (check TLS root certificates): {msg}"
232 ))
233 })
234 })
235 .await
236}
237
238async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult<String> {
240 if let Some(region) = &config.region {
241 return Ok(region.clone());
242 }
243
244 let provider = RegionProviderChain::default_provider();
245 provider
246 .region()
247 .await
248 .map(|r| r.to_string())
249 .ok_or_else(|| {
250 AwsKmsError::ConfigError(
251 "AWS region not specified and could not be resolved from environment".to_string(),
252 )
253 })
254}
255
256#[derive(Debug, Clone)]
257pub struct AwsKmsClient {
258 inner: Arc<Client>,
259}
260
261#[async_trait]
262impl AwsKmsK256 for AwsKmsClient {
263 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
264 let cached = {
266 let cache_read = KMS_DER_PK_CACHE.read().await;
267 cache_read.get(key_id).cloned()
268 };
269 if let Some(cached) = cached {
270 return Ok(cached);
271 }
272
273 let get_output = self
275 .inner
276 .get_public_key()
277 .key_id(key_id)
278 .send()
279 .await
280 .map_err(|e| {
281 warn!(
282 error.kind = classify_sdk_error(&e),
283 error.detail = %DisplayErrorContext(&e),
284 kms_key_id = %key_id,
285 operation = "get_public_key_secp256k1",
286 "AWS KMS get_public_key failed"
287 );
288 AwsKmsError::GetError(format!(
289 "Failed to get secp256k1 public key for key '{key_id}': {}",
290 classify_sdk_error(&e)
291 ))
292 })?;
293
294 let der_pk_blob = get_output
295 .public_key
296 .ok_or(AwsKmsError::GetError(
297 "No public key blob found".to_string(),
298 ))?
299 .into_inner();
300
301 let mut cache_write = KMS_DER_PK_CACHE.write().await;
302 cache_write.insert(key_id.to_string(), der_pk_blob.clone());
303
304 Ok(der_pk_blob)
305 }
306
307 async fn sign_digest<'a, 'b>(
308 &'a self,
309 key_id: &'b str,
310 digest: [u8; 32],
311 ) -> AwsKmsResult<Vec<u8>> {
312 let sign_result = self
314 .inner
315 .sign()
316 .key_id(key_id)
317 .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
318 .message_type(MessageType::Digest)
319 .message(Blob::new(digest))
320 .send()
321 .await;
322
323 let der_signature = sign_result
325 .map_err(|e| {
326 warn!(
327 error.kind = classify_sdk_error(&e),
328 error.detail = %DisplayErrorContext(&e),
329 kms_key_id = %key_id,
330 operation = "sign_digest_secp256k1",
331 "AWS KMS sign failed"
332 );
333 AwsKmsError::SignError(format!(
334 "Failed to sign secp256k1 digest for key '{key_id}': {}",
335 classify_sdk_error(&e)
336 ))
337 })?
338 .signature
339 .ok_or(AwsKmsError::SignError(
340 "Signature not found in response".to_string(),
341 ))?
342 .into_inner();
343
344 Ok(der_signature)
345 }
346}
347
348#[async_trait]
349impl AwsKmsEd25519 for AwsKmsClient {
350 async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
351 let cached = {
353 let cache_read = KMS_ED25519_PK_CACHE.read().await;
354 cache_read.get(key_id).cloned()
355 };
356 if let Some(cached) = cached {
357 return Ok(cached);
358 }
359
360 let get_output = self
362 .inner
363 .get_public_key()
364 .key_id(key_id)
365 .send()
366 .await
367 .map_err(|e| {
368 warn!(
369 error.kind = classify_sdk_error(&e),
370 error.detail = %DisplayErrorContext(&e),
371 kms_key_id = %key_id,
372 operation = "get_public_key_ed25519",
373 "AWS KMS get_public_key failed"
374 );
375 AwsKmsError::GetError(format!(
376 "Failed to get Ed25519 public key for key '{key_id}': {}",
377 classify_sdk_error(&e)
378 ))
379 })?;
380
381 let der_pk_blob = get_output
382 .public_key
383 .ok_or(AwsKmsError::GetError(
384 "No public key blob found".to_string(),
385 ))?
386 .into_inner();
387
388 let mut cache_write = KMS_ED25519_PK_CACHE.write().await;
389 cache_write.insert(key_id.to_string(), der_pk_blob.clone());
390
391 Ok(der_pk_blob)
392 }
393
394 async fn sign_ed25519<'a, 'b>(
395 &'a self,
396 key_id: &'b str,
397 message: &'b [u8],
398 ) -> AwsKmsResult<Vec<u8>> {
399 debug!("Signing Ed25519 message with AWS KMS, key_id: {}", key_id);
400
401 let sign_result = self
404 .inner
405 .sign()
406 .key_id(key_id)
407 .signing_algorithm(SigningAlgorithmSpec::Ed25519Sha512)
408 .message_type(MessageType::Raw)
409 .message(Blob::new(message))
410 .send()
411 .await;
412
413 let signature = sign_result
415 .map_err(|e| {
416 warn!(
417 error.kind = classify_sdk_error(&e),
418 error.detail = %DisplayErrorContext(&e),
419 kms_key_id = %key_id,
420 operation = "sign_ed25519",
421 "AWS KMS sign failed"
422 );
423 AwsKmsError::SignError(format!(
424 "Failed to sign Ed25519 message for key '{key_id}': {}",
425 classify_sdk_error(&e)
426 ))
427 })?
428 .signature
429 .ok_or(AwsKmsError::SignError(
430 "Signature not found in response".to_string(),
431 ))?
432 .into_inner();
433
434 if signature.len() != 64 {
436 return Err(AwsKmsError::SignError(format!(
437 "Invalid Ed25519 signature length: expected 64 bytes, got {}",
438 signature.len()
439 )));
440 }
441
442 Ok(signature)
443 }
444}
445
446#[derive(Debug, Clone)]
447pub struct AwsKmsService<T: AwsKmsK256 + AwsKmsEd25519 + Clone = AwsKmsClient> {
448 pub kms_key_id: String,
449 client: T,
450}
451
452impl AwsKmsService<AwsKmsClient> {
453 pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
454 let shared_client = get_or_create_kms_client(&config).await?;
455
456 Ok(Self {
457 kms_key_id: config.key_id,
458 client: AwsKmsClient {
459 inner: shared_client,
460 },
461 })
462 }
463}
464
465#[cfg(test)]
466impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
467 pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
468 Self {
469 client,
470 kms_key_id: config.key_id,
471 }
472 }
473}
474
475impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
476 async fn sign_and_recover_evm(
485 &self,
486 digest: [u8; 32],
487 original_bytes: &[u8],
488 use_prehash_recovery: bool,
489 ) -> AwsKmsResult<Vec<u8>> {
490 let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
492
493 let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
495
496 recover_evm_signature_from_der(
498 &der_signature,
499 &der_pk,
500 digest,
501 original_bytes,
502 use_prehash_recovery,
503 )
504 .map_err(|e| AwsKmsError::ParseError(e.to_string()))
505 }
506
507 pub async fn sign_payload_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
517 let digest = keccak256(bytes).0;
518 self.sign_and_recover_evm(digest, bytes, false).await
519 }
520
521 pub async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
531 self.sign_and_recover_evm(*hash, hash, true).await
532 }
533}
534
535#[async_trait]
536impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsEvmService for AwsKmsService<T> {
537 async fn get_evm_address(&self) -> AwsKmsResult<Address> {
538 let der = self.client.get_der_public_key(&self.kms_key_id).await?;
539 let eth_address = derive_ethereum_address_from_der(&der)
540 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
541 Ok(Address::Evm(eth_address))
542 }
543
544 async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
545 let digest = keccak256(message).0;
546 self.sign_and_recover_evm(digest, message, false).await
547 }
548
549 async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
550 self.sign_and_recover_evm(*hash, hash, true).await
552 }
553}
554
555#[async_trait]
556impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsSolanaService for AwsKmsService<T> {
557 async fn get_solana_address(&self) -> AwsKmsResult<Address> {
558 let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
559 let solana_address = derive_solana_address_from_der(&der)
560 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
561 Ok(Address::Solana(solana_address))
562 }
563
564 async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
565 self.client.sign_ed25519(&self.kms_key_id, message).await
566 }
567}
568
569#[async_trait]
570impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsStellarService for AwsKmsService<T> {
571 async fn get_stellar_address(&self) -> AwsKmsResult<Address> {
572 let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
573 let stellar_address = derive_stellar_address_from_der(&der)
574 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
575 Ok(Address::Stellar(stellar_address))
576 }
577
578 async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
579 self.client.sign_ed25519(&self.kms_key_id, message).await
580 }
581}
582
583#[cfg(test)]
584pub mod tests {
585 use super::*;
586
587 use alloy::primitives::utils::eip191_message;
588 use k256::{
589 ecdsa::SigningKey,
590 elliptic_curve::rand_core::OsRng,
591 pkcs8::{der::Encode, EncodePublicKey},
592 };
593 use mockall::predicate::{eq, ne};
594
595 pub struct TestEd25519Keys {
597 pub public_key_der: Vec<u8>,
598 pub public_key_raw: [u8; 32],
599 }
600
601 impl Default for TestEd25519Keys {
602 fn default() -> Self {
603 Self::new()
604 }
605 }
606
607 impl TestEd25519Keys {
608 pub fn new() -> Self {
609 let public_key_raw: [u8; 32] = [
611 0x9d, 0x45, 0x7e, 0x45, 0xe4, 0x16, 0xc4, 0xc6, 0x77, 0x67, 0x6a, 0x42, 0xff, 0x96,
612 0x8e, 0x3c, 0xf8, 0xdc, 0x73, 0xc8, 0xf3, 0x3a, 0x8d, 0x19, 0x81, 0x29, 0x7b, 0xfa,
613 0x3e, 0x00, 0x30, 0xba,
614 ];
615
616 let mut public_key_der = vec![
618 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, ];
624 public_key_der.extend_from_slice(&public_key_raw);
625
626 Self {
627 public_key_der,
628 public_key_raw,
629 }
630 }
631 }
632
633 pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
634 let mut client = MockAwsKmsClient::new();
635 let signing_key = SigningKey::random(&mut OsRng);
636 let s = signing_key
637 .verifying_key()
638 .to_public_key_der()
639 .unwrap()
640 .to_der()
641 .unwrap();
642
643 client
644 .expect_get_der_public_key()
645 .with(eq("test-key-id"))
646 .return_const(Ok(s));
647 client
648 .expect_get_der_public_key()
649 .with(ne("test-key-id"))
650 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
651
652 client
653 .expect_sign_digest()
654 .withf(|key_id, _| key_id.ne("test-key-id"))
655 .return_const(Err(AwsKmsError::SignError(
656 "Key does not exist".to_string(),
657 )));
658
659 let key = signing_key.clone();
660 client
661 .expect_sign_digest()
662 .withf(|key_id, _| key_id.eq("test-key-id"))
663 .returning(move |_, digest| {
664 let (signature, _) = signing_key
665 .sign_prehash_recoverable(&digest)
666 .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
667 let der_signature = signature.to_der().as_bytes().to_vec();
668 Ok(der_signature)
669 });
670
671 let test_ed25519_keys = TestEd25519Keys::new();
673 client
674 .expect_get_ed25519_public_key()
675 .with(eq("test-key-id"))
676 .return_const(Ok(test_ed25519_keys.public_key_der.clone()));
677 client
678 .expect_get_ed25519_public_key()
679 .with(ne("test-key-id"))
680 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
681
682 client
684 .expect_sign_ed25519()
685 .withf(|key_id, _| key_id.eq("test-key-id"))
686 .returning(|_, _| Ok(vec![0u8; 64]));
687 client
688 .expect_sign_ed25519()
689 .withf(|key_id, _| key_id.ne("test-key-id"))
690 .return_const(Err(AwsKmsError::SignError(
691 "Key does not exist".to_string(),
692 )));
693
694 client.expect_clone().return_once(MockAwsKmsClient::new);
695
696 (client, key)
697 }
698
699 #[tokio::test]
700 async fn test_get_public_key() {
701 let (mock_client, key) = setup_mock_kms_client();
702 let kms = AwsKmsService::new_for_testing(
703 mock_client,
704 AwsKmsSignerConfig {
705 region: Some("us-east-1".to_string()),
706 key_id: "test-key-id".to_string(),
707 },
708 );
709
710 let result = kms.get_evm_address().await;
711 assert!(result.is_ok());
712 if let Ok(Address::Evm(evm_address)) = result {
713 let expected_address = derive_ethereum_address_from_der(
714 key.verifying_key().to_public_key_der().unwrap().as_bytes(),
715 )
716 .unwrap();
717 assert_eq!(expected_address, evm_address);
718 }
719 }
720
721 #[tokio::test]
722 async fn test_get_public_key_fail() {
723 let (mock_client, _) = setup_mock_kms_client();
724 let kms = AwsKmsService::new_for_testing(
725 mock_client,
726 AwsKmsSignerConfig {
727 region: Some("us-east-1".to_string()),
728 key_id: "invalid-key-id".to_string(),
729 },
730 );
731
732 let result = kms.get_evm_address().await;
733 assert!(result.is_err());
734 if let Err(err) = result {
735 assert!(matches!(err, AwsKmsError::GetError(_)))
736 }
737 }
738
739 #[tokio::test]
740 async fn test_sign_digest() {
741 let (mock_client, _) = setup_mock_kms_client();
742 let kms = AwsKmsService::new_for_testing(
743 mock_client,
744 AwsKmsSignerConfig {
745 region: Some("us-east-1".to_string()),
746 key_id: "test-key-id".to_string(),
747 },
748 );
749
750 let message_eip = eip191_message(b"Hello World!");
751 let result = kms.sign_payload_evm(&message_eip).await;
752
753 assert!(result.is_ok());
755 }
756
757 #[tokio::test]
758 async fn test_sign_digest_fail() {
759 let (mock_client, _) = setup_mock_kms_client();
760 let kms = AwsKmsService::new_for_testing(
761 mock_client,
762 AwsKmsSignerConfig {
763 region: Some("us-east-1".to_string()),
764 key_id: "invalid-key-id".to_string(),
765 },
766 );
767
768 let message_eip = eip191_message(b"Hello World!");
769 let result = kms.sign_payload_evm(&message_eip).await;
770 assert!(result.is_err());
771 if let Err(err) = result {
772 assert!(matches!(err, AwsKmsError::SignError(_)))
773 }
774 }
775
776 #[tokio::test]
777 async fn test_get_solana_address() {
778 let (mock_client, _) = setup_mock_kms_client();
779 let kms = AwsKmsService::new_for_testing(
780 mock_client,
781 AwsKmsSignerConfig {
782 region: Some("us-east-1".to_string()),
783 key_id: "test-key-id".to_string(),
784 },
785 );
786
787 let result = kms.get_solana_address().await;
788 assert!(result.is_ok());
789 if let Ok(Address::Solana(solana_address)) = result {
790 assert!(!solana_address.is_empty());
792 assert!(solana_address.len() >= 32 && solana_address.len() <= 44);
793 let test_keys = TestEd25519Keys::new();
795 let expected_address = bs58::encode(test_keys.public_key_raw).into_string();
796 assert_eq!(solana_address, expected_address);
797 } else {
798 panic!("Expected Solana address");
799 }
800 }
801
802 #[tokio::test]
803 async fn test_get_solana_address_fail() {
804 let (mock_client, _) = setup_mock_kms_client();
805 let kms = AwsKmsService::new_for_testing(
806 mock_client,
807 AwsKmsSignerConfig {
808 region: Some("us-east-1".to_string()),
809 key_id: "invalid-key-id".to_string(),
810 },
811 );
812
813 let result = kms.get_solana_address().await;
814 assert!(result.is_err());
815 if let Err(err) = result {
816 assert!(matches!(err, AwsKmsError::GetError(_)))
817 }
818 }
819
820 #[tokio::test]
821 async fn test_sign_solana() {
822 let (mock_client, _) = setup_mock_kms_client();
823 let kms = AwsKmsService::new_for_testing(
824 mock_client,
825 AwsKmsSignerConfig {
826 region: Some("us-east-1".to_string()),
827 key_id: "test-key-id".to_string(),
828 },
829 );
830
831 let message = b"Test Solana message";
832 let result = kms.sign_solana(message).await;
833 assert!(result.is_ok());
834 let signature = result.unwrap();
835 assert_eq!(signature.len(), 64); }
837
838 #[tokio::test]
839 async fn test_sign_solana_fail() {
840 let (mock_client, _) = setup_mock_kms_client();
841 let kms = AwsKmsService::new_for_testing(
842 mock_client,
843 AwsKmsSignerConfig {
844 region: Some("us-east-1".to_string()),
845 key_id: "invalid-key-id".to_string(),
846 },
847 );
848
849 let message = b"Test Solana message";
850 let result = kms.sign_solana(message).await;
851 assert!(result.is_err());
852 if let Err(err) = result {
853 assert!(matches!(err, AwsKmsError::SignError(_)))
854 }
855 }
856
857 #[tokio::test]
858 async fn test_get_stellar_address() {
859 let (mock_client, _) = setup_mock_kms_client();
860 let kms = AwsKmsService::new_for_testing(
861 mock_client,
862 AwsKmsSignerConfig {
863 region: Some("us-east-1".to_string()),
864 key_id: "test-key-id".to_string(),
865 },
866 );
867
868 let result = kms.get_stellar_address().await;
869 assert!(result.is_ok());
870 if let Ok(Address::Stellar(stellar_address)) = result {
871 assert!(stellar_address.starts_with('G'));
873 assert_eq!(stellar_address.len(), 56);
875 } else {
876 panic!("Expected Stellar address");
877 }
878 }
879
880 #[tokio::test]
881 async fn test_get_stellar_address_fail() {
882 let (mock_client, _) = setup_mock_kms_client();
883 let kms = AwsKmsService::new_for_testing(
884 mock_client,
885 AwsKmsSignerConfig {
886 region: Some("us-east-1".to_string()),
887 key_id: "invalid-key-id".to_string(),
888 },
889 );
890
891 let result = kms.get_stellar_address().await;
892 assert!(result.is_err());
893 if let Err(err) = result {
894 assert!(matches!(err, AwsKmsError::GetError(_)))
895 }
896 }
897
898 #[tokio::test]
899 async fn test_sign_stellar() {
900 let (mock_client, _) = setup_mock_kms_client();
901 let kms = AwsKmsService::new_for_testing(
902 mock_client,
903 AwsKmsSignerConfig {
904 region: Some("us-east-1".to_string()),
905 key_id: "test-key-id".to_string(),
906 },
907 );
908
909 let message = b"Test Stellar message";
910 let result = kms.sign_stellar(message).await;
911 assert!(result.is_ok());
912 let signature = result.unwrap();
913 assert_eq!(signature.len(), 64); }
915
916 #[tokio::test]
917 async fn test_sign_stellar_fail() {
918 let (mock_client, _) = setup_mock_kms_client();
919 let kms = AwsKmsService::new_for_testing(
920 mock_client,
921 AwsKmsSignerConfig {
922 region: Some("us-east-1".to_string()),
923 key_id: "invalid-key-id".to_string(),
924 },
925 );
926
927 let message = b"Test Stellar message";
928 let result = kms.sign_stellar(message).await;
929 assert!(result.is_err());
930 if let Err(err) = result {
931 assert!(matches!(err, AwsKmsError::SignError(_)))
932 }
933 }
934
935 #[tokio::test]
938 async fn test_kms_client_cache_same_region_shares_client() {
939 let config1 = AwsKmsSignerConfig {
940 region: Some("us-west-2".to_string()),
941 key_id: "key-aaa".to_string(),
942 };
943 let config2 = AwsKmsSignerConfig {
944 region: Some("us-west-2".to_string()),
945 key_id: "key-bbb".to_string(),
946 };
947
948 let result1 = get_or_create_kms_client(&config1).await;
949 let result2 = get_or_create_kms_client(&config2).await;
950
951 match (result1, result2) {
952 (Ok(client1), Ok(client2)) => {
953 assert!(Arc::ptr_eq(&client1, &client2));
954 }
955 (Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => {
956 assert!(
958 msg.contains("TLS root certificates"),
959 "Expected TLS-related config error, got: {msg}"
960 );
961 }
962 (Err(e), _) | (_, Err(e)) => {
963 panic!("Unexpected error: {e:?}");
964 }
965 }
966 }
967
968 #[tokio::test]
969 async fn test_kms_client_returns_config_error_when_region_missing() {
970 let config = AwsKmsSignerConfig {
971 region: None,
972 key_id: "test-key".to_string(),
973 };
974
975 let result = get_or_create_kms_client(&config).await;
978 match result {
979 Err(AwsKmsError::ConfigError(_)) => {}
980 Ok(_) => panic!(
981 "Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env"
982 ),
983 Err(e) => panic!("Expected ConfigError, got: {e:?}"),
984 }
985 }
986}