Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dist/efs-utils.conf
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ fall_back_to_mount_target_ip_address_enabled = true
# By default, we use IMDSv2 to get the instance metadata, set this to true if you want to disable IMDSv2 usage
disable_fetch_ec2_metadata_token = false

# Set to true to force IPv4 when resolving EFS DNS names and connecting to mount targets.
# Enable this on dual-stack or IPv6-capable hosts where you want to ensure the NFS connection
# uses IPv4. When false (default), the system resolver picks the address family.
prefer_ipv4 = false

# By default, we enable efs-utils to retry failed mount.nfs command that due to (1) connection reset by peer (2) the
# mount.nfs is not finished within 'retry_nfs_mount_command_timeout_sec'. If the retry count is set as N, initial N - 1
# mount attempts will timeout if the command does not finish within 'retry_nfs_mount_command_timeout_sec' sec.
Expand Down
5 changes: 5 additions & 0 deletions dist/s3files-utils.conf
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ fall_back_to_mount_target_ip_address_enabled = true
# By default, we use IMDSv2 to get the instance metadata, set this to true if you want to disable IMDSv2 usage
disable_fetch_ec2_metadata_token = false

# Set to true to force IPv4 when resolving S3 Files DNS names and connecting to mount targets.
# Enable this on dual-stack or IPv6-capable hosts where you want to ensure the NFS connection
# uses IPv4. When false (default), the system resolver picks the address family.
prefer_ipv4 = false

# By default, we enable efs-utils to retry failed mount.nfs command that due to (1) connection reset by peer (2) the
# mount.nfs is not finished within 'retry_nfs_mount_command_timeout_sec'. If the retry count is set as N, initial N - 1
# mount attempts will timeout if the command does not finish within 'retry_nfs_mount_command_timeout_sec' sec.
Expand Down
5 changes: 3 additions & 2 deletions src/efs_utils_common/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ def get_ipv6_addresses(hostname):
return []


def dns_name_can_be_resolved(dns_name):
def dns_name_can_be_resolved(dns_name, prefer_ipv4=False):
try:
addr_info = socket.getaddrinfo(dns_name, None, socket.AF_UNSPEC)
family = socket.AF_INET if prefer_ipv4 else socket.AF_UNSPEC
addr_info = socket.getaddrinfo(dns_name, None, family)
return len(addr_info) > 0
except socket.gaierror:
return False
Expand Down
4 changes: 4 additions & 0 deletions src/efs_utils_common/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ def write_stunnel_config_file(
efs_config["fs_id"] = fs_id
efs_config["region"] = region
efs_config["efs_utils_version"] = VERSION
if get_boolean_config_item_value(
config, CONFIG_SECTION, "prefer_ipv4", default_value=False
):
efs_config["prefer_ipv4"] = "true"

stunnel_config = "\n".join(
serialize_stunnel_config(global_config)
Expand Down
16 changes: 14 additions & 2 deletions src/mount_efs/dns_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def _validate_replacement_field_count(format_str, expected_ct):
ip_address=ip_address, fallback_message=fallback_message
)

if dns_name_can_be_resolved(dns_name):
if dns_name_can_be_resolved(dns_name, get_boolean_config_item_value(
config, CONFIG_SECTION, "prefer_ipv4", default_value=False
)):
return dns_name, None

logging.info(
Expand Down Expand Up @@ -204,6 +206,12 @@ def get_fallback_mount_target_ip_address_helper(config, options, fs_id):
if "IpAddress" in mount_target:
return mount_target.get("IpAddress")
elif "Ipv6Address" in mount_target:
if get_boolean_config_item_value(
config, CONFIG_SECTION, "prefer_ipv4", default_value=False
):
raise FallbackException(
"Mount target has only an IPv6 address but prefer_ipv4 is enabled."
)
return mount_target.get("Ipv6Address")


Expand Down Expand Up @@ -256,8 +264,12 @@ def match_device(config, device, options):
return remote, path, None

try:
prefer_ipv4 = get_boolean_config_item_value(
config, CONFIG_SECTION, "prefer_ipv4", default_value=False
)
addrinfo = socket.getaddrinfo(
remote, None, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_CANONNAME
remote, None, socket.AF_INET if prefer_ipv4 else socket.AF_UNSPEC,
socket.SOCK_STREAM, 0, socket.AI_CANONNAME
)
hostnames = list(
set(
Expand Down
5 changes: 4 additions & 1 deletion src/mount_s3files/dns_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import socket

from efs_utils_common.cloudwatch import create_default_cloudwatchlog_agent_if_not_exist
from efs_utils_common.config_utils import get_boolean_config_item_value
from efs_utils_common.constants import CONFIG_SECTION, FS_ID_REGEX_PATTERN
from efs_utils_common.context import MountContext
from efs_utils_common.error_reporting import fatal_error
Expand Down Expand Up @@ -90,7 +91,9 @@ def _validate_replacement_field_count(format_str, expected_ct):
ip_address=ip_address, fallback_message=fallback_message
)

if dns_name_can_be_resolved(dns_name):
if dns_name_can_be_resolved(dns_name, get_boolean_config_item_value(
config, CONFIG_SECTION, "prefer_ipv4", default_value=False
)):
return dns_name, None

logging.info(
Expand Down
7 changes: 7 additions & 0 deletions src/proxy/src/config_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ pub struct EfsConfig {
/// efs-utils version string for channel init
#[serde(alias = "efs_utils_version", default)]
pub efs_utils_version: String,

/// When true, only connect to IPv4 addresses when resolving the mount target hostname
#[serde(alias = "prefer_ipv4", deserialize_with = "deserialize_bool", default)]
pub prefer_ipv4: bool,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ReadBypassConfig {
Expand Down Expand Up @@ -487,6 +491,7 @@ checkHost = fs-12341234.efs.us-east-1.amazonaws.com
proxy_logging_max_bytes: DEFAULT_PROXY_LOGGING_MAX_BYTES(),
proxy_logging_file_count: DEFAULT_PROXY_LOGGING_FILE_COUNT(),
efs_utils_version: String::new(),
prefer_ipv4: false,
},
};

Expand Down Expand Up @@ -584,6 +589,7 @@ jwt_path = baz
proxy_logging_max_bytes: DEFAULT_PROXY_LOGGING_MAX_BYTES(),
proxy_logging_file_count: DEFAULT_PROXY_LOGGING_FILE_COUNT(),
efs_utils_version: String::new(),
prefer_ipv4: false,
},
};

Expand Down Expand Up @@ -689,6 +695,7 @@ readahead_max_window_size_bytes = {test_value}
proxy_logging_max_bytes: DEFAULT_PROXY_LOGGING_MAX_BYTES(),
proxy_logging_file_count: DEFAULT_PROXY_LOGGING_FILE_COUNT(),
efs_utils_version: String::new(),
prefer_ipv4: false,
},
};

