From 771e44ebd32df8e342efa1f246f5d5070af04ec4 Mon Sep 17 00:00:00 2001 From: beetrees Date: Sat, 15 Jun 2024 22:30:25 +0100 Subject: [PATCH] Add `f16` inline ASM support for RISC-V --- compiler/rustc_codegen_llvm/src/asm.rs | 55 +++++++++++++++++++++++--- compiler/rustc_span/src/symbol.rs | 2 + compiler/rustc_target/src/asm/riscv.rs | 7 ++-- tests/assembly/asm/riscv-types.rs | 55 +++++++++++++++++++++++++- 4 files changed, 108 insertions(+), 11 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/asm.rs b/compiler/rustc_codegen_llvm/src/asm.rs index 60e63b956db..34a0f9973f6 100644 --- a/compiler/rustc_codegen_llvm/src/asm.rs +++ b/compiler/rustc_codegen_llvm/src/asm.rs @@ -13,7 +13,7 @@ use rustc_codegen_ssa::traits::*; use rustc_data_structures::fx::FxHashMap; use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::{bug, span_bug, ty::Instance}; -use rustc_span::{Pos, Span}; +use rustc_span::{sym, Pos, Span, Symbol}; use rustc_target::abi::*; use rustc_target::asm::*; use tracing::debug; @@ -64,7 +64,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let mut layout = None; let ty = if let Some(ref place) = place { layout = Some(&place.layout); - llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout) + llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout, instance) } else if matches!( reg.reg_class(), InlineAsmRegClass::X86( @@ -112,7 +112,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { // so we just use the type of the input. &in_value.layout }; - let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout); + let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout, instance); output_types.push(ty); op_idx.insert(idx, constraints.len()); let prefix = if late { "=" } else { "=&" }; @@ -127,8 +127,13 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { for (idx, op) in operands.iter().enumerate() { match *op { InlineAsmOperandRef::In { reg, value } => { - let llval = - llvm_fixup_input(self, value.immediate(), reg.reg_class(), &value.layout); + let llval = llvm_fixup_input( + self, + value.immediate(), + reg.reg_class(), + &value.layout, + instance, + ); inputs.push(llval); op_idx.insert(idx, constraints.len()); constraints.push(reg_to_llvm(reg, Some(&value.layout))); @@ -139,6 +144,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { in_value.immediate(), reg.reg_class(), &in_value.layout, + instance, ); inputs.push(value); @@ -341,7 +347,8 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { } else { self.extract_value(result, op_idx[&idx] as u64) }; - let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout); + let value = + llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance); OperandValue::Immediate(value).store(self, place); } } @@ -913,12 +920,22 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty } } +fn any_target_feature_enabled( + cx: &CodegenCx<'_, '_>, + instance: Instance<'_>, + features: &[Symbol], +) -> bool { + let enabled = cx.tcx.asm_target_features(instance.def_id()); + features.iter().any(|feat| enabled.contains(feat)) +} + /// Fix up an input value to work around LLVM bugs. fn llvm_fixup_input<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, mut value: &'ll Value, reg: InlineAsmRegClass, layout: &TyAndLayout<'tcx>, + instance: Instance<'_>, ) -> &'ll Value { let dl = &bx.tcx.data_layout; match (reg, layout.abi) { @@ -1029,6 +1046,16 @@ fn llvm_fixup_input<'ll, 'tcx>( _ => value, } } + (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s)) + if s.primitive() == Primitive::Float(Float::F16) + && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) => + { + // Smaller floats are always "NaN-boxed" inside larger floats on RISC-V. + let value = bx.bitcast(value, bx.type_i16()); + let value = bx.zext(value, bx.type_i32()); + let value = bx.or(value, bx.const_u32(0xFFFF_0000)); + bx.bitcast(value, bx.type_f32()) + } _ => value, } } @@ -1039,6 +1066,7 @@ fn llvm_fixup_output<'ll, 'tcx>( mut value: &'ll Value, reg: InlineAsmRegClass, layout: &TyAndLayout<'tcx>, + instance: Instance<'_>, ) -> &'ll Value { match (reg, layout.abi) { (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => { @@ -1140,6 +1168,14 @@ fn llvm_fixup_output<'ll, 'tcx>( _ => value, } } + (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s)) + if s.primitive() == Primitive::Float(Float::F16) + && !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) => + { + let value = bx.bitcast(value, bx.type_i32()); + let value = bx.trunc(value, bx.type_i16()); + bx.bitcast(value, bx.type_f16()) + } _ => value, } } @@ -1149,6 +1185,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>( cx: &CodegenCx<'ll, 'tcx>, reg: InlineAsmRegClass, layout: &TyAndLayout<'tcx>, + instance: Instance<'_>, ) -> &'ll Type { match (reg, layout.abi) { (InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => { @@ -1242,6 +1279,12 @@ fn llvm_fixup_output_type<'ll, 'tcx>( _ => layout.llvm_type(cx), } } + (InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s)) + if s.primitive() == Primitive::Float(Float::F16) + && !any_target_feature_enabled(cx, instance, &[sym::zfhmin, sym::zfh]) => + { + cx.type_f32() + } _ => layout.llvm_type(cx), } } diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index f44fa1bcb4f..018442af144 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -2054,6 +2054,8 @@ symbols! { yes, yield_expr, ymm_reg, + zfh, + zfhmin, zmm_reg, } } diff --git a/compiler/rustc_target/src/asm/riscv.rs b/compiler/rustc_target/src/asm/riscv.rs index 3845a0e14af..02a4a5e2ece 100644 --- a/compiler/rustc_target/src/asm/riscv.rs +++ b/compiler/rustc_target/src/asm/riscv.rs @@ -40,12 +40,13 @@ impl RiscVInlineAsmRegClass { match self { Self::reg => { if arch == InlineAsmArch::RiscV64 { - types! { _: I8, I16, I32, I64, F32, F64; } + types! { _: I8, I16, I32, I64, F16, F32, F64; } } else { - types! { _: I8, I16, I32, F32; } + types! { _: I8, I16, I32, F16, F32; } } } - Self::freg => types! { f: F32; d: F64; }, + // FIXME(f16_f128): Add `q: F128;` once LLVM support the `Q` extension. + Self::freg => types! { f: F16, F32; d: F64; }, Self::vreg => &[], } } diff --git a/tests/assembly/asm/riscv-types.rs b/tests/assembly/asm/riscv-types.rs index 0d1f8305d37..51b3aaf99d9 100644 --- a/tests/assembly/asm/riscv-types.rs +++ b/tests/assembly/asm/riscv-types.rs @@ -1,12 +1,34 @@ -//@ revisions: riscv64 riscv32 +//@ revisions: riscv64 riscv32 riscv64-zfhmin riscv32-zfhmin riscv64-zfh riscv32-zfh //@ assembly-output: emit-asm + //@[riscv64] compile-flags: --target riscv64imac-unknown-none-elf //@[riscv64] needs-llvm-components: riscv + //@[riscv32] compile-flags: --target riscv32imac-unknown-none-elf //@[riscv32] needs-llvm-components: riscv + +//@[riscv64-zfhmin] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64 +//@[riscv64-zfhmin] needs-llvm-components: riscv +//@[riscv64-zfhmin] compile-flags: -C target-feature=+zfhmin +//@[riscv64-zfhmin] filecheck-flags: --check-prefix riscv64 + +//@[riscv32-zfhmin] compile-flags: --target riscv32imac-unknown-none-elf +//@[riscv32-zfhmin] needs-llvm-components: riscv +//@[riscv32-zfhmin] compile-flags: -C target-feature=+zfhmin + +//@[riscv64-zfh] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64 +//@[riscv64-zfh] needs-llvm-components: riscv +//@[riscv64-zfh] compile-flags: -C target-feature=+zfh +//@[riscv64-zfh] filecheck-flags: --check-prefix riscv64 --check-prefix zfhmin + +//@[riscv32-zfh] compile-flags: --target riscv32imac-unknown-none-elf +//@[riscv32-zfh] needs-llvm-components: riscv +//@[riscv32-zfh] compile-flags: -C target-feature=+zfh +//@[riscv32-zfh] filecheck-flags: --check-prefix zfhmin + //@ compile-flags: -C target-feature=+d -#![feature(no_core, lang_items, rustc_attrs)] +#![feature(no_core, lang_items, rustc_attrs, f16)] #![crate_type = "rlib"] #![no_core] #![allow(asm_sub_register)] @@ -33,6 +55,7 @@ type ptr = *mut u8; impl Copy for i8 {} impl Copy for i16 {} +impl Copy for f16 {} impl Copy for i32 {} impl Copy for f32 {} impl Copy for i64 {} @@ -103,6 +126,12 @@ macro_rules! check_reg { // CHECK: #NO_APP check!(reg_i8 i8 reg "mv"); +// CHECK-LABEL: reg_f16: +// CHECK: #APP +// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}} +// CHECK: #NO_APP +check!(reg_f16 f16 reg "mv"); + // CHECK-LABEL: reg_i16: // CHECK: #APP // CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}} @@ -141,6 +170,14 @@ check!(reg_f64 f64 reg "mv"); // CHECK: #NO_APP check!(reg_ptr ptr reg "mv"); +// CHECK-LABEL: freg_f16: +// zfhmin-NOT: or +// CHECK: #APP +// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}} +// CHECK: #NO_APP +// zfhmin-NOT: or +check!(freg_f16 f16 freg "fmv.s"); + // CHECK-LABEL: freg_f32: // CHECK: #APP // CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}} @@ -165,6 +202,12 @@ check_reg!(a0_i8 i8 "a0" "mv"); // CHECK: #NO_APP check_reg!(a0_i16 i16 "a0" "mv"); +// CHECK-LABEL: a0_f16: +// CHECK: #APP +// CHECK: mv a0, a0 +// CHECK: #NO_APP +check_reg!(a0_f16 f16 "a0" "mv"); + // CHECK-LABEL: a0_i32: // CHECK: #APP // CHECK: mv a0, a0 @@ -197,6 +240,14 @@ check_reg!(a0_f64 f64 "a0" "mv"); // CHECK: #NO_APP check_reg!(a0_ptr ptr "a0" "mv"); +// CHECK-LABEL: fa0_f16: +// zfhmin-NOT: or +// CHECK: #APP +// CHECK: fmv.s fa0, fa0 +// CHECK: #NO_APP +// zfhmin-NOT: or +check_reg!(fa0_f16 f16 "fa0" "fmv.s"); + // CHECK-LABEL: fa0_f32: // CHECK: #APP // CHECK: fmv.s fa0, fa0