diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -34,20 +34,15 @@ LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - // Early exit if there is no condition. - if (!op.getIfCond()) - return success(); - - // Condition is not a constant. - if (!op.getIfCond().template getDefiningOp()) { - auto ifOp = rewriter.create(op.getLoc(), TypeRange(), - op.getIfCond(), false); - rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); - auto thenBodyBuilder = ifOp.getThenBodyBuilder(); - thenBodyBuilder.setListener(rewriter.getListener()); - thenBodyBuilder.clone(*op.getOperation()); - rewriter.eraseOp(op); - } + assert(op.getIfCond() && "expected Op with IfCond"); + + auto ifOp = rewriter.create(op.getLoc(), TypeRange(), + op.getIfCond(), false); + rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); + auto thenBodyBuilder = ifOp.getThenBodyBuilder(); + thenBodyBuilder.setListener(rewriter.getListener()); + thenBodyBuilder.clone(*op.getOperation()); + rewriter.eraseOp(op); return success(); } @@ -58,6 +53,10 @@ patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); + acc::EnterDataOp::getCanonicalizationPatterns(patterns, + patterns.getContext()); + acc::ExitDataOp::getCanonicalizationPatterns(patterns, patterns.getContext()); + acc::UpdateOp::getCanonicalizationPatterns(patterns, patterns.getContext()); } namespace { diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir --- a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir +++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir @@ -33,3 +33,74 @@ // CHECK: scf.if [[IFCOND]] { // CHECK-NEXT: acc.update host(%{{.*}} : memref<10xf32>) // CHECK-NEXT: } + +// ----- + +func.func @update_true(%arg0: memref<10xf32, #spirv.storage_class>) { + %true = arith.constant true + acc.update if(%true) host(%arg0 : memref<10xf32, #spirv.storage_class>) + return +} + +// CHECK-LABEL: func.func @update_true +// CHECK-NOT:if +// CHECK:acc.update host + +// ----- + +func.func @update_false(%arg0: memref<10xf32, #spirv.storage_class>) { + %false = arith.constant false + acc.update if(%false) host(%arg0 : memref<10xf32, #spirv.storage_class>) + return +} + +// CHECK-LABEL: func.func @update_false +// CHECK-NOT:acc.update + +// ----- + +func.func @enter_data_true(%d1 : memref<10xf32>) { + %true = arith.constant true + acc.enter_data if(%true) create(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @enter_data_true +// CHECK-NOT:if +// CHECK:acc.enter_data create + +// ----- + +func.func @enter_data_false(%d1 : memref<10xf32>) { + %false = arith.constant false + acc.enter_data if(%false) create(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @enter_data_false +// CHECK-NOT:acc.enter_data + +// ----- + +func.func @exit_data_true(%d1 : memref<10xf32>) { + %true = arith.constant true + acc.exit_data if(%true) delete(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @exit_data_true +// CHECK-NOT:if +// CHECK:acc.exit_data delete + +// ----- + +func.func @exit_data_false(%d1 : memref<10xf32>) { + %false = arith.constant false + acc.exit_data if(%false) delete(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @exit_data_false +// CHECK-NOT:acc.exit_data + +// -----