diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h --- a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h +++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ #include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/Complex/IR/Complex.h" namespace mlir { class LLVMTypeConverter; diff --git a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Complex/IR/CMakeLists.txt @@ -2,6 +2,11 @@ add_mlir_doc(ComplexOps ComplexOps Dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS ComplexAttributes.td) -mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls) -mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs) +mlir_tablegen(ComplexAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=complex) +mlir_tablegen(ComplexAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=complex) add_public_tablegen_target(MLIRComplexAttributesIncGen) + +set(LLVM_TARGET_DEFINITIONS ComplexOpsInterfaces.td) +mlir_tablegen(ComplexOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ComplexOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRComplexOpsInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -21,6 +22,15 @@ #include "mlir/Dialect/Complex/IR/ComplexOpsDialect.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Complex/IR/ComplexAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// Complex Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Complex/IR/ComplexOpsInterfaces.h.inc" + //===----------------------------------------------------------------------===// // Complex Dialect Operations //===----------------------------------------------------------------------===// @@ -28,7 +38,4 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Complex/IR/ComplexOps.h.inc" -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Complex/IR/ComplexAttributes.h.inc" - #endif // MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -9,7 +9,9 @@ #ifndef COMPLEX_OPS #define COMPLEX_OPS +include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Complex/IR/ComplexBase.td" +include "mlir/Dialect/Complex/IR/ComplexOpsInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -22,10 +24,11 @@ // one result, all of which must be complex numbers of the same type. class ComplexArithmeticOp traits = []> : Complex_Op { - let arguments = (ins Complex:$lhs, Complex:$rhs); + Elementwise, DeclareOpInterfaceMethods]> { + let arguments = (ins Complex:$lhs, Complex:$rhs, DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); let results = (outs Complex:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + let assemblyFormat = "$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result)"; } // Base class for standard unary operations on complex numbers with a @@ -33,8 +36,9 @@ // one result; the operand must be a complex number. class ComplexUnaryOp traits = []> : Complex_Op { - let arguments = (ins Complex:$complex); - let assemblyFormat = "$complex attr-dict `:` type($complex)"; + let arguments = (ins Complex:$complex, DefaultValuedAttr< + Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath); + let assemblyFormat = "$complex (`fastmath` `` $fastmath^)? attr-dict `:` type($complex)"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOpsInterfaces.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOpsInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOpsInterfaces.td @@ -0,0 +1,53 @@ +//===-- ComplexOpsInterfaces.td - complex op interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the interfaces definition file for complex dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef COMPLEX_OPS_INTERFACES +#define COMPLEX_OPS_INTERFACES + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Complex/IR/ComplexBase.td" + +def ComplexFastMathInterface : OpInterface<"ComplexFastMathInterface"> { + let description = [{ + Access to operation fastmath flags. + }]; + + let cppNamespace = "::mlir::complex"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "::mlir::arith::FastMathFlagsAttr", + /*methodName=*/ "getFastMathFlagsAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getFastMathAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "fastmath"; + }] + > + + ]; +} + +#endif // COMPLEX_OPS_INTERFACES diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -22,6 +22,7 @@ using namespace mlir; using namespace mlir::LLVM; +using namespace mlir::arith; //===----------------------------------------------------------------------===// // ComplexStructBuilder implementation. @@ -54,6 +55,25 @@ return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); } +// Map arithmetic fastmath enum values to LLVMIR enum values. +static LLVM::FastmathFlags +convertComplexFastMathFlagsToLLVM(arith::FastMathFlags complexFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContainsAny(complexFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + //===----------------------------------------------------------------------===// // Conversion patterns. //===----------------------------------------------------------------------===// @@ -180,7 +200,10 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + auto complexFMFAttr = op.getFastMathFlagsAttr(); + auto fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertComplexFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = @@ -208,7 +231,10 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + auto complexFMFAttr = op.getFastMathFlagsAttr(); + auto fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertComplexFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -253,7 +279,10 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + auto complexFMFAttr = op.getFastMathFlagsAttr(); + auto fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertComplexFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value rhsRe = arg.rhs.real(); Value rhsIm = arg.rhs.imag(); Value lhsRe = arg.lhs.real(); @@ -290,7 +319,10 @@ auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. - auto fmf = LLVM::FastmathFlagsAttr::get(op.getContext(), {}); + auto complexFMFAttr = op.getFastMathFlagsAttr(); + auto fmf = LLVM::FastmathFlagsAttr::get( + op.getContext(), + convertComplexFastMathFlagsToLLVM(complexFMFAttr.getValue())); Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real(), fmf); Value imag = diff --git a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Complex/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Complex/IR/CMakeLists.txt @@ -8,6 +8,7 @@ DEPENDS MLIRComplexOpsIncGen MLIRComplexAttributesIncGen + MLIRComplexOpsInterfacesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -150,3 +150,110 @@ // CHECK: %[[NORM:.*]] = llvm.intr.sqrt(%[[SQ_NORM]]) : (f32) -> f32 // CHECK: return %[[NORM]] : f32 +// CHECK-LABEL: func @complex_addition_with_fmf +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fadd %[[A_REAL]], %[[B_REAL]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fadd %[[A_IMAG]], %[[B_IMAG]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> +func.func @complex_addition_with_fmf() { + %a_re = arith.constant 1.2 : f64 + %a_im = arith.constant 3.4 : f64 + %a = complex.create %a_re, %a_im : complex + %b_re = arith.constant 5.6 : f64 + %b_im = arith.constant 7.8 : f64 + %b = complex.create %b_re, %b_im : complex + %c = complex.add %a, %b fastmath : complex + return +} + +// CHECK-LABEL: func @complex_substraction_with_fmf +// CHECK-DAG: %[[A_REAL:.*]] = llvm.extractvalue %[[A:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_REAL:.*]] = llvm.extractvalue %[[B:.*]][0] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[A_IMAG:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[B_IMAG:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(f64, f64)> +// CHECK: %[[C0:.*]] = llvm.mlir.undef : !llvm.struct<(f64, f64)> +// CHECK-DAG: %[[C_REAL:.*]] = llvm.fsub %[[A_REAL]], %[[B_REAL]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK-DAG: %[[C_IMAG:.*]] = llvm.fsub %[[A_IMAG]], %[[B_IMAG]] {fastmathFlags = #llvm.fastmath} : f64 +// CHECK: %[[C1:.*]] = llvm.insertvalue %[[C_REAL]], %[[C0]][0] : !llvm.struct<(f64, f64)> +// CHECK: %[[C2:.*]] = llvm.insertvalue %[[C_IMAG]], %[[C1]][1] : !llvm.struct<(f64, f64)> +func.func @complex_substraction_with_fmf() { + %a_re = arith.constant 1.2 : f64 + %a_im = arith.constant 3.4 : f64 + %a = complex.create %a_re, %a_im : complex + %b_re = arith.constant 5.6 : f64 + %b_im = arith.constant 7.8 : f64 + %b = complex.create %b_re, %b_im : complex + %c = complex.sub %a, %b fastmath : complex + return +} + +// CHECK-LABEL: func @complex_div_with_fmf +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] + +// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] + +// CHECK-DAG: %[[RHS_RE_SQ:.*]] = llvm.fmul %[[RHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[RHS_IM_SQ:.*]] = llvm.fmul %[[RHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[SQ_NORM:.*]] = llvm.fadd %[[RHS_RE_SQ]], %[[RHS_IM_SQ]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[REAL_TMP_2:.*]] = llvm.fadd %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[IMAG_TMP_2:.*]] = llvm.fsub %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK: %[[REAL:.*]] = llvm.fdiv %[[REAL_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] : ![[C_TY]] +// CHECK: %[[IMAG:.*]] = llvm.fdiv %[[IMAG_TMP_2]], %[[SQ_NORM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] : ![[C_TY]] +// +// CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// CHECK: return %[[CASTED_RESULT]] : complex +func.func @complex_div_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %div = complex.div %lhs, %rhs fastmath : complex + return %div : complex +} + + +// CHECK-LABEL: func @complex_mul_with_fmf +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] + +// CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] +// CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] +// CHECK: %[[RHS_RE:.*]] = llvm.extractvalue %[[CASTED_RHS]][0] : ![[C_TY]] +// CHECK: %[[RHS_IM:.*]] = llvm.extractvalue %[[CASTED_RHS]][1] : ![[C_TY]] +// CHECK: %[[RESULT_0:.*]] = llvm.mlir.undef : ![[C_TY]] + +// CHECK-DAG: %[[REAL_TMP_0:.*]] = llvm.fmul %[[RHS_RE]], %[[LHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[REAL_TMP_1:.*]] = llvm.fmul %[[RHS_IM]], %[[LHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[REAL:.*]] = llvm.fsub %[[REAL_TMP_0]], %[[REAL_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK-DAG: %[[IMAG_TMP_0:.*]] = llvm.fmul %[[LHS_IM]], %[[RHS_RE]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK-DAG: %[[IMAG_TMP_1:.*]] = llvm.fmul %[[LHS_RE]], %[[RHS_IM]] {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: %[[IMAG:.*]] = llvm.fadd %[[IMAG_TMP_0]], %[[IMAG_TMP_1]] {fastmathFlags = #llvm.fastmath} : f32 + +// CHECK: %[[RESULT_1:.*]] = llvm.insertvalue %[[REAL]], %[[RESULT_0]][0] +// CHECK: %[[RESULT_2:.*]] = llvm.insertvalue %[[IMAG]], %[[RESULT_1]][1] + +// CHECK: %[[CASTED_RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_2]] : ![[C_TY]] to complex +// CHECK: return %[[CASTED_RESULT]] : complex +func.func @complex_mul_with_fmf(%lhs: complex, %rhs: complex) -> complex { + %mul = complex.mul %lhs, %rhs fastmath : complex + return %mul : complex +}