diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -732,6 +732,7 @@ def LLVM_FCeilOp : LLVM_UnaryIntrinsicOp<"ceil">; def LLVM_CosOp : LLVM_UnaryIntrinsicOp<"cos">; def LLVM_CopySignOp : LLVM_BinarySameArgsIntrinsicOp<"copysign">; +def LLVM_FMAOp : LLVM_TernarySameArgsIntrinsicOp<"fma">; def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrinsicOp<"fmuladd">; def LLVM_SqrtOp : LLVM_UnaryIntrinsicOp<"sqrt">; diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -567,21 +567,25 @@ Results<(outs AnyVector)> { let summary = "vector outerproduct with optional fused add"; let description = [{ - Takes 2 1-D vectors and returns the 2-D vector containing the outer product. + Takes 2 1-D vectors and returns the 2-D vector containing the outer-product. An optional extra 2-D vector argument may be specified in which case the - operation returns the sum of the outer product and the extra vector. When - lowered to the LLVMIR dialect, this form emits `llvm.intr.fmuladd`, which - can lower to actual `fma` instructions in LLVM. + operation returns the sum of the outer-product and the extra vector. In this + multiply-accumulate scenario, the rounding mode is that obtained by + guaranteeing that a fused-multiply add operation is emitted. When lowered to + the LLVMIR dialect, this form emits `llvm.intr.fma`, which is guaranteed to + lower to actual `fma` instructions on x86. - Examples + Examples: + ``` %2 = vector.outerproduct %0, %1: vector<4xf32>, vector<8xf32> return %2: vector<4x8xf32> %3 = vector.outerproduct %0, %1, %2: vector<4xf32>, vector<8xf32>, vector<4x8xf32> return %3: vector<4x8xf32> + ``` }]; let extraClassDeclaration = [{ VectorType getOperandVectorTypeLHS() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -674,9 +674,9 @@ loc, vRHS, acc, rewriter.getI64ArrayAttr(d)); // 3. Compute aD outer b (plus accD, if relevant). Value aOuterbD = - accD ? rewriter.create(loc, vRHS, aD, b, accD) - .getResult() - : rewriter.create(loc, aD, b).getResult(); + accD + ? rewriter.create(loc, vRHS, aD, b, accD).getResult() + : rewriter.create(loc, aD, b).getResult(); // 4. Insert as value `d` in the descriptor. desc = rewriter.create(loc, llvmArrayOfVectType, desc, aOuterbD, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -222,11 +222,11 @@ // CHECK: llvm.mlir.undef : !llvm<"[2 x <3 x float>]"> // CHECK: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> // CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> -// CHECK: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> // CHECK: llvm.insertvalue {{.*}}[0] : !llvm<"[2 x <3 x float>]"> // CHECK: llvm.shufflevector {{.*}} [1 : i32, 1 : i32, 1 : i32] : !llvm<"<2 x float>">, !llvm<"<2 x float>"> // CHECK: llvm.extractvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> -// CHECK: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> +// CHECK: "llvm.intr.fma"({{.*}}) : (!llvm<"<3 x float>">, !llvm<"<3 x float>">, !llvm<"<3 x float>">) -> !llvm<"<3 x float>"> // CHECK: llvm.insertvalue {{.*}}[1] : !llvm<"[2 x <3 x float>]"> // CHECK: llvm.return {{.*}} : !llvm<"[2 x <3 x float>]"> diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -171,7 +171,7 @@ // LLVM-LOOPS: llvm.shufflevector {{.*}} [2 : i32, 2 : i32, 2 : i32, 2 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> // LLVM-LOOPS: llvm.shufflevector {{.*}} [3 : i32, 3 : i32, 3 : i32, 3 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> // LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]"> -// LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> +// LLVM-LOOPS-NEXT: "llvm.intr.fma"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> // LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -9,6 +9,10 @@ "llvm.intr.fmuladd"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float // CHECK: call <8 x float> @llvm.fmuladd.v8f32 "llvm.intr.fmuladd"(%arg2, %arg2, %arg2) : (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> + // CHECK: call float @llvm.fma.f32 + "llvm.intr.fma"(%arg0, %arg1, %arg0) : (!llvm.float, !llvm.float, !llvm.float) -> !llvm.float + // CHECK: call <8 x float> @llvm.fma.v8f32 + "llvm.intr.fma"(%arg2, %arg2, %arg2) : (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) -> !llvm<"<8 x float>"> // CHECK: call void @llvm.prefetch.p0i8(i8* %3, i32 0, i32 3, i32 1) "llvm.intr.prefetch"(%arg3, %c0, %c3, %c1) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm.i32) -> () llvm.return @@ -96,23 +100,25 @@ } // Check that intrinsics are declared with appropriate types. -// CHECK: declare float @llvm.fmuladd.f32(float, float, float) -// CHECK: declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 -// CHECK: declare void @llvm.prefetch.p0i8(i8* nocapture readonly, i32 immarg, i32 immarg, i32) -// CHECK: declare float @llvm.exp.f32(float) -// CHECK: declare <8 x float> @llvm.exp.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.log.f32(float) -// CHECK: declare <8 x float> @llvm.log.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.log10.f32(float) -// CHECK: declare <8 x float> @llvm.log10.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.log2.f32(float) -// CHECK: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.fabs.f32(float) -// CHECK: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.sqrt.f32(float) -// CHECK: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.ceil.f32(float) -// CHECK: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.cos.f32(float) -// CHECK: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 -// CHECK: declare float @llvm.copysign.f32(float, float) +// CHECK-DAG: declare float @llvm.fma.f32(float, float, float) +// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 +// CHECK-DAG: declare float @llvm.fmuladd.f32(float, float, float) +// CHECK-DAG: declare <8 x float> @llvm.fmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 +// CHECK-DAG: declare void @llvm.prefetch.p0i8(i8* nocapture readonly, i32 immarg, i32 immarg, i32) +// CHECK-DAG: declare float @llvm.exp.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.exp.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.log.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.log.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.log10.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.log10.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.log2.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.log2.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.fabs.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.fabs.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.sqrt.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.ceil.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.cos.f32(float) +// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0 +// CHECK-DAG: declare float @llvm.copysign.f32(float, float)