diff --git a/dist/efs-utils.conf b/dist/efs-utils.conf index 36e0f66e..6c297c10 100644 --- a/dist/efs-utils.conf +++ b/dist/efs-utils.conf @@ -45,6 +45,11 @@ fall_back_to_mount_target_ip_address_enabled = true # By default, we use IMDSv2 to get the instance metadata, set this to true if you want to disable IMDSv2 usage disable_fetch_ec2_metadata_token = false +# Set to true to force IPv4 when resolving EFS DNS names and connecting to mount targets. +# Enable this on dual-stack or IPv6-capable hosts where you want to ensure the NFS connection +# uses IPv4. When false (default), the system resolver picks the address family. +prefer_ipv4 = false + # By default, we enable efs-utils to retry failed mount.nfs command that due to (1) connection reset by peer (2) the # mount.nfs is not finished within 'retry_nfs_mount_command_timeout_sec'. If the retry count is set as N, initial N - 1 # mount attempts will timeout if the command does not finish within 'retry_nfs_mount_command_timeout_sec' sec. diff --git a/dist/s3files-utils.conf b/dist/s3files-utils.conf index feb60d5f..b518c73d 100644 --- a/dist/s3files-utils.conf +++ b/dist/s3files-utils.conf @@ -44,6 +44,11 @@ fall_back_to_mount_target_ip_address_enabled = true # By default, we use IMDSv2 to get the instance metadata, set this to true if you want to disable IMDSv2 usage disable_fetch_ec2_metadata_token = false +# Set to true to force IPv4 when resolving S3 Files DNS names and connecting to mount targets. +# Enable this on dual-stack or IPv6-capable hosts where you want to ensure the NFS connection +# uses IPv4. When false (default), the system resolver picks the address family. +prefer_ipv4 = false + # By default, we enable efs-utils to retry failed mount.nfs command that due to (1) connection reset by peer (2) the # mount.nfs is not finished within 'retry_nfs_mount_command_timeout_sec'. If the retry count is set as N, initial N - 1 # mount attempts will timeout if the command does not finish within 'retry_nfs_mount_command_timeout_sec' sec. diff --git a/src/efs_utils_common/network_utils.py b/src/efs_utils_common/network_utils.py index 40f08e35..764f53d0 100644 --- a/src/efs_utils_common/network_utils.py +++ b/src/efs_utils_common/network_utils.py @@ -78,9 +78,10 @@ def get_ipv6_addresses(hostname): return [] -def dns_name_can_be_resolved(dns_name): +def dns_name_can_be_resolved(dns_name, prefer_ipv4=False): try: - addr_info = socket.getaddrinfo(dns_name, None, socket.AF_UNSPEC) + family = socket.AF_INET if prefer_ipv4 else socket.AF_UNSPEC + addr_info = socket.getaddrinfo(dns_name, None, family) return len(addr_info) > 0 except socket.gaierror: return False diff --git a/src/efs_utils_common/proxy.py b/src/efs_utils_common/proxy.py index dd38edb3..abb41058 100644 --- a/src/efs_utils_common/proxy.py +++ b/src/efs_utils_common/proxy.py @@ -424,6 +424,10 @@ def write_stunnel_config_file( efs_config["fs_id"] = fs_id efs_config["region"] = region efs_config["efs_utils_version"] = VERSION + if get_boolean_config_item_value( + config, CONFIG_SECTION, "prefer_ipv4", default_value=False + ): + efs_config["prefer_ipv4"] = "true" stunnel_config = "\n".join( serialize_stunnel_config(global_config) diff --git a/src/mount_efs/dns_resolver.py b/src/mount_efs/dns_resolver.py index 6e9ec50b..432de2d9 100644 --- a/src/mount_efs/dns_resolver.py +++ b/src/mount_efs/dns_resolver.py @@ -124,7 +124,9 @@ def _validate_replacement_field_count(format_str, expected_ct): ip_address=ip_address, fallback_message=fallback_message ) - if dns_name_can_be_resolved(dns_name): + if dns_name_can_be_resolved(dns_name, get_boolean_config_item_value( + config, CONFIG_SECTION, "prefer_ipv4", default_value=False + )): return dns_name, None logging.info( @@ -204,6 +206,12 @@ def get_fallback_mount_target_ip_address_helper(config, options, fs_id): if "IpAddress" in mount_target: return mount_target.get("IpAddress") elif "Ipv6Address" in mount_target: + if get_boolean_config_item_value( + config, CONFIG_SECTION, "prefer_ipv4", default_value=False + ): + raise FallbackException( + "Mount target has only an IPv6 address but prefer_ipv4 is enabled." + ) return mount_target.get("Ipv6Address") @@ -256,8 +264,12 @@ def match_device(config, device, options): return remote, path, None try: + prefer_ipv4 = get_boolean_config_item_value( + config, CONFIG_SECTION, "prefer_ipv4", default_value=False + ) addrinfo = socket.getaddrinfo( - remote, None, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_CANONNAME + remote, None, socket.AF_INET if prefer_ipv4 else socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_CANONNAME ) hostnames = list( set( diff --git a/src/mount_s3files/dns_resolver.py b/src/mount_s3files/dns_resolver.py index bd74ede4..7fb3dbe1 100644 --- a/src/mount_s3files/dns_resolver.py +++ b/src/mount_s3files/dns_resolver.py @@ -10,6 +10,7 @@ import socket from efs_utils_common.cloudwatch import create_default_cloudwatchlog_agent_if_not_exist +from efs_utils_common.config_utils import get_boolean_config_item_value from efs_utils_common.constants import CONFIG_SECTION, FS_ID_REGEX_PATTERN from efs_utils_common.context import MountContext from efs_utils_common.error_reporting import fatal_error @@ -90,7 +91,9 @@ def _validate_replacement_field_count(format_str, expected_ct): ip_address=ip_address, fallback_message=fallback_message ) - if dns_name_can_be_resolved(dns_name): + if dns_name_can_be_resolved(dns_name, get_boolean_config_item_value( + config, CONFIG_SECTION, "prefer_ipv4", default_value=False + )): return dns_name, None logging.info( diff --git a/src/proxy/src/config_parser.rs b/src/proxy/src/config_parser.rs index 951344c7..5c6d62da 100644 --- a/src/proxy/src/config_parser.rs +++ b/src/proxy/src/config_parser.rs @@ -188,6 +188,10 @@ pub struct EfsConfig { /// efs-utils version string for channel init #[serde(alias = "efs_utils_version", default)] pub efs_utils_version: String, + + /// When true, only connect to IPv4 addresses when resolving the mount target hostname + #[serde(alias = "prefer_ipv4", deserialize_with = "deserialize_bool", default)] + pub prefer_ipv4: bool, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct ReadBypassConfig { @@ -487,6 +491,7 @@ checkHost = fs-12341234.efs.us-east-1.amazonaws.com proxy_logging_max_bytes: DEFAULT_PROXY_LOGGING_MAX_BYTES(), proxy_logging_file_count: DEFAULT_PROXY_LOGGING_FILE_COUNT(), efs_utils_version: String::new(), + prefer_ipv4: false, }, }; @@ -584,6 +589,7 @@ jwt_path = baz proxy_logging_max_bytes: DEFAULT_PROXY_LOGGING_MAX_BYTES(), proxy_logging_file_count: DEFAULT_PROXY_LOGGING_FILE_COUNT(), efs_utils_version: String::new(), + prefer_ipv4: false, }, }; @@ -689,6 +695,7 @@ readahead_max_window_size_bytes = {test_value} proxy_logging_max_bytes: DEFAULT_PROXY_LOGGING_MAX_BYTES(), proxy_logging_file_count: DEFAULT_PROXY_LOGGING_FILE_COUNT(), efs_utils_version: String::new(), + prefer_ipv4: false, }, }; diff --git a/src/proxy/src/connections.rs b/src/proxy/src/connections.rs index d6db5dd9..2b845daf 100644 --- a/src/proxy/src/connections.rs +++ b/src/proxy/src/connections.rs @@ -10,8 +10,10 @@ use async_trait::async_trait; use futures::future::{self, BoxFuture}; use log::{debug, info, warn}; use s2n_tls_tokio::TlsStream; +use std::net::SocketAddr; use std::sync::Arc; use std::{collections::HashMap, time::Duration}; +use tokio::net::lookup_host; use tokio::task::JoinHandle; use tokio::time::{timeout_at, Instant}; use tokio::{ @@ -334,14 +336,21 @@ pub fn get_bind_response_string(bind_response: &BindResponse) -> String { #[derive(Clone)] pub struct PlainTextPartitionFinder { pub mount_target_addr: String, + pub prefer_ipv4: bool, } #[async_trait] impl PartitionFinder for PlainTextPartitionFinder { async fn create_connect_future(&self) -> BoxFuture<'static, Result> { let mount_target_address = self.mount_target_addr.clone(); + let prefer_ipv4 = self.prefer_ipv4; Box::pin(async move { - match TcpStream::connect(mount_target_address).await { + let stream = if prefer_ipv4 { + TcpStream::connect(resolve_addr(&mount_target_address).await?).await + } else { + TcpStream::connect(mount_target_address).await + }; + match stream { Ok(tcp_stream) => Ok(configure_stream(tcp_stream)), Err(e) => Err(ConnectError::IoError(e)), } @@ -351,11 +360,15 @@ impl PartitionFinder for PlainTextPartitionFinder { pub struct TlsPartitionFinder { tls_config: Arc>, + pub prefer_ipv4: bool, } impl TlsPartitionFinder { - pub fn new(tls_config: Arc>) -> Self { - TlsPartitionFinder { tls_config } + pub fn new(tls_config: Arc>, prefer_ipv4: bool) -> Self { + TlsPartitionFinder { + tls_config, + prefer_ipv4, + } } } @@ -364,11 +377,28 @@ impl PartitionFinder> for TlsPartitionFinder { async fn create_connect_future( &self, ) -> BoxFuture<'static, Result, ConnectError>> { - let tls_config_copy = self.tls_config.lock().await.clone(); - Box::pin(establish_tls_stream(tls_config_copy)) + let mut tls_config_copy = self.tls_config.lock().await.clone(); + let prefer_ipv4 = self.prefer_ipv4; + Box::pin(async move { + if prefer_ipv4 { + let addr = resolve_addr(&tls_config_copy.remote_addr).await?; + tls_config_copy.remote_addr = addr.to_string(); + } + establish_tls_stream(tls_config_copy).await + }) } } +/// Resolve `addr` (host:port) to the first IPv4 `SocketAddr`. +/// Returns `ConnectError::IoError` if no IPv4 address is found. +async fn resolve_addr(addr: &str) -> Result { + lookup_host(addr) + .await + .map_err(ConnectError::IoError)? + .find(|a| a.is_ipv4()) + .ok_or_else(|| ConnectError::IoError(std::io::Error::other("no IPv4 address found"))) +} + #[cfg(test)] mod tests { use super::*; @@ -399,6 +429,7 @@ mod tests { let error = tokio::spawn(async move { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), + prefer_ipv4: false, }; partition_finder .establish_connection(create_deadline(test_single_connection_timeout), PROXY_ID) @@ -421,6 +452,7 @@ mod tests { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), + prefer_ipv4: false, }; partition_finder .inner_establish_multiplex_connection( @@ -453,6 +485,7 @@ mod tests { let task = tokio::spawn(async move { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), + prefer_ipv4: false, }; partition_finder .inner_establish_multiplex_connection( @@ -565,6 +598,7 @@ mod tests { }); let tls_partition_finder = TlsPartitionFinder { tls_config: tls_config_ptr.clone(), + prefer_ipv4: false, }; let _ = kill(nix::unistd::Pid::this(), Signal::SIGHUP); rx.await.unwrap(); @@ -573,4 +607,39 @@ mod tests { tls_partition_finder.tls_config.lock().await.client_cert ); } + + #[tokio::test] + async fn test_resolve_addr_returns_ipv4() { + let result = resolve_addr("127.0.0.1:2049").await; + assert!(result.is_ok(), "expected Ok, got {:?}", result); + assert!(result.unwrap().is_ipv4()); + } + + #[tokio::test] + async fn test_resolve_addr_no_ipv4_returns_error() { + let result = resolve_addr("[::1]:2049").await; + assert!( + matches!(result, Err(ConnectError::IoError(_))), + "expected IoError for IPv6-only address, got {:?}", + result + ); + } + + #[tokio::test] + async fn test_plain_text_finder_prefer_ipv4_connects() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + tokio::spawn(async move { listener.accept().await }); + + let finder = PlainTextPartitionFinder { + mount_target_addr: format!("127.0.0.1:{}", port), + prefer_ipv4: true, + }; + let stream = finder.create_connect_future().await.await; + assert!( + stream.is_ok(), + "expected connection to succeed, got {:?}", + stream + ); + } } diff --git a/src/proxy/src/main.rs b/src/proxy/src/main.rs index d1432c6a..338e2036 100644 --- a/src/proxy/src/main.rs +++ b/src/proxy/src/main.rs @@ -96,7 +96,10 @@ async fn main() { let controller = Controller::new( &proxy_config.nested_config.listen_addr, proxy_config.clone(), - Arc::new(TlsPartitionFinder::new(tls_config)), + Arc::new(TlsPartitionFinder::new( + tls_config, + proxy_config.nested_config.prefer_ipv4, + )), status_reporter, cw_publisher.clone(), ) @@ -112,6 +115,7 @@ async fn main() { proxy_config.clone(), Arc::new(PlainTextPartitionFinder { mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(), + prefer_ipv4: proxy_config.nested_config.prefer_ipv4, }), status_reporter, cw_publisher.clone(), diff --git a/test/mount_common_test/test_get_fallback_mount_target_ip_address.py b/test/mount_common_test/test_get_fallback_mount_target_ip_address.py index f1a0bd18..2dbe19a2 100644 --- a/test/mount_common_test/test_get_fallback_mount_target_ip_address.py +++ b/test/mount_common_test/test_get_fallback_mount_target_ip_address.py @@ -367,3 +367,24 @@ def test_get_fallback_mount_target_ip_address_helper_prefer_ipv4(mocker): ) assert ip_address == ipv4_address + + +def test_get_fallback_mount_target_ip_address_helper_ipv6_only_prefer_ipv4(mocker): + config = _get_mock_config() + config.set(efs_utils_common.constants.CONFIG_SECTION, "prefer_ipv4", "true") + + mocker.patch( + "mount_efs.dns_resolver.get_botocore_client", + side_effect=[MOCK_EFS_AGENT, MOCK_EC2_AGENT], + ) + + ipv6_address = "2001:db8:3333:4444:5555:6666:7777:8888" + mocker.patch( + "mount_efs.dns_resolver.get_mount_target_in_az", + return_value={"Ipv6Address": ipv6_address}, + ) + + with pytest.raises(efs_utils_common.exceptions.FallbackException) as excinfo: + dns_resolver.get_fallback_mount_target_ip_address_helper(config, {}, FS_ID) + + assert "prefer_ipv4" in str(excinfo.value)