openzeppelin_relayer/services/aws_kms/
mod.rs

1//! # AWS KMS Service Module
2//!
3//! This module provides integration with AWS KMS for secure key management
4//! and cryptographic operations such as public key retrieval and message signing.
5//!
6//! Supports EVM (secp256k1/ECDSA), Solana (Ed25519), and Stellar (Ed25519) networks.
7//!
8//! ## Features
9//!
10//! - Service account authentication using credential providers
11//! - Public key retrieval from KMS
12//! - Message signing via KMS for multiple key types
13//!
14//! ## Architecture
15//!
16//! ```text
17//! AwsKmsService (implements AwsKmsEvmService, AwsKmsSolanaService, AwsKmsStellarService)
18//!   ├── Authentication (via AwsKmsClient)
19//!   ├── Public Key Retrieval (via AwsKmsClient)
20//!   └── Message Signing (via AwsKmsClient)
21//! ```
22//! is based on
23//! ```text
24//! AwsKmsClient (implements AwsKmsK256, AwsKmsEd25519)
25//!   ├── Authentication (via shared credentials)
26//!   ├── Public Key Retrieval in DER Encoding
27//!   └── Message Signing (ECDSA for secp256k1, Ed25519 for EdDSA)
28//! ```
29//! `AwsKmsK256` and `AwsKmsEd25519` are mocked with `mockall` for unit testing
30//! and injected into `AwsKmsService`
31//!
32
33use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37    primitives::Blob,
38    types::{MessageType, SigningAlgorithmSpec},
39    Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::{collections::HashMap, sync::Arc};
44use tokio::sync::RwLock;
45
46use crate::{
47    models::{Address, AwsKmsSignerConfig},
48    services::{
49        client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der,
50    },
51    utils::{
52        self, aws_error::DisplayErrorContext, classify_sdk_error, derive_ethereum_address_from_der,
53        derive_solana_address_from_der, derive_stellar_address_from_der,
54    },
55};
56use tracing::{debug, warn};
57
58#[cfg(test)]
59use mockall::{automock, mock};
60
61#[derive(Clone, Debug, thiserror::Error, Serialize)]
62pub enum AwsKmsError {
63    #[error("AWS KMS response parse error: {0}")]
64    ParseError(String),
65    #[error("AWS KMS config error: {0}")]
66    ConfigError(String),
67    #[error("AWS KMS get error: {0}")]
68    GetError(String),
69    #[error("AWS KMS signing error: {0}")]
70    SignError(String),
71    #[error("AWS KMS public key error: {0}")]
72    RecoveryError(#[from] utils::Secp256k1Error),
73    #[error("AWS KMS conversion error: {0}")]
74    ConvertError(String),
75    #[error("AWS KMS Other error: {0}")]
76    Other(String),
77}
78
79pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
80
81#[async_trait]
82#[cfg_attr(test, automock)]
83pub trait AwsKmsEvmService: Send + Sync {
84    /// Returns the EVM address derived from the configured public key.
85    async fn get_evm_address(&self) -> AwsKmsResult<Address>;
86    /// Signs a payload using the EVM signing scheme (hashes before signing).
87    ///
88    /// This method applies keccak256 hashing before signing.
89    ///
90    /// **Use for:**
91    /// - Raw transaction data (TxLegacy, TxEip1559)
92    /// - EIP-191 personal messages
93    ///
94    /// **Note:** For EIP-712 typed data, use `sign_hash_evm()` to avoid double-hashing.
95    async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
96
97    /// Signs a pre-computed hash using the EVM signing scheme (no hashing).
98    ///
99    /// This method signs the hash directly without applying keccak256.
100    ///
101    /// **Use for:**
102    /// - EIP-712 typed data (already hashed)
103    /// - Pre-computed message digests
104    ///
105    /// **Note:** For raw data, use `sign_payload_evm()` instead.
106    async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>>;
107}
108
109#[async_trait]
110#[cfg_attr(test, automock)]
111pub trait AwsKmsK256: Send + Sync {
112    /// Fetches the DER-encoded public key from AWS KMS.
113    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
114    /// Signs a digest using EcdsaSha256 spec. Returns DER-encoded signature
115    async fn sign_digest<'a, 'b>(
116        &'a self,
117        key_id: &'b str,
118        digest: [u8; 32],
119    ) -> AwsKmsResult<Vec<u8>>;
120}
121
122/// Trait for Ed25519 (EdDSA) operations with AWS KMS.
123/// Used for Solana and Stellar signing.
124#[async_trait]
125#[cfg_attr(test, automock)]
126pub trait AwsKmsEd25519: Send + Sync {
127    /// Fetches the DER-encoded Ed25519 public key from AWS KMS.
128    async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
129    /// Signs a message using Ed25519. Returns 64-byte signature.
130    /// Uses ED25519_SHA_512 algorithm with RAW message type.
131    async fn sign_ed25519<'a, 'b>(
132        &'a self,
133        key_id: &'b str,
134        message: &'b [u8],
135    ) -> AwsKmsResult<Vec<u8>>;
136}
137
138/// Trait for Solana-specific AWS KMS operations
139#[async_trait]
140#[cfg_attr(test, automock)]
141pub trait AwsKmsSolanaService: Send + Sync {
142    /// Returns the Solana address derived from the configured Ed25519 public key.
143    async fn get_solana_address(&self) -> AwsKmsResult<Address>;
144    /// Signs a message using Ed25519 for Solana.
145    async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
146}
147
148/// Trait for Stellar-specific AWS KMS operations
149#[async_trait]
150#[cfg_attr(test, automock)]
151pub trait AwsKmsStellarService: Send + Sync {
152    /// Returns the Stellar address derived from the configured Ed25519 public key.
153    async fn get_stellar_address(&self) -> AwsKmsResult<Address>;
154    /// Signs a message using Ed25519 for Stellar.
155    async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
156}
157
158#[cfg(test)]
159mock! {
160    pub AwsKmsClient { }
161    impl Clone for AwsKmsClient {
162        fn clone(&self) -> Self;
163    }
164
165    #[async_trait]
166    impl AwsKmsK256 for AwsKmsClient {
167        async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
168        async fn sign_digest<'a, 'b>(
169            &'a self,
170            key_id: &'b str,
171            digest: [u8; 32],
172        ) -> AwsKmsResult<Vec<u8>>;
173    }
174
175    #[async_trait]
176    impl AwsKmsEd25519 for AwsKmsClient {
177        async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
178        async fn sign_ed25519<'a, 'b>(
179            &'a self,
180            key_id: &'b str,
181            message: &'b [u8],
182        ) -> AwsKmsResult<Vec<u8>>;
183    }
184}
185
186// Global cache for secp256k1 public keys - HashMap keyed by kms_key_id
187static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
188    Lazy::new(|| RwLock::new(HashMap::new()));
189
190// Global cache for Ed25519 public keys - HashMap keyed by kms_key_id
191static KMS_ED25519_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
192    Lazy::new(|| RwLock::new(HashMap::new()));
193
194#[derive(Clone, Debug, Eq, PartialEq, Hash)]
195struct AwsKmsClientKey {
196    region: String,
197}
198
199static KMS_CLIENT_CACHE: Lazy<AsyncClientCache<AwsKmsClientKey, Client>> =
200    Lazy::new(AsyncClientCache::new);
201
202/// Get or create a shared AWS KMS SDK client for the given signer config.
203/// Keyed by resolved region — one client serves all KMS keys in that region.
204async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult<Arc<Client>> {
205    let resolved_region = resolve_aws_region(config).await?;
206    let key = AwsKmsClientKey {
207        region: resolved_region.clone(),
208    };
209
210    KMS_CLIENT_CACHE
211        .get_or_try_init(key, || async {
212            debug!(
213                region = %resolved_region,
214                "Creating new AWS KMS client"
215            );
216            let auth_config = aws_config::defaults(BehaviorVersion::latest())
217                .region(Region::new(resolved_region))
218                .load()
219                .await;
220
221            // Client::new() can panic in environments without TLS root certificates
222            // (e.g., stripped containers). Catch the panic and return a typed error.
223            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config)))
224                .map_err(|panic| {
225                    let msg = panic
226                        .downcast_ref::<String>()
227                        .map(|s| s.as_str())
228                        .or_else(|| panic.downcast_ref::<&str>().copied())
229                        .unwrap_or("unknown panic");
230                    AwsKmsError::ConfigError(format!(
231                        "Failed to initialize AWS KMS client (check TLS root certificates): {msg}"
232                    ))
233                })
234        })
235        .await
236}
237
238/// Resolve the AWS region from config or the default provider chain.
239async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult<String> {
240    if let Some(region) = &config.region {
241        return Ok(region.clone());
242    }
243
244    let provider = RegionProviderChain::default_provider();
245    provider
246        .region()
247        .await
248        .map(|r| r.to_string())
249        .ok_or_else(|| {
250            AwsKmsError::ConfigError(
251                "AWS region not specified and could not be resolved from environment".to_string(),
252            )
253        })
254}
255
256#[derive(Debug, Clone)]
257pub struct AwsKmsClient {
258    inner: Arc<Client>,
259}
260
261#[async_trait]
262impl AwsKmsK256 for AwsKmsClient {
263    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
264        // Try cache first with minimal lock time
265        let cached = {
266            let cache_read = KMS_DER_PK_CACHE.read().await;
267            cache_read.get(key_id).cloned()
268        };
269        if let Some(cached) = cached {
270            return Ok(cached);
271        }
272
273        // Fetch from AWS KMS
274        let get_output = self
275            .inner
276            .get_public_key()
277            .key_id(key_id)
278            .send()
279            .await
280            .map_err(|e| {
281                warn!(
282                    error.kind = classify_sdk_error(&e),
283                    error.detail = %DisplayErrorContext(&e),
284                    kms_key_id = %key_id,
285                    operation = "get_public_key_secp256k1",
286                    "AWS KMS get_public_key failed"
287                );
288                AwsKmsError::GetError(format!(
289                    "Failed to get secp256k1 public key for key '{key_id}': {}",
290                    classify_sdk_error(&e)
291                ))
292            })?;
293
294        let der_pk_blob = get_output
295            .public_key
296            .ok_or(AwsKmsError::GetError(
297                "No public key blob found".to_string(),
298            ))?
299            .into_inner();
300
301        let mut cache_write = KMS_DER_PK_CACHE.write().await;
302        cache_write.insert(key_id.to_string(), der_pk_blob.clone());
303
304        Ok(der_pk_blob)
305    }
306
307    async fn sign_digest<'a, 'b>(
308        &'a self,
309        key_id: &'b str,
310        digest: [u8; 32],
311    ) -> AwsKmsResult<Vec<u8>> {
312        // Sign the digest with the AWS KMS
313        let sign_result = self
314            .inner
315            .sign()
316            .key_id(key_id)
317            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
318            .message_type(MessageType::Digest)
319            .message(Blob::new(digest))
320            .send()
321            .await;
322
323        // Process the result, extract DER signature
324        let der_signature = sign_result
325            .map_err(|e| {
326                warn!(
327                    error.kind = classify_sdk_error(&e),
328                    error.detail = %DisplayErrorContext(&e),
329                    kms_key_id = %key_id,
330                    operation = "sign_digest_secp256k1",
331                    "AWS KMS sign failed"
332                );
333                AwsKmsError::SignError(format!(
334                    "Failed to sign secp256k1 digest for key '{key_id}': {}",
335                    classify_sdk_error(&e)
336                ))
337            })?
338            .signature
339            .ok_or(AwsKmsError::SignError(
340                "Signature not found in response".to_string(),
341            ))?
342            .into_inner();
343
344        Ok(der_signature)
345    }
346}
347
348#[async_trait]
349impl AwsKmsEd25519 for AwsKmsClient {
350    async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
351        // Try cache first with minimal lock time
352        let cached = {
353            let cache_read = KMS_ED25519_PK_CACHE.read().await;
354            cache_read.get(key_id).cloned()
355        };
356        if let Some(cached) = cached {
357            return Ok(cached);
358        }
359
360        // Fetch from AWS KMS
361        let get_output = self
362            .inner
363            .get_public_key()
364            .key_id(key_id)
365            .send()
366            .await
367            .map_err(|e| {
368                warn!(
369                    error.kind = classify_sdk_error(&e),
370                    error.detail = %DisplayErrorContext(&e),
371                    kms_key_id = %key_id,
372                    operation = "get_public_key_ed25519",
373                    "AWS KMS get_public_key failed"
374                );
375                AwsKmsError::GetError(format!(
376                    "Failed to get Ed25519 public key for key '{key_id}': {}",
377                    classify_sdk_error(&e)
378                ))
379            })?;
380
381        let der_pk_blob = get_output
382            .public_key
383            .ok_or(AwsKmsError::GetError(
384                "No public key blob found".to_string(),
385            ))?
386            .into_inner();
387
388        let mut cache_write = KMS_ED25519_PK_CACHE.write().await;
389        cache_write.insert(key_id.to_string(), der_pk_blob.clone());
390
391        Ok(der_pk_blob)
392    }
393
394    async fn sign_ed25519<'a, 'b>(
395        &'a self,
396        key_id: &'b str,
397        message: &'b [u8],
398    ) -> AwsKmsResult<Vec<u8>> {
399        debug!("Signing Ed25519 message with AWS KMS, key_id: {}", key_id);
400
401        // Sign the message with Ed25519 using ED25519_SHA_512 algorithm
402        // Note: ED25519_SHA_512 requires MessageType::Raw - we pass the raw message
403        let sign_result = self
404            .inner
405            .sign()
406            .key_id(key_id)
407            .signing_algorithm(SigningAlgorithmSpec::Ed25519Sha512)
408            .message_type(MessageType::Raw)
409            .message(Blob::new(message))
410            .send()
411            .await;
412
413        // Process the result, extract signature
414        let signature = sign_result
415            .map_err(|e| {
416                warn!(
417                    error.kind = classify_sdk_error(&e),
418                    error.detail = %DisplayErrorContext(&e),
419                    kms_key_id = %key_id,
420                    operation = "sign_ed25519",
421                    "AWS KMS sign failed"
422                );
423                AwsKmsError::SignError(format!(
424                    "Failed to sign Ed25519 message for key '{key_id}': {}",
425                    classify_sdk_error(&e)
426                ))
427            })?
428            .signature
429            .ok_or(AwsKmsError::SignError(
430                "Signature not found in response".to_string(),
431            ))?
432            .into_inner();
433
434        // Ed25519 signatures should be 64 bytes
435        if signature.len() != 64 {
436            return Err(AwsKmsError::SignError(format!(
437                "Invalid Ed25519 signature length: expected 64 bytes, got {}",
438                signature.len()
439            )));
440        }
441
442        Ok(signature)
443    }
444}
445
446#[derive(Debug, Clone)]
447pub struct AwsKmsService<T: AwsKmsK256 + AwsKmsEd25519 + Clone = AwsKmsClient> {
448    pub kms_key_id: String,
449    client: T,
450}
451
452impl AwsKmsService<AwsKmsClient> {
453    pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
454        let shared_client = get_or_create_kms_client(&config).await?;
455
456        Ok(Self {
457            kms_key_id: config.key_id,
458            client: AwsKmsClient {
459                inner: shared_client,
460            },
461        })
462    }
463}
464
465#[cfg(test)]
466impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
467    pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
468        Self {
469            client,
470            kms_key_id: config.key_id,
471        }
472    }
473}
474
475impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
476    /// Common signing logic for EVM signatures.
477    ///
478    /// This internal helper eliminates duplication between `sign_payload_evm` and `sign_hash_evm`.
479    ///
480    /// # Parameters
481    /// * `digest` - The 32-byte hash to sign
482    /// * `original_bytes` - The original message bytes for recovery verification (if applicable)
483    /// * `use_prehash_recovery` - If true, recovers using hash directly; if false, uses original bytes
484    async fn sign_and_recover_evm(
485        &self,
486        digest: [u8; 32],
487        original_bytes: &[u8],
488        use_prehash_recovery: bool,
489    ) -> AwsKmsResult<Vec<u8>> {
490        // Sign the digest with AWS KMS
491        let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
492
493        // Get public key
494        let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
495
496        // Use shared signature recovery logic
497        recover_evm_signature_from_der(
498            &der_signature,
499            &der_pk,
500            digest,
501            original_bytes,
502            use_prehash_recovery,
503        )
504        .map_err(|e| AwsKmsError::ParseError(e.to_string()))
505    }
506
507    /// Signs a payload using the EVM signing scheme (hashes before signing).
508    ///
509    /// This method applies keccak256 hashing before signing.
510    ///
511    /// **Use for:**
512    /// - Raw transaction data (TxLegacy, TxEip1559)
513    /// - EIP-191 personal messages
514    ///
515    /// **Note:** For EIP-712 typed data, use `sign_hash_evm()` to avoid double-hashing.
516    pub async fn sign_payload_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
517        let digest = keccak256(bytes).0;
518        self.sign_and_recover_evm(digest, bytes, false).await
519    }
520
521    /// Signs a pre-computed hash using the EVM signing scheme (no hashing).
522    ///
523    /// This method signs the hash directly without applying keccak256.
524    ///
525    /// **Use for:**
526    /// - EIP-712 typed data (already hashed)
527    /// - Pre-computed message digests
528    ///
529    /// **Note:** For raw data, use `sign_payload_evm()` instead.
530    pub async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
531        self.sign_and_recover_evm(*hash, hash, true).await
532    }
533}
534
535#[async_trait]
536impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsEvmService for AwsKmsService<T> {
537    async fn get_evm_address(&self) -> AwsKmsResult<Address> {
538        let der = self.client.get_der_public_key(&self.kms_key_id).await?;
539        let eth_address = derive_ethereum_address_from_der(&der)
540            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
541        Ok(Address::Evm(eth_address))
542    }
543
544    async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
545        let digest = keccak256(message).0;
546        self.sign_and_recover_evm(digest, message, false).await
547    }
548
549    async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
550        // Delegates to the implementation method on AwsKmsService
551        self.sign_and_recover_evm(*hash, hash, true).await
552    }
553}
554
555#[async_trait]
556impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsSolanaService for AwsKmsService<T> {
557    async fn get_solana_address(&self) -> AwsKmsResult<Address> {
558        let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
559        let solana_address = derive_solana_address_from_der(&der)
560            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
561        Ok(Address::Solana(solana_address))
562    }
563
564    async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
565        self.client.sign_ed25519(&self.kms_key_id, message).await
566    }
567}
568
569#[async_trait]
570impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsStellarService for AwsKmsService<T> {
571    async fn get_stellar_address(&self) -> AwsKmsResult<Address> {
572        let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
573        let stellar_address = derive_stellar_address_from_der(&der)
574            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
575        Ok(Address::Stellar(stellar_address))
576    }
577
578    async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
579        self.client.sign_ed25519(&self.kms_key_id, message).await
580    }
581}
582
583#[cfg(test)]
584pub mod tests {
585    use super::*;
586
587    use alloy::primitives::utils::eip191_message;
588    use k256::{
589        ecdsa::SigningKey,
590        elliptic_curve::rand_core::OsRng,
591        pkcs8::{der::Encode, EncodePublicKey},
592    };
593    use mockall::predicate::{eq, ne};
594
595    /// Test Ed25519 key pair for mocking AWS KMS Ed25519 operations
596    pub struct TestEd25519Keys {
597        pub public_key_der: Vec<u8>,
598        pub public_key_raw: [u8; 32],
599    }
600
601    impl Default for TestEd25519Keys {
602        fn default() -> Self {
603            Self::new()
604        }
605    }
606
607    impl TestEd25519Keys {
608        pub fn new() -> Self {
609            // Well-known test Ed25519 public key (32 bytes)
610            let public_key_raw: [u8; 32] = [
611                0x9d, 0x45, 0x7e, 0x45, 0xe4, 0x16, 0xc4, 0xc6, 0x77, 0x67, 0x6a, 0x42, 0xff, 0x96,
612                0x8e, 0x3c, 0xf8, 0xdc, 0x73, 0xc8, 0xf3, 0x3a, 0x8d, 0x19, 0x81, 0x29, 0x7b, 0xfa,
613                0x3e, 0x00, 0x30, 0xba,
614            ];
615
616            // Ed25519 SPKI format: 12-byte header + 32-byte key
617            let mut public_key_der = vec![
618                0x30, 0x2a, // SEQUENCE, 42 bytes
619                0x30, 0x05, // SEQUENCE, 5 bytes
620                0x06, 0x03, 0x2b, 0x65, 0x70, // OID 1.3.101.112 (Ed25519)
621                0x03, 0x21, // BIT STRING, 33 bytes
622                0x00, // zero unused bits
623            ];
624            public_key_der.extend_from_slice(&public_key_raw);
625
626            Self {
627                public_key_der,
628                public_key_raw,
629            }
630        }
631    }
632
633    pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
634        let mut client = MockAwsKmsClient::new();
635        let signing_key = SigningKey::random(&mut OsRng);
636        let s = signing_key
637            .verifying_key()
638            .to_public_key_der()
639            .unwrap()
640            .to_der()
641            .unwrap();
642
643        client
644            .expect_get_der_public_key()
645            .with(eq("test-key-id"))
646            .return_const(Ok(s));
647        client
648            .expect_get_der_public_key()
649            .with(ne("test-key-id"))
650            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
651
652        client
653            .expect_sign_digest()
654            .withf(|key_id, _| key_id.ne("test-key-id"))
655            .return_const(Err(AwsKmsError::SignError(
656                "Key does not exist".to_string(),
657            )));
658
659        let key = signing_key.clone();
660        client
661            .expect_sign_digest()
662            .withf(|key_id, _| key_id.eq("test-key-id"))
663            .returning(move |_, digest| {
664                let (signature, _) = signing_key
665                    .sign_prehash_recoverable(&digest)
666                    .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
667                let der_signature = signature.to_der().as_bytes().to_vec();
668                Ok(der_signature)
669            });
670
671        // Setup Ed25519 mock expectations
672        let test_ed25519_keys = TestEd25519Keys::new();
673        client
674            .expect_get_ed25519_public_key()
675            .with(eq("test-key-id"))
676            .return_const(Ok(test_ed25519_keys.public_key_der.clone()));
677        client
678            .expect_get_ed25519_public_key()
679            .with(ne("test-key-id"))
680            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
681
682        // Mock Ed25519 signing - return a fixed 64-byte signature
683        client
684            .expect_sign_ed25519()
685            .withf(|key_id, _| key_id.eq("test-key-id"))
686            .returning(|_, _| Ok(vec![0u8; 64]));
687        client
688            .expect_sign_ed25519()
689            .withf(|key_id, _| key_id.ne("test-key-id"))
690            .return_const(Err(AwsKmsError::SignError(
691                "Key does not exist".to_string(),
692            )));
693
694        client.expect_clone().return_once(MockAwsKmsClient::new);
695
696        (client, key)
697    }
698
699    #[tokio::test]
700    async fn test_get_public_key() {
701        let (mock_client, key) = setup_mock_kms_client();
702        let kms = AwsKmsService::new_for_testing(
703            mock_client,
704            AwsKmsSignerConfig {
705                region: Some("us-east-1".to_string()),
706                key_id: "test-key-id".to_string(),
707            },
708        );
709
710        let result = kms.get_evm_address().await;
711        assert!(result.is_ok());
712        if let Ok(Address::Evm(evm_address)) = result {
713            let expected_address = derive_ethereum_address_from_der(
714                key.verifying_key().to_public_key_der().unwrap().as_bytes(),
715            )
716            .unwrap();
717            assert_eq!(expected_address, evm_address);
718        }
719    }
720
721    #[tokio::test]
722    async fn test_get_public_key_fail() {
723        let (mock_client, _) = setup_mock_kms_client();
724        let kms = AwsKmsService::new_for_testing(
725            mock_client,
726            AwsKmsSignerConfig {
727                region: Some("us-east-1".to_string()),
728                key_id: "invalid-key-id".to_string(),
729            },
730        );
731
732        let result = kms.get_evm_address().await;
733        assert!(result.is_err());
734        if let Err(err) = result {
735            assert!(matches!(err, AwsKmsError::GetError(_)))
736        }
737    }
738
739    #[tokio::test]
740    async fn test_sign_digest() {
741        let (mock_client, _) = setup_mock_kms_client();
742        let kms = AwsKmsService::new_for_testing(
743            mock_client,
744            AwsKmsSignerConfig {
745                region: Some("us-east-1".to_string()),
746                key_id: "test-key-id".to_string(),
747            },
748        );
749
750        let message_eip = eip191_message(b"Hello World!");
751        let result = kms.sign_payload_evm(&message_eip).await;
752
753        // We just assert for Ok, since the pubkey recovery indicates the validity of signature
754        assert!(result.is_ok());
755    }
756
757    #[tokio::test]
758    async fn test_sign_digest_fail() {
759        let (mock_client, _) = setup_mock_kms_client();
760        let kms = AwsKmsService::new_for_testing(
761            mock_client,
762            AwsKmsSignerConfig {
763                region: Some("us-east-1".to_string()),
764                key_id: "invalid-key-id".to_string(),
765            },
766        );
767
768        let message_eip = eip191_message(b"Hello World!");
769        let result = kms.sign_payload_evm(&message_eip).await;
770        assert!(result.is_err());
771        if let Err(err) = result {
772            assert!(matches!(err, AwsKmsError::SignError(_)))
773        }
774    }
775
776    #[tokio::test]
777    async fn test_get_solana_address() {
778        let (mock_client, _) = setup_mock_kms_client();
779        let kms = AwsKmsService::new_for_testing(
780            mock_client,
781            AwsKmsSignerConfig {
782                region: Some("us-east-1".to_string()),
783                key_id: "test-key-id".to_string(),
784            },
785        );
786
787        let result = kms.get_solana_address().await;
788        assert!(result.is_ok());
789        if let Ok(Address::Solana(solana_address)) = result {
790            // Verify it's a valid base58-encoded address
791            assert!(!solana_address.is_empty());
792            assert!(solana_address.len() >= 32 && solana_address.len() <= 44);
793            // Verify it matches the expected address from our test key
794            let test_keys = TestEd25519Keys::new();
795            let expected_address = bs58::encode(test_keys.public_key_raw).into_string();
796            assert_eq!(solana_address, expected_address);
797        } else {
798            panic!("Expected Solana address");
799        }
800    }
801
802    #[tokio::test]
803    async fn test_get_solana_address_fail() {
804        let (mock_client, _) = setup_mock_kms_client();
805        let kms = AwsKmsService::new_for_testing(
806            mock_client,
807            AwsKmsSignerConfig {
808                region: Some("us-east-1".to_string()),
809                key_id: "invalid-key-id".to_string(),
810            },
811        );
812
813        let result = kms.get_solana_address().await;
814        assert!(result.is_err());
815        if let Err(err) = result {
816            assert!(matches!(err, AwsKmsError::GetError(_)))
817        }
818    }
819
820    #[tokio::test]
821    async fn test_sign_solana() {
822        let (mock_client, _) = setup_mock_kms_client();
823        let kms = AwsKmsService::new_for_testing(
824            mock_client,
825            AwsKmsSignerConfig {
826                region: Some("us-east-1".to_string()),
827                key_id: "test-key-id".to_string(),
828            },
829        );
830
831        let message = b"Test Solana message";
832        let result = kms.sign_solana(message).await;
833        assert!(result.is_ok());
834        let signature = result.unwrap();
835        assert_eq!(signature.len(), 64); // Ed25519 signatures are 64 bytes
836    }
837
838    #[tokio::test]
839    async fn test_sign_solana_fail() {
840        let (mock_client, _) = setup_mock_kms_client();
841        let kms = AwsKmsService::new_for_testing(
842            mock_client,
843            AwsKmsSignerConfig {
844                region: Some("us-east-1".to_string()),
845                key_id: "invalid-key-id".to_string(),
846            },
847        );
848
849        let message = b"Test Solana message";
850        let result = kms.sign_solana(message).await;
851        assert!(result.is_err());
852        if let Err(err) = result {
853            assert!(matches!(err, AwsKmsError::SignError(_)))
854        }
855    }
856
857    #[tokio::test]
858    async fn test_get_stellar_address() {
859        let (mock_client, _) = setup_mock_kms_client();
860        let kms = AwsKmsService::new_for_testing(
861            mock_client,
862            AwsKmsSignerConfig {
863                region: Some("us-east-1".to_string()),
864                key_id: "test-key-id".to_string(),
865            },
866        );
867
868        let result = kms.get_stellar_address().await;
869        assert!(result.is_ok());
870        if let Ok(Address::Stellar(stellar_address)) = result {
871            // Stellar addresses start with 'G' for public accounts
872            assert!(stellar_address.starts_with('G'));
873            // Stellar addresses are 56 characters long
874            assert_eq!(stellar_address.len(), 56);
875        } else {
876            panic!("Expected Stellar address");
877        }
878    }
879
880    #[tokio::test]
881    async fn test_get_stellar_address_fail() {
882        let (mock_client, _) = setup_mock_kms_client();
883        let kms = AwsKmsService::new_for_testing(
884            mock_client,
885            AwsKmsSignerConfig {
886                region: Some("us-east-1".to_string()),
887                key_id: "invalid-key-id".to_string(),
888            },
889        );
890
891        let result = kms.get_stellar_address().await;
892        assert!(result.is_err());
893        if let Err(err) = result {
894            assert!(matches!(err, AwsKmsError::GetError(_)))
895        }
896    }
897
898    #[tokio::test]
899    async fn test_sign_stellar() {
900        let (mock_client, _) = setup_mock_kms_client();
901        let kms = AwsKmsService::new_for_testing(
902            mock_client,
903            AwsKmsSignerConfig {
904                region: Some("us-east-1".to_string()),
905                key_id: "test-key-id".to_string(),
906            },
907        );
908
909        let message = b"Test Stellar message";
910        let result = kms.sign_stellar(message).await;
911        assert!(result.is_ok());
912        let signature = result.unwrap();
913        assert_eq!(signature.len(), 64); // Ed25519 signatures are 64 bytes
914    }
915
916    #[tokio::test]
917    async fn test_sign_stellar_fail() {
918        let (mock_client, _) = setup_mock_kms_client();
919        let kms = AwsKmsService::new_for_testing(
920            mock_client,
921            AwsKmsSignerConfig {
922                region: Some("us-east-1".to_string()),
923                key_id: "invalid-key-id".to_string(),
924            },
925        );
926
927        let message = b"Test Stellar message";
928        let result = kms.sign_stellar(message).await;
929        assert!(result.is_err());
930        if let Err(err) = result {
931            assert!(matches!(err, AwsKmsError::SignError(_)))
932        }
933    }
934
935    // Note: Ed25519 DER parsing tests are in utils/ed25519.rs
936
937    #[tokio::test]
938    async fn test_kms_client_cache_same_region_shares_client() {
939        let config1 = AwsKmsSignerConfig {
940            region: Some("us-west-2".to_string()),
941            key_id: "key-aaa".to_string(),
942        };
943        let config2 = AwsKmsSignerConfig {
944            region: Some("us-west-2".to_string()),
945            key_id: "key-bbb".to_string(),
946        };
947
948        let result1 = get_or_create_kms_client(&config1).await;
949        let result2 = get_or_create_kms_client(&config2).await;
950
951        match (result1, result2) {
952            (Ok(client1), Ok(client2)) => {
953                assert!(Arc::ptr_eq(&client1, &client2));
954            }
955            (Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => {
956                // In environments without TLS roots, the panic is caught as ConfigError
957                assert!(
958                    msg.contains("TLS root certificates"),
959                    "Expected TLS-related config error, got: {msg}"
960                );
961            }
962            (Err(e), _) | (_, Err(e)) => {
963                panic!("Unexpected error: {e:?}");
964            }
965        }
966    }
967
968    #[tokio::test]
969    async fn test_kms_client_returns_config_error_when_region_missing() {
970        let config = AwsKmsSignerConfig {
971            region: None,
972            key_id: "test-key".to_string(),
973        };
974
975        // Covers the missing-region branch in resolve_aws_region().
976        // Does not exercise Client::new() panic handling (that requires TLS root absence).
977        let result = get_or_create_kms_client(&config).await;
978        match result {
979            Err(AwsKmsError::ConfigError(_)) => {}
980            Ok(_) => panic!(
981                "Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env"
982            ),
983            Err(e) => panic!("Expected ConfigError, got: {e:?}"),
984        }
985    }
986}