From e899f116225ac510981f9409befdc60f37085d60 Mon Sep 17 00:00:00 2001 From: sayantn Date: Sun, 24 May 2026 23:50:47 +0530 Subject: [PATCH 1/3] Some documentation fixes --- crates/core_arch/src/x86_64/amx.rs | 39 ++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index b3b3e86750..31c1762ca2 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -3,12 +3,31 @@ use crate::core_arch::{simd::*, x86::*}; #[cfg(test)] use stdarch_test::assert_instr; -/// Load tile configuration from a 64-byte memory location specified by mem_addr. +/// Load tile configuration from a 64-byte memory location specified by `mem_addr`. /// The tile configuration format is specified below, and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. /// Any invalid configurations will result in #GP fault. /// +/// ```intel +/// // format of memory payload. each field is a byte. +/// 0: palette +/// 1: start_row +/// 2-15: reserved, must be zero +/// 16-17: tile0.colsb +/// 18-19: tile1.colsb +/// 20-21: tile2.colsb +/// ... +/// 30-31: tile7.colsb +/// 32-47: reserved, must be zero +/// 48: tile0.rows +/// 49: tile1.rows +/// 50: tile2.rows +/// ... +/// 55: tile7.rows +/// 56-63: reserved, must be zero +/// ``` +/// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) #[inline] #[target_feature(enable = "amx-tile")] @@ -18,8 +37,8 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { ldtilecfg(mem_addr); } -/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. -/// The tile configuration format is specified below, and includes the tile type pallette, +/// Stores the current tile configuration to a 64-byte memory location specified by `mem_addr`. +/// The tile configuration format is as specified in [`_tile_loadconfig`], and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) @@ -31,7 +50,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { sttilecfg(mem_addr); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) #[inline] @@ -55,7 +74,7 @@ pub unsafe fn _tile_release() { tilerelease(); } -/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// Store the tile specified by src to memory specified by base address and stride using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) #[inline] @@ -68,8 +87,8 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { tilestored64(DST as i8, base, stride); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration -/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration +/// previously configured via [`_tile_loadconfig`]. This intrinsic provides a hint to the implementation that the data will /// likely not be reused in the near future and the data caching can be optimized accordingly. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) @@ -83,7 +102,7 @@ pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) tileloaddt164(DST as i8, base, stride); } -/// Zero the tile specified by tdest. +/// Zero the tile specified by `tdest`. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) #[inline] @@ -321,7 +340,7 @@ pub unsafe fn _tile_dphf8ps() { } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Additionally, this intrinsic indicates the source memory location is likely to become /// read-shared by multiple processors, i.e., read in the future by at least one other processor /// before it is written, assuming it is ever written in the future. @@ -339,7 +358,7 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Provides a hint to the implementation that the data would be reused but does not need /// to be resident in the nearest cache levels. /// Additionally, this intrinsic indicates the source memory location is likely to become From d271d548781fbabe2172beea14cd0e53e15f529d Mon Sep 17 00:00:00 2001 From: sayantn Date: Sun, 24 May 2026 23:53:42 +0530 Subject: [PATCH 2/3] Use correct LLVM intrinsic signature --- crates/core_arch/src/x86_64/amx.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index 31c1762ca2..ddd7948061 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -60,7 +60,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloadd64(DST as i8, base, stride); + tileloadd64(DST as i8, base, stride as u64); } /// Release the tile configuration to return to the init state, which releases all storage it currently holds. @@ -84,7 +84,7 @@ pub unsafe fn _tile_release() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tilestored64(DST as i8, base, stride); + tilestored64(DST as i8, base, stride as u64); } /// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration @@ -99,7 +99,7 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddt164(DST as i8, base, stride); + tileloaddt164(DST as i8, base, stride as u64); } /// Zero the tile specified by `tdest`. @@ -354,7 +354,7 @@ pub unsafe fn _tile_dphf8ps() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrs64(DST as i8, base, stride); + tileloaddrs64(DST as i8, base, stride as u64); } /// Load tile rows from memory specified by base address and stride into destination tile dst @@ -374,7 +374,7 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrst164(DST as i8, base, stride); + tileloaddrst164(DST as i8, base, stride as u64); } /// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) @@ -601,13 +601,13 @@ unsafe extern "C" { #[link_name = "llvm.x86.sttilecfg"] fn sttilecfg(mem_addr: *mut u8); #[link_name = "llvm.x86.tileloadd64"] - fn tileloadd64(dst: i8, base: *const u8, stride: usize); + fn tileloadd64(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddt164"] - fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + fn tileloaddt164(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); #[link_name = "llvm.x86.tilestored64"] - fn tilestored64(dst: i8, base: *mut u8, stride: usize); + fn tilestored64(dst: i8, base: *mut u8, stride: u64); #[link_name = "llvm.x86.tilezero"] fn tilezero(dst: i8); #[link_name = "llvm.x86.tdpbf16ps"] @@ -635,9 +635,9 @@ unsafe extern "C" { #[link_name = "llvm.x86.tdphf8ps"] fn tdphf8ps(dst: i8, a: i8, b: i8); #[link_name = "llvm.x86.tileloaddrs64"] - fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); + fn tileloaddrs64(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddrst164"] - fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); + fn tileloaddrst164(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tmmultf32ps"] fn tmmultf32ps(dst: i8, a: i8, b: i8); #[link_name = "llvm.x86.tcvtrowd2ps"] From 3a7d435bf5018a988da37a69c6d4af88bf18accf Mon Sep 17 00:00:00 2001 From: sayantn Date: Sun, 24 May 2026 23:56:39 +0530 Subject: [PATCH 3/3] Some refactorings in AMX tests --- crates/core_arch/src/x86_64/amx.rs | 111 +++++++++++++++-------------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index ddd7948061..08585d2067 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -670,7 +670,7 @@ unsafe extern "C" { mod tests { use crate::core_arch::x86::_mm_cvtness_sbh; use crate::core_arch::x86_64::*; - use core::{array, mem::transmute}; + use core::array; use stdarch_test::simd_test; #[cfg(target_os = "linux")] use syscalls::{Sysno, syscall}; @@ -723,19 +723,23 @@ mod tests { #[cfg(target_os = "linux")] #[target_feature(enable = "amx-tile")] #[inline] - unsafe fn _init_amx() { + fn _init_amx() { let mut ret: usize; let mut xfeatures: usize = 0; - ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) - .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1022, &raw mut xfeatures) + .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to get XFEATURES"); } else { match 0b11 & (xfeatures >> 17) { 0 => panic!("AMX is not available"), 1 => { - ret = syscall!(Sysno::arch_prctl, 0x1023, 18) - .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1023, 18) + .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to enable AMX"); } @@ -778,7 +782,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -795,7 +799,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -812,9 +816,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -831,9 +835,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -846,14 +850,15 @@ mod tests { } } - #[simd_test(enable = "amx-bf16,avx512f")] + const BF16_1: u16 = 0x3f80; + const BF16_2: u16 = 0x4000; + + #[simd_test(enable = "amx-bf16")] fn test_tile_dpbf16ps() { unsafe { _init_amx(); - let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); - let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); - let ones: [u8; 1024] = transmute([bf16_1; 512]); - let twos: [u8; 1024] = transmute([bf16_2; 512]); + let ones = [BF16_1; 512]; + let twos = [BF16_2; 512]; let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -863,10 +868,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbf16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -887,10 +892,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbssd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -911,10 +916,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbsud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -935,10 +940,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbusd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -959,10 +964,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbuud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -983,10 +988,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -1007,10 +1012,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmimfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -1031,10 +1036,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmrlfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[0f32; 16]; 16]); } @@ -1060,8 +1065,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1084,8 +1089,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbhf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1108,8 +1113,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1132,8 +1137,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1152,9 +1157,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -1171,9 +1176,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); }