diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -103,7 +103,7 @@ // Base class for standard arithmetic operations. Requires operands and // results to be of the same type, but does not constrain them to specific -// types. Individual classes will have `lhs` and `rhs` accessor to operands. +// types. class ArithmeticOp traits = []> : Op traits = []> : + ArithmeticOp { + + let parser = [{ + return impl::parseOneResultSameOperandTypeOp(parser, result); + }]; + + let printer = [{ + return printStandardBinaryOp(this->getOperation(), p); + }]; +} + +// Base class for standard ternary arithmetic operations. +class ArithmeticTernaryOp traits = []> : + ArithmeticOp { + + let parser = [{ + return impl::parseOneResultSameOperandTypeOp(parser, result); + }]; + + let printer = [{ + return printStandardTernaryOp(this->getOperation(), p); + }]; +} + // Base class for standard arithmetic operations on integers, vectors and // tensors thereof. This operation takes two operands and returns one result, // each of these is required to be of the same type. This type may be an @@ -130,8 +156,8 @@ // // i %0, %1 : i32 // -class IntArithmeticOp traits = []> : - ArithmeticOp traits = []> : + ArithmeticBinaryOp])>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>; @@ -145,12 +171,27 @@ // // f %0, %1 : f32 // -class FloatArithmeticOp traits = []> : - ArithmeticOp traits = []> : + ArithmeticBinaryOp])>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; +// Base class for standard arithmetic ternary operations on floats, vectors and +// tensors thereof. This operation has three operands and returns one result, +// each of these is required to be of the same type. This type may be a +// floating point scalar type, a vector whose element type is a floating point +// type, or a floating point tensor. The custom assembly form of the operation +// is as follows +// +// %0, %1, %2 : f32 +// +class FloatTernaryOp traits = []> : + ArithmeticTernaryOp])>, + Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>; + // Base class for memref allocating ops: alloca and alloc. // // %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> @@ -257,7 +298,7 @@ // AddFOp //===----------------------------------------------------------------------===// -def AddFOp : FloatArithmeticOp<"addf"> { +def AddFOp : FloatBinaryOp<"addf"> { let summary = "floating point addition operation"; let description = [{ Syntax: @@ -294,7 +335,7 @@ // AddIOp //===----------------------------------------------------------------------===// -def AddIOp : IntArithmeticOp<"addi", [Commutative]> { +def AddIOp : IntBinaryOp<"addi", [Commutative]> { let summary = "integer addition operation"; let description = [{ Syntax: @@ -418,7 +459,7 @@ // AndOp //===----------------------------------------------------------------------===// -def AndOp : IntArithmeticOp<"and", [Commutative]> { +def AndOp : IntBinaryOp<"and", [Commutative]> { let summary = "integer binary and"; let description = [{ Syntax: @@ -1269,7 +1310,7 @@ // CopySignOp //===----------------------------------------------------------------------===// -def CopySignOp : FloatArithmeticOp<"copysign"> { +def CopySignOp : FloatBinaryOp<"copysign"> { let summary = "A copysign operation"; let description = [{ Syntax: @@ -1384,11 +1425,49 @@ // DivFOp //===----------------------------------------------------------------------===// -def DivFOp : FloatArithmeticOp<"divf"> { +def DivFOp : FloatBinaryOp<"divf"> { let summary = "floating point division operation"; let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// FmaFOp +//===----------------------------------------------------------------------===// + +def FmaFOp : FloatTernaryOp<"fmaf"> { + let summary = "floating point fused multipy-add operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type + ``` + + The `fmaf` operation takes three operands and returns one result, each of + these is required to be the same type. This type may be a floating point + scalar type, a vector whose element type is a floating point type, or a + floating point tensor. + + Example: + + ```mlir + // Scalar fused multiply-add: d = a*b + c + %d = fmaf %a, %b, %c : f64 + + // SIMD vector fused multiply-add, e.g. for Intel SSE. + %i = fmaf %f, %g, %h : vector<4xf32> + + // Tensor fused multiply-add. + %w = fmaf %x, %y, %z : tensor<4x?xbf16> + ``` + + The semantics of the operation correspond to those of the `llvm.fma` + [intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the + particular case of lowering to LLVM, this is guaranteed to lower + to the `llvm.fma.*` intrinsic. + }]; +} + //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// @@ -1854,7 +1933,7 @@ // MulFOp //===----------------------------------------------------------------------===// -def MulFOp : FloatArithmeticOp<"mulf"> { +def MulFOp : FloatBinaryOp<"mulf"> { let summary = "floating point multiplication operation"; let description = [{ Syntax: @@ -1891,7 +1970,7 @@ // MulIOp //===----------------------------------------------------------------------===// -def MulIOp : IntArithmeticOp<"muli", [Commutative]> { +def MulIOp : IntBinaryOp<"muli", [Commutative]> { let summary = "integer multiplication operation"; let hasFolder = 1; } @@ -1933,7 +2012,7 @@ // OrOp //===----------------------------------------------------------------------===// -def OrOp : IntArithmeticOp<"or", [Commutative]> { +def OrOp : IntBinaryOp<"or", [Commutative]> { let summary = "integer binary or"; let description = [{ Syntax: @@ -2040,7 +2119,7 @@ // RemFOp //===----------------------------------------------------------------------===// -def RemFOp : FloatArithmeticOp<"remf"> { +def RemFOp : FloatBinaryOp<"remf"> { let summary = "floating point division remainder operation"; } @@ -2141,7 +2220,7 @@ // ShiftLeftOp //===----------------------------------------------------------------------===// -def ShiftLeftOp : IntArithmeticOp<"shift_left"> { +def ShiftLeftOp : IntBinaryOp<"shift_left"> { let summary = "integer left-shift"; let description = [{ The shift_left operation shifts an integer value to the left by a variable @@ -2161,7 +2240,7 @@ // SignedDivIOp //===----------------------------------------------------------------------===// -def SignedDivIOp : IntArithmeticOp<"divi_signed"> { +def SignedDivIOp : IntBinaryOp<"divi_signed"> { let summary = "signed integer division operation"; let description = [{ Syntax: @@ -2196,7 +2275,7 @@ // SignedFloorDivIOp //===----------------------------------------------------------------------===// -def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> { +def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> { let summary = "signed floor integer division operation"; let description = [{ Syntax: @@ -2225,7 +2304,7 @@ // SignedCeilDivIOp //===----------------------------------------------------------------------===// -def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> { +def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> { let summary = "signed ceil integer division operation"; let description = [{ Syntax: @@ -2253,7 +2332,7 @@ // SignedRemIOp //===----------------------------------------------------------------------===// -def SignedRemIOp : IntArithmeticOp<"remi_signed"> { +def SignedRemIOp : IntBinaryOp<"remi_signed"> { let summary = "signed integer division remainder operation"; let description = [{ Syntax: @@ -2288,7 +2367,7 @@ // SignedShiftRightOp //===----------------------------------------------------------------------===// -def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { +def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> { let summary = "signed integer right-shift"; let description = [{ The shift_right_signed operation shifts an integer value to the right by @@ -2488,7 +2567,7 @@ // SubFOp //===----------------------------------------------------------------------===// -def SubFOp : FloatArithmeticOp<"subf"> { +def SubFOp : FloatBinaryOp<"subf"> { let summary = "floating point subtraction operation"; let hasFolder = 1; } @@ -2497,7 +2576,7 @@ // SubIOp //===----------------------------------------------------------------------===// -def SubIOp : IntArithmeticOp<"subi"> { +def SubIOp : IntBinaryOp<"subi"> { let summary = "integer subtraction operation"; let hasFolder = 1; } @@ -3173,7 +3252,7 @@ // UnsignedDivIOp //===----------------------------------------------------------------------===// -def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> { +def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> { let summary = "unsigned integer division operation"; let description = [{ Syntax: @@ -3208,7 +3287,7 @@ // UnsignedRemIOp //===----------------------------------------------------------------------===// -def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> { +def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> { let summary = "unsigned integer division remainder operation"; let description = [{ Syntax: @@ -3243,7 +3322,7 @@ // UnsignedShiftRightOp //===----------------------------------------------------------------------===// -def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> { +def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> { let summary = "unsigned integer right-shift"; let description = [{ The shift_right_unsigned operation shifts an integer value to the right by @@ -3332,7 +3411,7 @@ // XOrOp //===----------------------------------------------------------------------===// -def XOrOp : IntArithmeticOp<"xor", [Commutative]> { +def XOrOp : IntBinaryOp<"xor", [Commutative]> { let summary = "integer binary xor"; let description = [{ The `xor` operation takes two operands and returns one result, each of these diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1662,6 +1662,7 @@ using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; +using FmaFOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; @@ -3775,6 +3776,7 @@ ExpOpLowering, Exp2OpLowering, FloorFOpLowering, + FmaFOpLowering, GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -158,6 +158,32 @@ p << " : " << op->getResult(0).getType(); } +/// A custom ternary operation printer that omits the "std." prefix from the +/// operation names. +static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) { + assert(op->getNumOperands() == 3 && "ternary op should have three operands"); + assert(op->getNumResults() == 1 && "ternary op should have one result"); + + // If not all the operand and result types are the same, just use the + // generic assembly form to avoid omitting information in printing. + auto resultType = op->getResult(0).getType(); + if (op->getOperand(0).getType() != resultType || + op->getOperand(1).getType() != resultType || + op->getOperand(2).getType() != resultType) { + p.printGenericOp(op); + return; + } + + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' + << op->getOperand(0) << ", " << op->getOperand(1) << ", " + << op->getOperand(2); + p.printOptionalAttrDict(op->getAttrs()); + + // Now we can output only one type for all operands and the result. + p << " : " << op->getResult(0).getType(); +} + /// A custom cast operation printer that omits the "std." prefix from the /// operation names. static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -223,3 +223,16 @@ %0 = math.powf %arg0, %arg0 : f64 std.return } + +// ----- + +// CHECK-LABEL: func @fmaf( +// CHECK-SAME: %[[ARG0:.*]]: f32 +// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32> +func @fmaf(%arg0: f32, %arg1: vector<4xf32>) { + // CHECK: %[[S:.*]] = "llvm.intr.fma"(%[[ARG0]], %[[ARG0]], %[[ARG0]]) : (f32, f32, f32) -> f32 + %0 = fmaf %arg0, %arg0, %arg0 : f32 + // CHECK: %[[V:.*]] = "llvm.intr.fma"(%[[ARG1]], %[[ARG1]], %[[ARG1]]) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> + %1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32> + std.return +}