diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -16,18 +16,23 @@ using namespace mlir; namespace { -struct ParallelOpConversion : public ConvertToLLVMPattern { - explicit ParallelOpConversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context, +/// A pattern that converts the region arguments in a single-region OpenMP +/// operation to the LLVM dialect. The body of the region is not modified and is +/// expected to either be processed by the conversion infrastructure or already +/// contain ops compatible with LLVM dialect types. +template +struct RegionOpConversion : public ConvertToLLVMPattern { + explicit RegionOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(OpType::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto curOp = cast(op); - auto newOp = rewriter.create(curOp.getLoc(), TypeRange(), - operands, curOp.getAttrs()); + auto curOp = cast(op); + auto newOp = rewriter.create(curOp.getLoc(), TypeRange(), operands, + curOp.getAttrs()); rewriter.inlineRegionBefore(curOp.region(), newOp.region(), newOp.region().end()); if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) @@ -42,7 +47,8 @@ void mlir::populateOpenMPToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert(context, converter); + patterns.insert, + RegionOpConversion>(context, converter); } namespace { @@ -63,8 +69,8 @@ populateOpenMPToLLVMConversionPatterns(context, converter, patterns); LLVMConversionTarget target(getContext()); - target.addDynamicallyLegalOp( - [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); + target.addDynamicallyLegalOp( + [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); }); target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -28,3 +28,22 @@ } return } + +// CHECK-LABEL: @wsloop +// CHECK: (%[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64, %[[ARG3:.*]]: !llvm.i64, %[[ARG4:.*]]: !llvm.i64, %[[ARG5:.*]]: !llvm.i64) +func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { + // CHECK: omp.parallel + omp.parallel { + // CHECK: omp.wsloop + // CHECK: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) + "omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( { + // CHECK: ^{{.*}}(%[[ARG6:.*]]: !llvm.i64, %[[ARG7:.*]]: !llvm.i64): + ^bb0(%arg6: index, %arg7: index): // no predecessors + // CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (!llvm.i64, !llvm.i64) -> () + "test.payload"(%arg6, %arg7) : (index, index) -> () + omp.yield + }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (index, index, index, index, index, index) -> () + omp.terminator + } + return +}