diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -746,10 +746,16 @@ // SCFToOpenMP //===----------------------------------------------------------------------===// -def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> { +def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> { let summary = "Convert SCF parallel loop to OpenMP parallel + workshare " "constructs."; - let constructor = "mlir::createConvertSCFToOpenMPPass()"; + + let options = [ + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"false", "Generate LLVM IR using opaque pointers " + "instead of typed pointers"> + ]; + let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect", "memref::MemRefDialect"]; } diff --git a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h --- a/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h +++ b/mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h @@ -12,15 +12,11 @@ #include namespace mlir { -class ModuleOp; -template -class OperationPass; +class Pass; -#define GEN_PASS_DECL_CONVERTSCFTOOPENMP +#define GEN_PASS_DECL_CONVERTSCFTOOPENMPPASS #include "mlir/Conversion/Passes.h.inc" -std::unique_ptr> createConvertSCFToOpenMPPass(); - } // namespace mlir #endif // MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -26,7 +26,7 @@ #include "mlir/Transforms/DialectConversion.h" namespace mlir { -#define GEN_PASS_DEF_CONVERTSCFTOOPENMP +#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir @@ -212,22 +212,32 @@ return decl; } +/// Returns an LLVM pointer type with the given element type, or an opaque +/// pointer if 'useOpaquePointers' is true. +static LLVM::LLVMPointerType getPointerType(Type elementType, + bool useOpaquePointers) { + if (useOpaquePointers) + return LLVM::LLVMPointerType::get(elementType.getContext()); + return LLVM::LLVMPointerType::get(elementType); +} + /// Adds an atomic reduction combiner to the given OpenMP reduction declaration /// using llvm.atomicrmw of the given kind. static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder, LLVM::AtomicBinOp atomicKind, omp::ReductionDeclareOp decl, - scf::ReduceOp reduce) { + scf::ReduceOp reduce, + bool useOpaquePointers) { OpBuilder::InsertionGuard guard(builder); Type type = reduce.getOperand().getType(); - Type ptrType = LLVM::LLVMPointerType::get(type); + Type ptrType = getPointerType(type, useOpaquePointers); Location reduceOperandLoc = reduce.getOperand().getLoc(); builder.createBlock(&decl.getAtomicReductionRegion(), decl.getAtomicReductionRegion().end(), {ptrType, ptrType}, {reduceOperandLoc, reduceOperandLoc}); Block *atomicBlock = &decl.getAtomicReductionRegion().back(); builder.setInsertionPointToEnd(atomicBlock); - Value loaded = builder.create(reduce.getLoc(), + Value loaded = builder.create(reduce.getLoc(), decl.getType(), atomicBlock->getArgument(1)); builder.create(reduce.getLoc(), atomicKind, atomicBlock->getArgument(0), loaded, @@ -241,7 +251,8 @@ /// the neutral value, necessary for the OpenMP declaration. If the reduction /// cannot be recognized, returns null. static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder, - scf::ReduceOp reduce) { + scf::ReduceOp reduce, + bool useOpaquePointers) { Operation *container = SymbolTable::getNearestSymbolTable(reduce); SymbolTable symbolTable(container); @@ -262,29 +273,34 @@ if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getFloatAttr(type, 0.0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce); + return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce, + useOpaquePointers); } if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce); + return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce, + useOpaquePointers); } if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce, + useOpaquePointers); } if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce, builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce, + useOpaquePointers); } if (matchSimpleReduction(reduction)) { omp::ReductionDeclareOp decl = createDecl( builder, symbolTable, reduce, builder.getIntegerAttr( type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth()))); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce); + return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce, + useOpaquePointers); } // Match simple binary reductions that cannot be expressed with atomicrmw. @@ -316,7 +332,7 @@ builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin)); return addAtomicRMW(builder, isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, - decl, reduce); + decl, reduce, useOpaquePointers); } if (matchSelectReduction( reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, @@ -328,7 +344,7 @@ builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin)); return addAtomicRMW( builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, - decl, reduce); + decl, reduce, useOpaquePointers); } return nullptr; @@ -337,7 +353,12 @@ namespace { struct ParallelOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + + bool useOpaquePointers; + + ParallelOpLowering(MLIRContext *context, bool useOpaquePointers) + : OpRewritePattern(context), + useOpaquePointers(useOpaquePointers) {} LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override { @@ -346,7 +367,8 @@ // declaration and use it instead of redeclaring. SmallVector reductionDeclSymbols; for (auto reduce : parallelOp.getOps()) { - omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce); + omp::ReductionDeclareOp decl = + declareReduction(rewriter, reduce, useOpaquePointers); if (!decl) return failure(); reductionDeclSymbols.push_back( @@ -366,7 +388,8 @@ "cannot create a reduction variable if the type is not an LLVM " "pointer element"); Value storage = rewriter.create( - loc, LLVM::LLVMPointerType::get(init.getType()), one, 0); + loc, getPointerType(init.getType(), useOpaquePointers), + init.getType(), one, 0); rewriter.create(loc, init, storage); reductionVariables.push_back(storage); } @@ -426,8 +449,9 @@ // Load loop results. SmallVector results; results.reserve(reductionVariables.size()); - for (Value variable : reductionVariables) { - Value res = rewriter.create(loc, variable); + for (auto [variable, type] : + llvm::zip(reductionVariables, parallelOp.getResultTypes())) { + Value res = rewriter.create(loc, type, variable); results.push_back(res); } rewriter.replaceOp(parallelOp, results); @@ -437,29 +461,29 @@ }; /// Applies the conversion patterns in the given function. -static LogicalResult applyPatterns(ModuleOp module) { +static LogicalResult applyPatterns(ModuleOp module, bool useOpaquePointers) { ConversionTarget target(*module.getContext()); target.addIllegalOp(); target.addLegalDialect(); RewritePatternSet patterns(module.getContext()); - patterns.add(module.getContext()); + patterns.add(module.getContext(), useOpaquePointers); FrozenRewritePatternSet frozen(std::move(patterns)); return applyPartialConversion(module, target, frozen); } /// A pass converting SCF operations to OpenMP operations. -struct SCFToOpenMPPass : public impl::ConvertSCFToOpenMPBase { +struct SCFToOpenMPPass + : public impl::ConvertSCFToOpenMPPassBase { + + using Base::Base; + /// Pass entry point. void runOnOperation() override { - if (failed(applyPatterns(getOperation()))) + if (failed(applyPatterns(getOperation(), useOpaquePointers))) signalPassFailure(); } }; } // namespace - -std::unique_ptr> mlir::createConvertSCFToOpenMPPass() { - return std::make_unique(); -} diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-scf-to-openmp -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=1' -split-input-file %s | FileCheck %s // CHECK: omp.reduction.declare @[[$REDF:.*]] : f32 @@ -12,8 +12,8 @@ // CHECK: omp.yield(%[[RES]] : f32) // CHECK: atomic -// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): -// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] +// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): +// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> f32 // CHECK: llvm.atomicrmw fadd %[[ARG0]], %[[RHS]] monotonic // CHECK-LABEL: @reduction1 @@ -143,8 +143,8 @@ // CHECK: omp.yield(%[[RES]] : i64) // CHECK: atomic -// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): -// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] +// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): +// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> i64 // CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic // CHECK-LABEL: @reduction4 @@ -187,8 +187,8 @@ // CHECK: omp.yield } // CHECK: omp.terminator - // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] - // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] + // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32 + // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64 // CHECK: return %[[RES1]], %[[RES2]] return %res#0, %res#1 : f32, i64 } diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s +// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=1' %s | FileCheck %s // CHECK-LABEL: @parallel func.func @parallel(%arg0: index, %arg1: index, %arg2: index, diff --git a/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir b/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=0' -split-input-file %s | FileCheck %s + +// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4 +// CHECK: omp.yield(%[[INIT]] : f32) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK: %[[CMP:.*]] = arith.cmpf oge, %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG0]], %[[ARG1]] +// CHECK: omp.yield(%[[RES]] : f32) + +// CHECK-NOT: atomic + +// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64 + +// CHECK: init +// CHECK: %[[INIT:.*]] = llvm.mlir.constant +// CHECK: omp.yield(%[[INIT]] : i64) + +// CHECK: combiner +// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) +// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG0]] +// CHECK: omp.yield(%[[RES]] : i64) + +// CHECK: atomic +// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr): +// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] +// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic + +// CHECK-LABEL: @reduction4 +func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index) -> (f32, i64) { + %step = arith.constant 1 : index + // CHECK: %[[ZERO:.*]] = arith.constant 0.0 + %zero = arith.constant 0.0 : f32 + // CHECK: %[[IONE:.*]] = arith.constant 1 + %ione = arith.constant 1 : i64 + // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32 + // CHECK: llvm.store %[[ZERO]], %[[BUF1]] + // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64 + // CHECK: llvm.store %[[IONE]], %[[BUF2]] + + // CHECK: omp.parallel + // CHECK: omp.wsloop + // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]] + // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]] + // CHECK: memref.alloca_scope + %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init (%zero, %ione) -> (f32, i64) { + %one = arith.constant 1.0 : f32 + // CHECK: omp.reduction %{{.*}}, %[[BUF1]] + scf.reduce(%one) : f32 { + ^bb0(%lhs : f32, %rhs: f32): + %cmp = arith.cmpf oge, %lhs, %rhs : f32 + %res = arith.select %cmp, %lhs, %rhs : f32 + scf.reduce.return %res : f32 + } + // CHECK: arith.fptosi + %1 = arith.fptosi %one : f32 to i64 + // CHECK: omp.reduction %{{.*}}, %[[BUF2]] + scf.reduce(%1) : i64 { + ^bb1(%lhs: i64, %rhs: i64): + %cmp = arith.cmpi slt, %lhs, %rhs : i64 + %res = arith.select %cmp, %rhs, %lhs : i64 + scf.reduce.return %res : i64 + } + // CHECK: omp.yield + } + // CHECK: omp.terminator + // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr + // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr + // CHECK: return %[[RES1]], %[[RES2]] + return %res#0, %res#1 : f32, i64 +}