diff --git a/crates/core_arch/src/x86_64/amx.rs b/crates/core_arch/src/x86_64/amx.rs index b3b3e86750..08585d2067 100644 --- a/crates/core_arch/src/x86_64/amx.rs +++ b/crates/core_arch/src/x86_64/amx.rs @@ -3,12 +3,31 @@ use crate::core_arch::{simd::*, x86::*}; #[cfg(test)] use stdarch_test::assert_instr; -/// Load tile configuration from a 64-byte memory location specified by mem_addr. +/// Load tile configuration from a 64-byte memory location specified by `mem_addr`. /// The tile configuration format is specified below, and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If the specified pallette_id is zero, /// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed. /// Any invalid configurations will result in #GP fault. /// +/// ```intel +/// // format of memory payload. each field is a byte. +/// 0: palette +/// 1: start_row +/// 2-15: reserved, must be zero +/// 16-17: tile0.colsb +/// 18-19: tile1.colsb +/// 20-21: tile2.colsb +/// ... +/// 30-31: tile7.colsb +/// 32-47: reserved, must be zero +/// 48: tile0.rows +/// 49: tile1.rows +/// 50: tile2.rows +/// ... +/// 55: tile7.rows +/// 56-63: reserved, must be zero +/// ``` +/// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875) #[inline] #[target_feature(enable = "amx-tile")] @@ -18,8 +37,8 @@ pub unsafe fn _tile_loadconfig(mem_addr: *const u8) { ldtilecfg(mem_addr); } -/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr. -/// The tile configuration format is specified below, and includes the tile type pallette, +/// Stores the current tile configuration to a 64-byte memory location specified by `mem_addr`. +/// The tile configuration format is as specified in [`_tile_loadconfig`], and includes the tile type pallette, /// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879) @@ -31,7 +50,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { sttilecfg(mem_addr); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig. +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877) #[inline] @@ -41,7 +60,7 @@ pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloadd64(DST as i8, base, stride); + tileloadd64(DST as i8, base, stride as u64); } /// Release the tile configuration to return to the init state, which releases all storage it currently holds. @@ -55,7 +74,7 @@ pub unsafe fn _tile_release() { tilerelease(); } -/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig. +/// Store the tile specified by src to memory specified by base address and stride using the tile configuration previously configured via [`_tile_loadconfig`]. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881) #[inline] @@ -65,11 +84,11 @@ pub unsafe fn _tile_release() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tilestored64(DST as i8, base, stride); + tilestored64(DST as i8, base, stride as u64); } -/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration -/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will +/// Load tile rows from memory specified by base address and stride into destination tile dst using the tile configuration +/// previously configured via [`_tile_loadconfig`]. This intrinsic provides a hint to the implementation that the data will /// likely not be reused in the near future and the data caching can be optimized accordingly. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883) @@ -80,10 +99,10 @@ pub unsafe fn _tile_stored(base: *mut u8, stride: usize) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loadd(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddt164(DST as i8, base, stride); + tileloaddt164(DST as i8, base, stride as u64); } -/// Zero the tile specified by tdest. +/// Zero the tile specified by `tdest`. /// /// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885) #[inline] @@ -321,7 +340,7 @@ pub unsafe fn _tile_dphf8ps() { } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Additionally, this intrinsic indicates the source memory location is likely to become /// read-shared by multiple processors, i.e., read in the future by at least one other processor /// before it is written, assuming it is ever written in the future. @@ -335,11 +354,11 @@ pub unsafe fn _tile_dphf8ps() { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrs64(DST as i8, base, stride); + tileloaddrs64(DST as i8, base, stride as u64); } /// Load tile rows from memory specified by base address and stride into destination tile dst -/// using the tile configuration previously configured via _tile_loadconfig. +/// using the tile configuration previously configured via [`_tile_loadconfig`]. /// Provides a hint to the implementation that the data would be reused but does not need /// to be resident in the nearest cache levels. /// Additionally, this intrinsic indicates the source memory location is likely to become @@ -355,7 +374,7 @@ pub unsafe fn _tile_loaddrs(base: *const u8, stride: usize) { #[unstable(feature = "x86_amx_intrinsics", issue = "126622")] pub unsafe fn _tile_stream_loaddrs(base: *const u8, stride: usize) { static_assert_uimm_bits!(DST, 3); - tileloaddrst164(DST as i8, base, stride); + tileloaddrst164(DST as i8, base, stride as u64); } /// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) @@ -582,13 +601,13 @@ unsafe extern "C" { #[link_name = "llvm.x86.sttilecfg"] fn sttilecfg(mem_addr: *mut u8); #[link_name = "llvm.x86.tileloadd64"] - fn tileloadd64(dst: i8, base: *const u8, stride: usize); + fn tileloadd64(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddt164"] - fn tileloaddt164(dst: i8, base: *const u8, stride: usize); + fn tileloaddt164(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tilerelease"] fn tilerelease(); #[link_name = "llvm.x86.tilestored64"] - fn tilestored64(dst: i8, base: *mut u8, stride: usize); + fn tilestored64(dst: i8, base: *mut u8, stride: u64); #[link_name = "llvm.x86.tilezero"] fn tilezero(dst: i8); #[link_name = "llvm.x86.tdpbf16ps"] @@ -616,9 +635,9 @@ unsafe extern "C" { #[link_name = "llvm.x86.tdphf8ps"] fn tdphf8ps(dst: i8, a: i8, b: i8); #[link_name = "llvm.x86.tileloaddrs64"] - fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); + fn tileloaddrs64(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tileloaddrst164"] - fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); + fn tileloaddrst164(dst: i8, base: *const u8, stride: u64); #[link_name = "llvm.x86.tmmultf32ps"] fn tmmultf32ps(dst: i8, a: i8, b: i8); #[link_name = "llvm.x86.tcvtrowd2ps"] @@ -651,7 +670,7 @@ unsafe extern "C" { mod tests { use crate::core_arch::x86::_mm_cvtness_sbh; use crate::core_arch::x86_64::*; - use core::{array, mem::transmute}; + use core::array; use stdarch_test::simd_test; #[cfg(target_os = "linux")] use syscalls::{Sysno, syscall}; @@ -704,19 +723,23 @@ mod tests { #[cfg(target_os = "linux")] #[target_feature(enable = "amx-tile")] #[inline] - unsafe fn _init_amx() { + fn _init_amx() { let mut ret: usize; let mut xfeatures: usize = 0; - ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize) - .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1022, &raw mut xfeatures) + .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to get XFEATURES"); } else { match 0b11 & (xfeatures >> 17) { 0 => panic!("AMX is not available"), 1 => { - ret = syscall!(Sysno::arch_prctl, 0x1023, 18) - .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed"); + ret = unsafe { + syscall!(Sysno::arch_prctl, 0x1023, 18) + .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed") + }; if ret != 0 { panic!("Failed to enable AMX"); } @@ -759,7 +782,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -776,7 +799,7 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mut out = [[1_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[0; 64]; 16]); } @@ -793,9 +816,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -812,9 +835,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loadd::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -827,14 +850,15 @@ mod tests { } } - #[simd_test(enable = "amx-bf16,avx512f")] + const BF16_1: u16 = 0x3f80; + const BF16_2: u16 = 0x4000; + + #[simd_test(enable = "amx-bf16")] fn test_tile_dpbf16ps() { unsafe { _init_amx(); - let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits(); - let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits(); - let ones: [u8; 1024] = transmute([bf16_1; 512]); - let twos: [u8; 1024] = transmute([bf16_2; 512]); + let ones = [BF16_1; 512]; + let twos = [BF16_2; 512]; let mut res = [[0f32; 16]; 16]; let mut config = __tilecfg::default(); config.palette = 1; @@ -844,10 +868,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbf16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -868,10 +892,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbssd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -892,10 +916,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const i8 as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbsud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -916,10 +940,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const i8 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpbusd::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[-128_i32; 16]; 16]); } @@ -940,10 +964,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbuud::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[128_i32; 16]; 16]); } @@ -964,10 +988,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_dpfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -988,10 +1012,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmimfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[64f32; 16]; 16]); } @@ -1012,10 +1036,10 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const f16 as *const u8, 64); - _tile_loadd::<2>(&twos as *const f16 as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr().cast(), 64); + _tile_loadd::<2>(twos.as_ptr().cast(), 64); _tile_cmmrlfp16ps::<0, 1, 2>(); - _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64); + _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(res, [[0f32; 16]; 16]); } @@ -1041,8 +1065,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1065,8 +1089,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dpbhf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1089,8 +1113,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphbf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1113,8 +1137,8 @@ mod tests { }); _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); - _tile_loadd::<1>(&ones as *const u8, 64); - _tile_loadd::<2>(&twos as *const u8, 64); + _tile_loadd::<1>(ones.as_ptr(), 64); + _tile_loadd::<2>(twos.as_ptr(), 64); _tile_dphf8ps::<0, 1, 2>(); _tile_stored::<0>(res.as_mut_ptr().cast(), 64); _tile_release(); @@ -1133,9 +1157,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); } @@ -1152,9 +1176,9 @@ mod tests { _tile_loadconfig(config.as_ptr()); _tile_zero::<0>(); let mat = [1_i8; 1024]; - _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64); + _tile_stream_loaddrs::<0>(mat.as_ptr().cast(), 64); let mut out = [[0_i8; 64]; 16]; - _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64); + _tile_stored::<0>(out.as_mut_ptr().cast(), 64); _tile_release(); assert_eq!(out, [[1; 64]; 16]); }