diff --git a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp --- a/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp +++ b/mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -36,10 +37,10 @@ PatternRewriter &rewriter) const override { // Early exit if there is no condition. if (!op.getIfCond()) - return success(); + return failure(); - // Condition is not a constant. - if (!op.getIfCond().template getDefiningOp()) { + IntegerAttr constAttr; + if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) { auto ifOp = rewriter.create(op.getLoc(), TypeRange(), op.getIfCond(), false); rewriter.updateRootInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); @@ -47,8 +48,13 @@ thenBodyBuilder.setListener(rewriter.getListener()); thenBodyBuilder.clone(*op.getOperation()); rewriter.eraseOp(op); + } else { + if (constAttr.getInt()) + rewriter.updateRootInPlace(op, + [&]() { op.getIfCondMutable().erase(0); }); + else + rewriter.eraseOp(op); } - return success(); } }; diff --git a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir --- a/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir +++ b/mlir/test/Conversion/OpenACCToSCF/convert-openacc-to-scf.mlir @@ -33,3 +33,74 @@ // CHECK: scf.if [[IFCOND]] { // CHECK-NEXT: acc.update host(%{{.*}} : memref<10xf32>) // CHECK-NEXT: } + +// ----- + +func.func @update_true(%arg0: memref<10xf32, #spirv.storage_class>) { + %true = arith.constant true + acc.update if(%true) host(%arg0 : memref<10xf32, #spirv.storage_class>) + return +} + +// CHECK-LABEL: func.func @update_true +// CHECK-NOT:if +// CHECK:acc.update host + +// ----- + +func.func @update_false(%arg0: memref<10xf32, #spirv.storage_class>) { + %false = arith.constant false + acc.update if(%false) host(%arg0 : memref<10xf32, #spirv.storage_class>) + return +} + +// CHECK-LABEL: func.func @update_false +// CHECK-NOT:acc.update + +// ----- + +func.func @enter_data_true(%d1 : memref<10xf32>) { + %true = arith.constant true + acc.enter_data if(%true) create(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @enter_data_true +// CHECK-NOT:if +// CHECK:acc.enter_data create + +// ----- + +func.func @enter_data_false(%d1 : memref<10xf32>) { + %false = arith.constant false + acc.enter_data if(%false) create(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @enter_data_false +// CHECK-NOT:acc.enter_data + +// ----- + +func.func @exit_data_true(%d1 : memref<10xf32>) { + %true = arith.constant true + acc.exit_data if(%true) delete(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @exit_data_true +// CHECK-NOT:if +// CHECK:acc.exit_data delete + +// ----- + +func.func @exit_data_false(%d1 : memref<10xf32>) { + %false = arith.constant false + acc.exit_data if(%false) delete(%d1 : memref<10xf32>) attributes {async} + return +} + +// CHECK-LABEL: func.func @exit_data_false +// CHECK-NOT:acc.exit_data + +// -----