Expand Down
79 changes: 74 additions & 5 deletions src/proxy/src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ use async_trait::async_trait;
use futures::future::{self, BoxFuture};
use log::{debug, info, warn};
use s2n_tls_tokio::TlsStream;
use std::net::SocketAddr;
use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
use tokio::net::lookup_host;
use tokio::task::JoinHandle;
use tokio::time::{timeout_at, Instant};
use tokio::{
Expand Down Expand Up @@ -334,14 +336,21 @@ pub fn get_bind_response_string(bind_response: &BindResponse) -> String {
#[derive(Clone)]
pub struct PlainTextPartitionFinder {
pub mount_target_addr: String,
pub prefer_ipv4: bool,
}

#[async_trait]
impl PartitionFinder<TcpStream> for PlainTextPartitionFinder {
async fn create_connect_future(&self) -> BoxFuture<'static, Result<TcpStream, ConnectError>> {
let mount_target_address = self.mount_target_addr.clone();
let prefer_ipv4 = self.prefer_ipv4;
Box::pin(async move {
match TcpStream::connect(mount_target_address).await {
let stream = if prefer_ipv4 {
TcpStream::connect(resolve_addr(&mount_target_address).await?).await
} else {
TcpStream::connect(mount_target_address).await
};
match stream {
Ok(tcp_stream) => Ok(configure_stream(tcp_stream)),
Err(e) => Err(ConnectError::IoError(e)),
}
Expand All @@ -351,11 +360,15 @@ impl PartitionFinder<TcpStream> for PlainTextPartitionFinder {

pub struct TlsPartitionFinder {
tls_config: Arc<tokio::sync::Mutex<TlsConfig>>,
pub prefer_ipv4: bool,
}

impl TlsPartitionFinder {
pub fn new(tls_config: Arc<tokio::sync::Mutex<TlsConfig>>) -> Self {
TlsPartitionFinder { tls_config }
pub fn new(tls_config: Arc<tokio::sync::Mutex<TlsConfig>>, prefer_ipv4: bool) -> Self {
TlsPartitionFinder {
tls_config,
prefer_ipv4,
}
}
}

Expand All @@ -364,11 +377,28 @@ impl PartitionFinder<TlsStream<TcpStream>> for TlsPartitionFinder {
async fn create_connect_future(
&self,
) -> BoxFuture<'static, Result<TlsStream<TcpStream>, ConnectError>> {
let tls_config_copy = self.tls_config.lock().await.clone();
Box::pin(establish_tls_stream(tls_config_copy))
let mut tls_config_copy = self.tls_config.lock().await.clone();
let prefer_ipv4 = self.prefer_ipv4;
Box::pin(async move {
if prefer_ipv4 {
let addr = resolve_addr(&tls_config_copy.remote_addr).await?;
tls_config_copy.remote_addr = addr.to_string();
}
establish_tls_stream(tls_config_copy).await
})
}
}

/// Resolve `addr` (host:port) to the first IPv4 `SocketAddr`.
/// Returns `ConnectError::IoError` if no IPv4 address is found.
async fn resolve_addr(addr: &str) -> Result<SocketAddr, ConnectError> {
lookup_host(addr)
.await
.map_err(ConnectError::IoError)?
.find(|a| a.is_ipv4())
.ok_or_else(|| ConnectError::IoError(std::io::Error::other("no IPv4 address found")))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -399,6 +429,7 @@ mod tests {
let error = tokio::spawn(async move {
let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
prefer_ipv4: false,
};
partition_finder
.establish_connection(create_deadline(test_single_connection_timeout), PROXY_ID)
Expand All @@ -421,6 +452,7 @@ mod tests {

let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
prefer_ipv4: false,
};
partition_finder
.inner_establish_multiplex_connection(
Expand Down Expand Up @@ -453,6 +485,7 @@ mod tests {
let task = tokio::spawn(async move {
let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
prefer_ipv4: false,
};
partition_finder
.inner_establish_multiplex_connection(
Expand Down Expand Up @@ -565,6 +598,7 @@ mod tests {
});
let tls_partition_finder = TlsPartitionFinder {
tls_config: tls_config_ptr.clone(),
prefer_ipv4: false,
};
let _ = kill(nix::unistd::Pid::this(), Signal::SIGHUP);
rx.await.unwrap();
Expand All @@ -573,4 +607,39 @@ mod tests {
tls_partition_finder.tls_config.lock().await.client_cert
);
}

#[tokio::test]
async fn test_resolve_addr_returns_ipv4() {
let result = resolve_addr("127.0.0.1:2049").await;
assert!(result.is_ok(), "expected Ok, got {:?}", result);
assert!(result.unwrap().is_ipv4());
}

#[tokio::test]
async fn test_resolve_addr_no_ipv4_returns_error() {
let result = resolve_addr("[::1]:2049").await;
assert!(
matches!(result, Err(ConnectError::IoError(_))),
"expected IoError for IPv6-only address, got {:?}",
result
);
}

#[tokio::test]
async fn test_plain_text_finder_prefer_ipv4_connects() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move { listener.accept().await });

let finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port),
prefer_ipv4: true,
};
let stream = finder.create_connect_future().await.await;
assert!(
stream.is_ok(),
"expected connection to succeed, got {:?}",
stream
);
}
}
6 changes: 5 additions & 1 deletion src/proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ async fn main() {
let controller = Controller::new(
&proxy_config.nested_config.listen_addr,
proxy_config.clone(),
Arc::new(TlsPartitionFinder::new(tls_config)),
Arc::new(TlsPartitionFinder::new(
tls_config,
proxy_config.nested_config.prefer_ipv4,
)),
status_reporter,
cw_publisher.clone(),
)
Expand All @@ -112,6 +115,7 @@ async fn main() {
proxy_config.clone(),
Arc::new(PlainTextPartitionFinder {
mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(),
prefer_ipv4: proxy_config.nested_config.prefer_ipv4,
}),
status_reporter,
cw_publisher.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,3 +367,24 @@ def test_get_fallback_mount_target_ip_address_helper_prefer_ipv4(mocker):
)

assert ip_address == ipv4_address


def test_get_fallback_mount_target_ip_address_helper_ipv6_only_prefer_ipv4(mocker):
config = _get_mock_config()
config.set(efs_utils_common.constants.CONFIG_SECTION, "prefer_ipv4", "true")

mocker.patch(
"mount_efs.dns_resolver.get_botocore_client",
side_effect=[MOCK_EFS_AGENT, MOCK_EC2_AGENT],
)

ipv6_address = "2001:db8:3333:4444:5555:6666:7777:8888"
mocker.patch(
"mount_efs.dns_resolver.get_mount_target_in_az",
return_value={"Ipv6Address": ipv6_address},
)

with pytest.raises(efs_utils_common.exceptions.FallbackException) as excinfo:
dns_resolver.get_fallback_mount_target_ip_address_helper(config, {}, FS_ID)

assert "prefer_ipv4" in str(excinfo.value)
Loading