diff --git a/dist/efs-utils.conf b/dist/efs-utils.conf index 36e0f66e..a62aa315 100644 --- a/dist/efs-utils.conf +++ b/dist/efs-utils.conf @@ -42,6 +42,13 @@ optimize_readahead = true # By default, we enable the feature to fallback to mount with mount target ip address when dns name cannot be resolved fall_back_to_mount_target_ip_address_enabled = true +# When enabled, efs-utils calls DescribeMountTargets to determine the mount target address family (IPv4/IPv6) +# and constrains DNS resolution and efs-proxy connections to that family. Disable to skip the API call +# and let the OS resolver decide (equivalent to the pre-3.x behavior). +# Requires IAM actions: elasticfilesystem:DescribeMountTargets, ec2:DescribeAvailabilityZones. +# If those actions are unavailable, efs-utils falls back to AF_UNSPEC automatically. +dynamic_address_family_enabled = true + # By default, we use IMDSv2 to get the instance metadata, set this to true if you want to disable IMDSv2 usage disable_fetch_ec2_metadata_token = false diff --git a/src/efs_utils_common/mount_utils.py b/src/efs_utils_common/mount_utils.py index 2356e46a..bc091ac1 100644 --- a/src/efs_utils_common/mount_utils.py +++ b/src/efs_utils_common/mount_utils.py @@ -266,6 +266,7 @@ def mount_with_proxy( mountpoint, options, fallback_ip_address=None, + address_family=None, ): """ This function is responsible for launching a efs-proxy process and attaching a NFS mount to that process @@ -291,6 +292,7 @@ def mount_with_proxy( options, fallback_ip_address=fallback_ip_address, efs_proxy_enabled=efs_proxy_enabled, + address_family=address_family, ) as tunnel_proc: mount_completed = threading.Event() t = threading.Thread( diff --git a/src/efs_utils_common/network_utils.py b/src/efs_utils_common/network_utils.py index 40f08e35..4ff02780 100644 --- a/src/efs_utils_common/network_utils.py +++ b/src/efs_utils_common/network_utils.py @@ -78,9 +78,9 @@ def get_ipv6_addresses(hostname): return [] -def dns_name_can_be_resolved(dns_name): +def dns_name_can_be_resolved(dns_name, family=socket.AF_UNSPEC): try: - addr_info = socket.getaddrinfo(dns_name, None, socket.AF_UNSPEC) + addr_info = socket.getaddrinfo(dns_name, None, family) return len(addr_info) > 0 except socket.gaierror: return False diff --git a/src/efs_utils_common/proxy.py b/src/efs_utils_common/proxy.py index dd38edb3..ecb4d30e 100644 --- a/src/efs_utils_common/proxy.py +++ b/src/efs_utils_common/proxy.py @@ -271,6 +271,7 @@ def write_stunnel_config_file( cert_details=None, fallback_ip_address=None, efs_proxy_enabled=True, + address_family=None, ): """ Serializes stunnel configuration to a file. Unfortunately this does not conform to Python's config file format, so we have to @@ -424,6 +425,12 @@ def write_stunnel_config_file( efs_config["fs_id"] = fs_id efs_config["region"] = region efs_config["efs_utils_version"] = VERSION + if address_family is not None: + if address_family == socket.AF_INET6: + efs_config["address_family"] = "ipv6" + elif address_family == socket.AF_INET: + efs_config["address_family"] = "ipv4" + # AF_UNSPEC: omit the key, proxy uses its own default (unspec) stunnel_config = "\n".join( serialize_stunnel_config(global_config) @@ -693,6 +700,7 @@ def bootstrap_proxy( state_file_dir=STATE_FILE_DIR, fallback_ip_address=None, efs_proxy_enabled=True, + address_family=None, ): """ Generates a TLS private key and client-side certificate, a stunnel configuration file, and a state file @@ -809,6 +817,7 @@ def bootstrap_proxy( cert_details=cert_details, fallback_ip_address=fallback_ip_address, efs_proxy_enabled=efs_proxy_enabled, + address_family=address_family, ) if efs_proxy_enabled: tunnel_args = [_efs_proxy_bin(), stunnel_config_file] diff --git a/src/mount_efs/__init__.py b/src/mount_efs/__init__.py index 0f345f24..dfaf5edf 100755 --- a/src/mount_efs/__init__.py +++ b/src/mount_efs/__init__.py @@ -34,6 +34,7 @@ import logging import platform import re +import socket import sys from efs_utils_common.cloudwatch import bootstrap_cloudwatch_logging @@ -69,6 +70,7 @@ from efs_utils_common.proxy import get_init_system from mount_efs.dns_resolver import ( get_dns_name_and_fallback_mount_target_ip_address, + get_mount_target_address_family, match_device, ) @@ -178,8 +180,13 @@ def main(): init_system = get_init_system() check_network_status(fs_id, init_system) + address_family = ( + get_mount_target_address_family(config, options, fs_id) + if "mounttargetip" not in options + else socket.AF_UNSPEC + ) dns_name, fallback_ip_address = get_dns_name_and_fallback_mount_target_ip_address( - config, fs_id, options + config, fs_id, options, address_family=address_family ) if check_if_platform_is_mac() and "notls" not in options: @@ -208,6 +215,7 @@ def main(): mountpoint, options, fallback_ip_address=fallback_ip_address, + address_family=address_family, ) diff --git a/src/mount_efs/dns_resolver.py b/src/mount_efs/dns_resolver.py index 6e9ec50b..415f79e0 100644 --- a/src/mount_efs/dns_resolver.py +++ b/src/mount_efs/dns_resolver.py @@ -48,7 +48,7 @@ def get_target_az(config, options): return None -def get_dns_name_and_fallback_mount_target_ip_address(config, fs_id, options): +def get_dns_name_and_fallback_mount_target_ip_address(config, fs_id, options, address_family=socket.AF_UNSPEC): def _validate_replacement_field_count(format_str, expected_ct): if format_str.count("{") != expected_ct or format_str.count("}") != expected_ct: raise ValueError( @@ -124,7 +124,7 @@ def _validate_replacement_field_count(format_str, expected_ct): ip_address=ip_address, fallback_message=fallback_message ) - if dns_name_can_be_resolved(dns_name): + if dns_name_can_be_resolved(dns_name, family=address_family): return dns_name, None logging.info( @@ -179,6 +179,33 @@ def get_fallback_mount_target_ip_address(config, options, fs_id, dns_name): ) +def get_mount_target_address_family(config, options, fs_id): + """Return socket.AF_INET, socket.AF_INET6, or socket.AF_UNSPEC based on the actual mount target IP type. + Falls back to socket.AF_UNSPEC if the API call fails or the feature is disabled.""" + if not get_boolean_config_item_value(config, CONFIG_SECTION, "dynamic_address_family_enabled", default_value=True): + return socket.AF_UNSPEC + + try: + efs_client = get_botocore_client(config, "efs", options) + if efs_client is None: + return socket.AF_UNSPEC + + az_name = get_target_az(config, options) + ec2_client = get_botocore_client(config, "ec2", options) + if ec2_client is None: + return socket.AF_UNSPEC + + mount_target = get_mount_target_in_az(efs_client, ec2_client, fs_id, az_name) + if "Ipv6Address" in mount_target and "IpAddress" not in mount_target: + return socket.AF_INET6 + + return socket.AF_INET + except Exception: + logging.info("Failed to determine mount target address family, defaulting to AF_UNSPEC") + logging.debug("get_mount_target_address_family exception detail", exc_info=True) + return socket.AF_UNSPEC + + def check_if_fall_back_to_mount_target_ip_address_is_enabled(config): return get_boolean_config_item_value( config, diff --git a/src/proxy/src/config_parser.rs b/src/proxy/src/config_parser.rs index 951344c7..2bf9b0a0 100644 --- a/src/proxy/src/config_parser.rs +++ b/src/proxy/src/config_parser.rs @@ -64,6 +64,10 @@ fn default_log_format() -> Option { Some("file".to_string()) } +fn default_address_family() -> String { + "unspec".to_string() +} + #[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct ProxyConfig { #[serde(alias = "fips", deserialize_with = "deserialize_bool")] @@ -188,6 +192,10 @@ pub struct EfsConfig { /// efs-utils version string for channel init #[serde(alias = "efs_utils_version", default)] pub efs_utils_version: String, + + /// Address family to use when resolving the mount target hostname ("ipv4" or "ipv6") + #[serde(default = "default_address_family")] + pub address_family: String, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct ReadBypassConfig { diff --git a/src/proxy/src/connections.rs b/src/proxy/src/connections.rs index d6db5dd9..9b7e14ae 100644 --- a/src/proxy/src/connections.rs +++ b/src/proxy/src/connections.rs @@ -11,12 +11,13 @@ use futures::future::{self, BoxFuture}; use log::{debug, info, warn}; use s2n_tls_tokio::TlsStream; use std::sync::Arc; -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, net::SocketAddr, time::Duration}; use tokio::task::JoinHandle; use tokio::time::{timeout_at, Instant}; use tokio::{ io::AsyncWriteExt, io::{AsyncRead, AsyncWrite}, + net::lookup_host, net::TcpStream, sync::mpsc, }; @@ -334,14 +335,22 @@ pub fn get_bind_response_string(bind_response: &BindResponse) -> String { #[derive(Clone)] pub struct PlainTextPartitionFinder { pub mount_target_addr: String, + pub address_family: String, } #[async_trait] impl PartitionFinder for PlainTextPartitionFinder { async fn create_connect_future(&self) -> BoxFuture<'static, Result> { let mount_target_address = self.mount_target_addr.clone(); + let address_family = self.address_family.clone(); Box::pin(async move { - match TcpStream::connect(mount_target_address).await { + let stream = if address_family == "unspec" { + TcpStream::connect(&mount_target_address).await + } else { + let addr = resolve_addr(&mount_target_address, &address_family).await?; + TcpStream::connect(addr).await + }; + match stream { Ok(tcp_stream) => Ok(configure_stream(tcp_stream)), Err(e) => Err(ConnectError::IoError(e)), } @@ -351,11 +360,15 @@ impl PartitionFinder for PlainTextPartitionFinder { pub struct TlsPartitionFinder { tls_config: Arc>, + pub address_family: String, } impl TlsPartitionFinder { - pub fn new(tls_config: Arc>) -> Self { - TlsPartitionFinder { tls_config } + pub fn new(tls_config: Arc>, address_family: String) -> Self { + TlsPartitionFinder { + tls_config, + address_family, + } } } @@ -364,11 +377,40 @@ impl PartitionFinder> for TlsPartitionFinder { async fn create_connect_future( &self, ) -> BoxFuture<'static, Result, ConnectError>> { - let tls_config_copy = self.tls_config.lock().await.clone(); - Box::pin(establish_tls_stream(tls_config_copy)) + let mut tls_config_copy = self.tls_config.lock().await.clone(); + let address_family = self.address_family.clone(); + Box::pin(async move { + if address_family != "unspec" { + let addr = resolve_addr(&tls_config_copy.remote_addr, &address_family).await?; + tls_config_copy.remote_addr = addr.to_string(); + } + establish_tls_stream(tls_config_copy).await + }) } } +/// Resolve `addr` (host:port) to a `SocketAddr` matching the requested address family. +/// `address_family`: `"ipv4"` selects the first IPv4 result, `"ipv6"` selects the first IPv6 result. +/// Returns an error if no address of the requested family is found. +/// Callers must not pass `"unspec"` — use `TcpStream::connect(hostname)` directly instead. +async fn resolve_addr(addr: &str, address_family: &str) -> Result { + let addrs: Vec = lookup_host(addr) + .await + .map_err(ConnectError::IoError)? + .collect(); + let result = match address_family { + "ipv4" => addrs.iter().find(|a| a.is_ipv4()).copied(), + "ipv6" => addrs.iter().find(|a| a.is_ipv6()).copied(), + _ => None, + }; + result.ok_or_else(|| { + ConnectError::IoError(std::io::Error::other(format!( + "no {} address found for {}", + address_family, addr + ))) + }) +} + #[cfg(test)] mod tests { use super::*; @@ -399,6 +441,7 @@ mod tests { let error = tokio::spawn(async move { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), + address_family: "unspec".to_string(), }; partition_finder .establish_connection(create_deadline(test_single_connection_timeout), PROXY_ID) @@ -421,6 +464,7 @@ mod tests { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), + address_family: "unspec".to_string(), }; partition_finder .inner_establish_multiplex_connection( @@ -453,6 +497,7 @@ mod tests { let task = tokio::spawn(async move { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), + address_family: "unspec".to_string(), }; partition_finder .inner_establish_multiplex_connection( @@ -565,6 +610,7 @@ mod tests { }); let tls_partition_finder = TlsPartitionFinder { tls_config: tls_config_ptr.clone(), + address_family: "unspec".to_string(), }; let _ = kill(nix::unistd::Pid::this(), Signal::SIGHUP); rx.await.unwrap(); @@ -573,4 +619,30 @@ mod tests { tls_partition_finder.tls_config.lock().await.client_cert ); } + + #[tokio::test] + async fn test_resolve_addr_ipv4_success() { + // "127.0.0.1:2049" is a numeric address — lookup_host returns it without DNS. + let result = resolve_addr("127.0.0.1:2049", "ipv4").await; + assert!(result.is_ok()); + assert!(result.unwrap().is_ipv4()); + } + + #[tokio::test] + async fn test_resolve_addr_ipv6_success() { + // "[::1]:2049" is a numeric IPv6 address — lookup_host returns it without DNS. + let result = resolve_addr("[::1]:2049", "ipv6").await; + assert!(result.is_ok()); + assert!(result.unwrap().is_ipv6()); + } + + #[tokio::test] + async fn test_resolve_addr_ipv6_only_with_ipv4_requested_returns_error() { + // "[::1]:2049" is a numeric IPv6 address — guaranteed to resolve to only an IPv6 SocketAddr. + let result = resolve_addr("[::1]:2049", "ipv4").await; + assert!( + matches!(result, Err(ConnectError::IoError(_))), + "expected error when no IPv4 address available" + ); + } } diff --git a/src/proxy/src/main.rs b/src/proxy/src/main.rs index d1432c6a..d804072c 100644 --- a/src/proxy/src/main.rs +++ b/src/proxy/src/main.rs @@ -96,7 +96,10 @@ async fn main() { let controller = Controller::new( &proxy_config.nested_config.listen_addr, proxy_config.clone(), - Arc::new(TlsPartitionFinder::new(tls_config)), + Arc::new(TlsPartitionFinder::new( + tls_config, + proxy_config.nested_config.address_family.clone(), + )), status_reporter, cw_publisher.clone(), ) @@ -112,6 +115,7 @@ async fn main() { proxy_config.clone(), Arc::new(PlainTextPartitionFinder { mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(), + address_family: proxy_config.nested_config.address_family.clone(), }), status_reporter, cw_publisher.clone(), diff --git a/test/mount_common_test/test_get_fallback_mount_target_ip_address.py b/test/mount_common_test/test_get_fallback_mount_target_ip_address.py index f1a0bd18..44faa176 100644 --- a/test/mount_common_test/test_get_fallback_mount_target_ip_address.py +++ b/test/mount_common_test/test_get_fallback_mount_target_ip_address.py @@ -367,3 +367,101 @@ def test_get_fallback_mount_target_ip_address_helper_prefer_ipv4(mocker): ) assert ip_address == ipv4_address + + +def test_get_mount_target_address_family_disabled_returns_unspec(mocker): + config = _get_mock_config() + config.set(efs_utils_common.constants.CONFIG_SECTION, "dynamic_address_family_enabled", "false") + botocore_mock = mocker.patch("mount_efs.dns_resolver.get_botocore_client") + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_UNSPEC + botocore_mock.assert_not_called() + + +def test_get_mount_target_address_family_ipv4(mocker): + config = _get_mock_config() + mocker.patch("mount_efs.dns_resolver.get_botocore_client", side_effect=[MOCK_EFS_AGENT, MOCK_EC2_AGENT]) + mocker.patch("mount_efs.dns_resolver.get_target_az", return_value=DEFAULT_AZ) + mocker.patch("mount_efs.dns_resolver.get_mount_target_in_az", return_value={"IpAddress": "1.2.3.4"}) + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_INET + + +def test_get_mount_target_address_family_ipv6_only(mocker): + config = _get_mock_config() + mocker.patch("mount_efs.dns_resolver.get_botocore_client", side_effect=[MOCK_EFS_AGENT, MOCK_EC2_AGENT]) + mocker.patch("mount_efs.dns_resolver.get_target_az", return_value=DEFAULT_AZ) + mocker.patch( + "mount_efs.dns_resolver.get_mount_target_in_az", + return_value={"Ipv6Address": "2001:db8::1"}, + ) + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_INET6 + + +def test_get_mount_target_address_family_dual_stack_returns_ipv4(mocker): + config = _get_mock_config() + mocker.patch("mount_efs.dns_resolver.get_botocore_client", side_effect=[MOCK_EFS_AGENT, MOCK_EC2_AGENT]) + mocker.patch("mount_efs.dns_resolver.get_target_az", return_value=DEFAULT_AZ) + mocker.patch( + "mount_efs.dns_resolver.get_mount_target_in_az", + return_value={"IpAddress": "1.2.3.4", "Ipv6Address": "2001:db8::1"}, + ) + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_INET + + +def test_get_mount_target_address_family_no_efs_client_returns_unspec(mocker): + config = _get_mock_config() + mocker.patch("mount_efs.dns_resolver.get_botocore_client", side_effect=[None]) + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_UNSPEC + + +def test_get_mount_target_address_family_no_ec2_client_returns_unspec(mocker): + config = _get_mock_config() + mocker.patch("mount_efs.dns_resolver.get_botocore_client", side_effect=[MOCK_EFS_AGENT, None]) + mocker.patch("mount_efs.dns_resolver.get_target_az", return_value=DEFAULT_AZ) + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_UNSPEC + + +def test_get_mount_target_address_family_api_error_returns_unspec(mocker): + config = _get_mock_config() + mocker.patch("mount_efs.dns_resolver.get_botocore_client", side_effect=[MOCK_EFS_AGENT, MOCK_EC2_AGENT]) + mocker.patch("mount_efs.dns_resolver.get_target_az", return_value=DEFAULT_AZ) + mocker.patch("mount_efs.dns_resolver.get_mount_target_in_az", side_effect=Exception("API error")) + + assert dns_resolver.get_mount_target_address_family(config, {}, FS_ID) == socket.AF_UNSPEC + + +def test_get_dns_name_with_address_family_passed(mocker): + """address_family param is forwarded to dns_name_can_be_resolved.""" + try: + config = ConfigParser.SafeConfigParser() + except AttributeError: + config = ConfigParser() + config.add_section(efs_utils_common.constants.CONFIG_SECTION) + config.set(efs_utils_common.constants.CONFIG_SECTION, "dns_name_format", "{fs_id}.efs.{region}.{dns_name_suffix}") + config.set(efs_utils_common.constants.CONFIG_SECTION, "dns_name_suffix", "amazonaws.com") + config.set( + efs_utils_common.constants.CONFIG_SECTION, + efs_utils_common.constants.FALLBACK_TO_MOUNT_TARGET_IP_ADDRESS_ITEM, + "false", + ) + + mocker.patch("mount_efs.dns_resolver.get_target_region", return_value=DEFAULT_REGION) + dns_mock = mocker.patch( + "efs_utils_common.network_utils.socket.getaddrinfo", + return_value=[("", "", "", "", ("1.2.3.4", 0))], + ) + + dns_name, ip = dns_resolver.get_dns_name_and_fallback_mount_target_ip_address( + config, FS_ID, {}, address_family=socket.AF_INET + ) + + assert dns_name is not None + assert ip is None + # getaddrinfo was called with AF_INET + args = dns_mock.call_args[0] + assert args[2] == socket.AF_INET