diff --git a/mlir/benchmark/python/common.py b/mlir/benchmark/python/common.py --- a/mlir/benchmark/python/common.py +++ b/mlir/benchmark/python/common.py @@ -26,7 +26,7 @@ f"sparse-tensor-conversion," f"builtin.func" f"(linalg-bufferize,convert-linalg-to-loops,convert-vector-to-scf)," - f"convert-scf-to-std," + f"convert-scf-to-cf," f"func-bufferize," f"arith-bufferize," f"builtin.func(tensor-bufferize,finalizing-bufferize)," diff --git a/mlir/docs/BufferDeallocationInternals.md b/mlir/docs/BufferDeallocationInternals.md --- a/mlir/docs/BufferDeallocationInternals.md +++ b/mlir/docs/BufferDeallocationInternals.md @@ -41,12 +41,12 @@ ```mlir func @condBranch(%arg0: i1, %arg1: memref<2xf32>) { %0 = memref.alloc() : memref<2xf32> - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: - br ^bb3() + cf.br ^bb3() ^bb2: partial_write(%0, %0) - br ^bb3() + cf.br ^bb3() ^bb3(): test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return @@ -74,13 +74,13 @@ func @mixedAllocation(%arg0: i1) { %0 = memref.alloca() : memref<2xf32> // aliases: %2 %1 = memref.alloc() : memref<2xf32> // aliases: %2 - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: use(%0) - br ^bb3(%0 : memref<2xf32>) + cf.br ^bb3(%0 : memref<2xf32>) ^bb2: use(%1) - br ^bb3(%1 : memref<2xf32>) + cf.br ^bb3(%1 : memref<2xf32>) ^bb3(%2: memref<2xf32>): ... } @@ -129,13 +129,13 @@ ```mlir func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: - br ^bb3(%arg1 : memref<2xf32>) + cf.br ^bb3(%arg1 : memref<2xf32>) ^bb2: %0 = memref.alloc() : memref<2xf32> // aliases: %1 use(%0) - br ^bb3(%0 : memref<2xf32>) + cf.br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1 test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return @@ -150,12 +150,12 @@ ```mlir func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { %0 = memref.alloc() : memref<2xf32> // moved to bb0 - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: - br ^bb3(%arg1 : memref<2xf32>) + cf.br ^bb3(%arg1 : memref<2xf32>) ^bb2: use(%0) - br ^bb3(%0 : memref<2xf32>) + cf.br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return @@ -175,14 +175,14 @@ %arg1: memref, %arg2: memref, %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3: index) + cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index) ^bb1: - br ^bb3(%arg1 : memref) + cf.br ^bb3(%arg1 : memref) ^bb2(%0: index): %1 = memref.alloc(%0) : memref // cannot be moved upwards to the data // dependency to %0 use(%1) - br ^bb3(%1 : memref) + cf.br ^bb3(%1 : memref) ^bb3(%2: memref): test.copy(%2, %arg2) : (memref, memref) -> () return @@ -201,14 +201,14 @@ ```mlir func @branch(%arg0: i1) { %0 = memref.alloc() : memref<2xf32> // aliases: %2 - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: %1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes // aliases: %2 - br ^bb3(%1 : memref<2xf32>) + cf.br ^bb3(%1 : memref<2xf32>) ^bb2: use(%0) - br ^bb3(%0 : memref<2xf32>) + cf.br ^bb3(%0 : memref<2xf32>) ^bb3(%2: memref<2xf32>): … return @@ -233,16 +233,16 @@ ```mlir func @branch(%arg0: i1) { %0 = memref.alloc() : memref<2xf32> - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: %1 = memref.alloc() : memref<2xf32> %3 = bufferization.clone %1 : (memref<2xf32>) -> (memref<2xf32>) memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here - br ^bb3(%3 : memref<2xf32>) + cf.br ^bb3(%3 : memref<2xf32>) ^bb2: use(%0) %4 = bufferization.clone %0 : (memref<2xf32>) -> (memref<2xf32>) - br ^bb3(%4 : memref<2xf32>) + cf.br ^bb3(%4 : memref<2xf32>) ^bb3(%2: memref<2xf32>): … memref.dealloc %2 : memref<2xf32> // free temp buffer %2 @@ -273,23 +273,23 @@ %arg1: memref, // aliases: %3, %4 %arg2: memref, %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3: index) + cf.cond_br %arg0, ^bb1, ^bb2(%arg3: index) ^bb1: - br ^bb6(%arg1 : memref) + cf.br ^bb6(%arg1 : memref) ^bb2(%0: index): %1 = memref.alloc(%0) : memref // cannot be moved upwards due to the data // dependency to %0 // aliases: %2, %3, %4 use(%1) - cond_br %arg0, ^bb3, ^bb4 + cf.cond_br %arg0, ^bb3, ^bb4 ^bb3: - br ^bb5(%1 : memref) + cf.br ^bb5(%1 : memref) ^bb4: - br ^bb5(%1 : memref) + cf.br ^bb5(%1 : memref) ^bb5(%2: memref): // non-crit. alias of %1, since %1 dominates %2 - br ^bb6(%2 : memref) + cf.br ^bb6(%2 : memref) ^bb6(%3: memref): // crit. alias of %arg1 and %2 (in other words %1) - br ^bb7(%3 : memref) + cf.br ^bb7(%3 : memref) ^bb7(%4: memref): // non-crit. alias of %3, since %3 dominates %4 test.copy(%4, %arg2) : (memref, memref) -> () return @@ -306,25 +306,25 @@ %arg1: memref, %arg2: memref, %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3 : index) + cf.cond_br %arg0, ^bb1, ^bb2(%arg3 : index) ^bb1: // temp buffer required due to alias %3 %5 = bufferization.clone %arg1 : (memref) -> (memref) - br ^bb6(%5 : memref) + cf.br ^bb6(%5 : memref) ^bb2(%0: index): %1 = memref.alloc(%0) : memref use(%1) - cond_br %arg0, ^bb3, ^bb4 + cf.cond_br %arg0, ^bb3, ^bb4 ^bb3: - br ^bb5(%1 : memref) + cf.br ^bb5(%1 : memref) ^bb4: - br ^bb5(%1 : memref) + cf.br ^bb5(%1 : memref) ^bb5(%2: memref): %6 = bufferization.clone %1 : (memref) -> (memref) memref.dealloc %1 : memref - br ^bb6(%6 : memref) + cf.br ^bb6(%6 : memref) ^bb6(%3: memref): - br ^bb7(%3 : memref) + cf.br ^bb7(%3 : memref) ^bb7(%4: memref): test.copy(%4, %arg2) : (memref, memref) -> () memref.dealloc %3 : memref // free %3, since %4 is a non-crit. alias of %3 diff --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md --- a/mlir/docs/Diagnostics.md +++ b/mlir/docs/Diagnostics.md @@ -295,7 +295,7 @@ ```mlir // Expect an error on the same line. func @bad_branch() { - br ^missing // expected-error {{reference to an undefined block}} + cf.br ^missing // expected-error {{reference to an undefined block}} } // Expect an error on an adjacent line. diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -114,8 +114,8 @@ /// All operations within the GPU dialect are illegal. addIllegalDialect(); - /// Mark `std.br` and `std.cond_br` as illegal. - addIllegalOp(); + /// Mark `cf.br` and `cf.cond_br` as illegal. + addIllegalOp(); } /// Implement the default legalization handler to handle operations marked as diff --git a/mlir/docs/Dialects/emitc.md b/mlir/docs/Dialects/emitc.md --- a/mlir/docs/Dialects/emitc.md +++ b/mlir/docs/Dialects/emitc.md @@ -23,10 +23,11 @@ Besides operations part of the EmitC dialect, the Cpp targets supports translating the following operations: +* 'cf' Dialect + * `cf.br` + * `cf.cond_br` * 'std' Dialect - * `std.br` * `std.call` - * `std.cond_br` * `std.constant` * `std.return` * 'scf' Dialect diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -391,21 +391,21 @@ ```mlir func @simple(i64, i1) -> i64 { ^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a - cond_br %cond, ^bb1, ^bb2 + cf.cond_br %cond, ^bb1, ^bb2 ^bb1: - br ^bb3(%a: i64) // Branch passes %a as the argument + cf.br ^bb3(%a: i64) // Branch passes %a as the argument ^bb2: %b = arith.addi %a, %a : i64 - br ^bb3(%b: i64) // Branch passes %b as the argument + cf.br ^bb3(%b: i64) // Branch passes %b as the argument // ^bb3 receives an argument, named %c, from predecessors // and passes it on to bb4 along with %a. %a is referenced // directly from its defining operation and is not passed through // an argument of ^bb3. ^bb3(%c: i64): - br ^bb4(%c, %a : i64, i64) + cf.br ^bb4(%c, %a : i64, i64) ^bb4(%d : i64, %e : i64): %0 = arith.addi %d, %e : i64 @@ -525,12 +525,12 @@ ```mlir func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region ^bb0(%a: i64, %cond: i1): // Code dominated by ^bb0 may refer to %a - cond_br %cond, ^bb1, ^bb2 + cf.cond_br %cond, ^bb1, ^bb2 ^bb1: // This def for %value does not dominate ^bb2 %value = "op.convert"(%a) : (i64) -> i64 - br ^bb3(%a: i64) // Branch passes %a as the argument + cf.br ^bb3(%a: i64) // Branch passes %a as the argument ^bb2: accelerator.launch() { // An SSACFG region diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -356,24 +356,24 @@ ``` //===-------------------------------------------===// -Processing operation : 'std.cond_br'(0x60f000001120) { - "std.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () +Processing operation : 'cf.cond_br'(0x60f000001120) { + "cf.cond_br"(%arg0)[^bb2, ^bb2] {operand_segment_sizes = dense<[1, 0, 0]> : vector<3xi32>} : (i1) -> () - * Pattern SimplifyConstCondBranchPred : 'std.cond_br -> ()' { + * Pattern SimplifyConstCondBranchPred : 'cf.cond_br -> ()' { } -> failure : pattern failed to match - * Pattern SimplifyCondBranchIdenticalSuccessors : 'std.cond_br -> ()' { - ** Insert : 'std.br'(0x60b000003690) - ** Replace : 'std.cond_br'(0x60f000001120) + * Pattern SimplifyCondBranchIdenticalSuccessors : 'cf.cond_br -> ()' { + ** Insert : 'cf.br'(0x60b000003690) + ** Replace : 'cf.cond_br'(0x60f000001120) } -> success : pattern applied successfully } -> success : pattern matched //===-------------------------------------------===// ``` -This output is describing the processing of a `std.cond_br` operation. We first +This output is describing the processing of a `cf.cond_br` operation. We first try to apply the `SimplifyConstCondBranchPred`, which fails. From there, another pattern (`SimplifyCondBranchIdenticalSuccessors`) is applied that matches the -`std.cond_br` and replaces it with a `std.br`. +`cf.cond_br` and replaces it with a `cf.br`. ## Debugging diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -560,24 +560,24 @@ func @search_body(%A: memref, %S: memref, %key: i32, %i : i32) { %nj = memref.dim %A, 1 : memref - br ^bb1(0) + cf.br ^bb1(0) ^bb1(%j: i32) %p1 = arith.cmpi "lt", %j, %nj : i32 - cond_br %p1, ^bb2, ^bb5 + cf.cond_br %p1, ^bb2, ^bb5 ^bb2: %v = affine.load %A[%i, %j] : memref %p2 = arith.cmpi "eq", %v, %key : i32 - cond_br %p2, ^bb3(%j), ^bb4 + cf.cond_br %p2, ^bb3(%j), ^bb4 ^bb3(%j: i32) affine.store %j, %S[%i] : memref - br ^bb5 + cf.br ^bb5 ^bb4: %jinc = arith.addi %j, 1 : i32 - br ^bb1(%jinc) + cf.br ^bb1(%jinc) ^bb5: return diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -94,10 +94,11 @@ ```c++ mlir::RewritePatternSet patterns(&getContext()); mlir::populateAffineToStdConversionPatterns(patterns, &getContext()); - mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + mlir::cf::populateSCFToControlFlowConversionPatterns(patterns, &getContext()); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(patterns, &getContext()); // The only remaining operation, to lower from the `toy` dialect, is the // PrintOp. @@ -207,7 +208,7 @@ %109 = memref.load double, double* %108 %110 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @frmt_spec, i64 0, i64 0), double %109) %111 = add i64 %100, 1 - br label %99 + cf.br label %99 ... diff --git a/mlir/docs/includes/img/branch_example_post_move.svg b/mlir/docs/includes/img/branch_example_post_move.svg --- a/mlir/docs/includes/img/branch_example_post_move.svg +++ b/mlir/docs/includes/img/branch_example_post_move.svg @@ -361,7 +361,7 @@ br bb3(%0) + style="font-size:5.64444px">cf.br bb3(%0) br bb3(%0) + style="font-size:5.64444px">cf.br bb3(%0) + +namespace mlir { +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +namespace cf { +/// Collect the patterns to convert from the ControlFlow dialect to LLVM. The +/// conversion patterns capture the LLVMTypeConverter by reference meaning the +/// references have to remain alive during the entire pattern lifetime. +void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect. +std::unique_ptr createConvertControlFlowToLLVMPass(); +} // namespace cf +} // namespace mlir + +#endif // MLIR_CONVERSION_CONTROLFLOWTOLLVM_CONTROLFLOWTOLLVM_H diff --git a/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h @@ -0,0 +1,28 @@ +//===- ControlFlowToSPIRV.h - CF to SPIR-V Patterns --------*- C++ ------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns to convert ControlFlow dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H +#define MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class SPIRVTypeConverter; +namespace cf { +/// Appends to a pattern list additional patterns for translating ControlFLow +/// ops to SPIR-V ops. +void populateControlFlowToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns); +} // namespace cf +} // namespace mlir + +#endif // MLIR_CONVERSION_CONTROLFLOWTOSPIRV_CONTROLFLOWTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -17,6 +17,8 @@ #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" @@ -35,10 +37,10 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRVPass.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -181,6 +181,28 @@ let dependentDialects = ["math::MathDialect"]; } +//===----------------------------------------------------------------------===// +// ControlFlowToLLVM +//===----------------------------------------------------------------------===// + +def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> { + let summary = "Convert ControlFlow operations to the LLVM dialect"; + let description = [{ + Convert ControlFlow operations into LLVM IR dialect operations. + + If other operations are present and their results are required by the LLVM + IR dialect operations, the pass will fail. Any LLVM IR operations or types + already present in the IR will be kept as is. + }]; + let constructor = "mlir::cf::createConvertControlFlowToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word">, + ]; +} + //===----------------------------------------------------------------------===// // GPUCommon //===----------------------------------------------------------------------===// @@ -460,6 +482,17 @@ let constructor = "mlir::createReconcileUnrealizedCastsPass()"; } +//===----------------------------------------------------------------------===// +// SCFToControlFlow +//===----------------------------------------------------------------------===// + +def SCFToControlFlow : Pass<"convert-scf-to-cf"> { + let summary = "Convert SCF dialect to ControlFlow dialect, replacing structured" + " control flow with a CFG"; + let constructor = "mlir::createConvertSCFToCFPass()"; + let dependentDialects = ["cf::ControlFlowDialect"]; +} + //===----------------------------------------------------------------------===// // SCFToOpenMP //===----------------------------------------------------------------------===// @@ -488,17 +521,6 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } -//===----------------------------------------------------------------------===// -// SCFToStandard -//===----------------------------------------------------------------------===// - -def SCFToStandard : Pass<"convert-scf-to-std"> { - let summary = "Convert SCF dialect to Standard dialect, replacing structured" - " control flow with a CFG"; - let constructor = "mlir::createLowerToCFGPass()"; - let dependentDialects = ["StandardOpsDialect"]; -} - //===----------------------------------------------------------------------===// // SCFToGPU //===----------------------------------------------------------------------===// @@ -547,7 +569,7 @@ computation lowering. }]; let constructor = "mlir::createConvertShapeConstraintsPass()"; - let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"]; + let dependentDialects = ["cf::ControlFlowDialect", "scf::SCFDialect"]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h @@ -0,0 +1,31 @@ +//===- ConvertSCFToControlFlow.h - Pass entrypoint --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_ +#define MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_ + +#include +#include + +namespace mlir { +struct LogicalResult; +class Pass; + +class RewritePatternSet; + +/// Collect a set of patterns to convert SCF operations to CFG branch-based +/// operations within the ControlFlow dialect. +void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns); + +/// Creates a pass to convert SCF operations to CFG branch-based operation in +/// the ControlFlow dialect. +std::unique_ptr createConvertSCFToCFPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SCFTOCONTROLFLOW_SCFTOCONTROLFLOW_H_ diff --git a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h deleted file mode 100644 --- a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h +++ /dev/null @@ -1,31 +0,0 @@ -//===- ConvertSCFToStandard.h - Pass entrypoint -----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_ -#define MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_ - -#include -#include - -namespace mlir { -struct LogicalResult; -class Pass; - -class RewritePatternSet; - -/// Collect a set of patterns to lower from scf.for, scf.if, and -/// loop.terminator to CFG operations within the Standard dialect, in particular -/// convert structured control flow into CFG branch-based control flow. -void populateLoopToStdConversionPatterns(RewritePatternSet &patterns); - -/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG. -std::unique_ptr createLowerToCFGPass(); - -} // namespace mlir - -#endif // MLIR_CONVERSION_SCFTOSTANDARD_SCFTOSTANDARD_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -26,9 +26,9 @@ #map0 = affine_map<(d0) -> (d0)> module { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: - br ^bb3(%arg1 : memref<2xf32>) + cf.br ^bb3(%arg1 : memref<2xf32>) ^bb2: %0 = memref.alloc() : memref<2xf32> linalg.generic { @@ -40,7 +40,7 @@ %tmp1 = exp %gen1_arg0 : f32 linalg.yield %tmp1 : f32 }: memref<2xf32>, memref<2xf32> - br ^bb3(%0 : memref<2xf32>) + cf.br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return @@ -55,11 +55,11 @@ #map0 = affine_map<(d0) -> (d0)> module { func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - cond_br %arg0, ^bb1, ^bb2 + cf.cond_br %arg0, ^bb1, ^bb2 ^bb1: // pred: ^bb0 %0 = memref.alloc() : memref<2xf32> memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32> - br ^bb3(%0 : memref<2xf32>) + cf.br ^bb3(%0 : memref<2xf32>) ^bb2: // pred: ^bb0 %1 = memref.alloc() : memref<2xf32> linalg.generic { @@ -74,7 +74,7 @@ %2 = memref.alloc() : memref<2xf32> memref.copy(%1, %2) : memref<2xf32>, memref<2xf32> dealloc %1 : memref<2xf32> - br ^bb3(%2 : memref<2xf32>) + cf.br ^bb3(%2 : memref<2xf32>) ^bb3(%3: memref<2xf32>): // 2 preds: ^bb1, ^bb2 memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32> dealloc %3 : memref<2xf32> diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(AMX) add_subdirectory(Bufferization) add_subdirectory(Complex) +add_subdirectory(ControlFlow) add_subdirectory(DLTI) add_subdirectory(EmitC) add_subdirectory(GPU) diff --git a/mlir/include/mlir/Dialect/ControlFlow/CMakeLists.txt b/mlir/include/mlir/Dialect/ControlFlow/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ControlFlow/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(ControlFlowOps cf ControlFlowOps) +add_mlir_doc(ControlFlowOps ControlFlowDialect Dialects/ -gen-dialect-doc) diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h @@ -0,0 +1,21 @@ +//===- ControlFlow.h - ControlFlow Dialect ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ControlFlow dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H +#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/IR/Dialect.h" + +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc" + +#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h @@ -0,0 +1,30 @@ +//===- ControlFlowOps.h - ControlFlow Operations ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operations of the ControlFlow dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H +#define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H + +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +class PatternRewriter; +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h.inc" + +#endif // MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td copy from mlir/include/mlir/Dialect/StandardOps/IR/Ops.td copy to mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td @@ -1,4 +1,4 @@ -//===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// +//===- ControlFlowOps.td - ControlFlow operations ----------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// Defines some MLIR standard operations. +// This file contains definitions for the operations within the ControlFlow +// dialect. // //===----------------------------------------------------------------------===// @@ -14,39 +15,29 @@ #define STANDARD_OPS include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/SymbolInterfaces.td" -include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/VectorInterfaces.td" -def StandardOps_Dialect : Dialect { - let name = "std"; - let cppNamespace = "::mlir"; +def ControlFlow_Dialect : Dialect { + let name = "cf"; + let cppNamespace = "::mlir::cf"; let dependentDialects = ["arith::ArithmeticDialect"]; - let hasConstantMaterializer = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let description = [{ + This dialect contains low-level, i.e. non-region based, control flow + constructs. These constructs generally represent control flow directly + on SSA blocks of a control flow graph. + }]; } -// Base class for Standard dialect ops. -class Std_Op traits = []> : - Op { - // For every standard op, there needs to be a: - // * void print(OpAsmPrinter &p, ${C++ class of Op} op) - // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, - // OperationState &result) - // functions. - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parse$cppClass(parser, result); }]; -} +class CF_Op traits = []> : + Op; //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// -def AssertOp : Std_Op<"assert"> { +def AssertOp : CF_Op<"assert"> { let summary = "Assert operation with message attribute"; let description = [{ Assert operation with single boolean operand and an error message attribute. @@ -71,22 +62,23 @@ // BranchOp //===----------------------------------------------------------------------===// -def BranchOp : Std_Op<"br", - [DeclareOpInterfaceMethods, - NoSideEffect, Terminator]> { +def BranchOp : CF_Op<"br", [ + DeclareOpInterfaceMethods, + NoSideEffect, Terminator + ]> { let summary = "branch operation"; let description = [{ - The `br` operation represents a branch operation in a function. - The operation takes variable number of operands and produces no results. - The operand number and types for each successor must match the arguments of - the block successor. + The `cf.br` operation represents a direct branch operation to a given + block. The operands of this operation are forwarded to the successor block, + and the number and type of the operands must match the arguments of the + target block. Example: ```mlir ^bb2: %2 = call @someFn() - br ^bb3(%2 : tensor<*xf32>) + cf.br ^bb3(%2 : tensor<*xf32>) ^bb3(%3: tensor<*xf32>): ``` }]; @@ -96,7 +88,7 @@ let builders = [ OpBuilder<(ins "Block *":$dest, - CArg<"ValueRange", "{}">:$destOperands), [{ + CArg<"ValueRange", "{}">:$destOperands), [{ $_state.addSuccessors(dest); $_state.addOperands(destOperands); }]>]; @@ -114,143 +106,11 @@ }]; } -//===----------------------------------------------------------------------===// -// CallOp -//===----------------------------------------------------------------------===// - -def CallOp : Std_Op<"call", - [CallOpInterface, MemRefsNormalizable, - DeclareOpInterfaceMethods]> { - let summary = "call operation"; - let description = [{ - The `call` operation represents a direct call to a function that is within - the same symbol scope as the call. The operands and result types of the - call must match the specified function type. The callee is encoded as a - symbol reference attribute named "callee". - - Example: - - ```mlir - %2 = call @my_add(%0, %1) : (f32, f32) -> f32 - ``` - }]; - - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); - let results = (outs Variadic); - - let builders = [ - OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ - $_state.addOperands(operands); - $_state.addAttribute("callee", SymbolRefAttr::get(callee)); - $_state.addTypes(callee.getType().getResults()); - }]>, - OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$operands), [{ - $_state.addOperands(operands); - $_state.addAttribute("callee", callee); - $_state.addTypes(results); - }]>, - OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$operands), [{ - build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); - }]>, - OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, - CArg<"ValueRange", "{}">:$operands), [{ - build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), - results, operands); - }]>]; - - let extraClassDeclaration = [{ - FunctionType getCalleeType(); - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - - /// Return the callee of this operation. - CallInterfaceCallable getCallableForCallee() { - return (*this)->getAttrOfType("callee"); - } - }]; - - let assemblyFormat = [{ - $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) - }]; -} - -//===----------------------------------------------------------------------===// -// CallIndirectOp -//===----------------------------------------------------------------------===// - -def CallIndirectOp : Std_Op<"call_indirect", [ - CallOpInterface, - TypesMatchWith<"callee input types match argument types", - "callee", "callee_operands", - "$_self.cast().getInputs()">, - TypesMatchWith<"callee result types match result types", - "callee", "results", - "$_self.cast().getResults()"> - ]> { - let summary = "indirect call operation"; - let description = [{ - The `call_indirect` operation represents an indirect call to a value of - function type. Functions are first class types in MLIR, and may be passed as - arguments and merged together with block arguments. The operands and result - types of the call must match the specified function type. - - Function values can be created with the - [`constant` operation](#stdconstant-constantop). - - Example: - - ```mlir - %31 = call_indirect %15(%0, %1) - : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> - ``` - }]; - - let arguments = (ins FunctionType:$callee, - Variadic:$callee_operands); - let results = (outs Variadic:$results); - - let builders = [ - OpBuilder<(ins "Value":$callee, CArg<"ValueRange", "{}">:$operands), [{ - $_state.operands.push_back(callee); - $_state.addOperands(operands); - $_state.addTypes(callee.getType().cast().getResults()); - }]>]; - - let extraClassDeclaration = [{ - // TODO: Remove once migrated callers. - ValueRange operands() { return getCalleeOperands(); } - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - - /// Return the callee of this operation. - CallInterfaceCallable getCallableForCallee() { return getCallee(); } - }]; - - let hasCanonicalizeMethod = 1; - - let assemblyFormat = - "$callee `(` $callee_operands `)` attr-dict `:` type($callee)"; -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// -def CondBranchOp : Std_Op<"cond_br", +def CondBranchOp : CF_Op<"cond_br", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect, Terminator]> { @@ -361,93 +221,11 @@ }]; } -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -def ConstantOp : Std_Op<"constant", - [ConstantLike, NoSideEffect, - DeclareOpInterfaceMethods]> { - let summary = "constant"; - let description = [{ - Syntax: - - ``` - operation ::= ssa-id `=` `std.constant` attribute-value `:` type - ``` - - The `constant` operation produces an SSA value from a symbol reference to a - `builtin.func` operation - - Example: - - ```mlir - // Reference to function @myfn. - %2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> - - // Equivalent generic forms - %2 = "std.constant"() {value = @myfn} - : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) - ``` - - MLIR does not allow direct references to functions in SSA operands because - the compiler is multithreaded, and disallowing SSA values to directly - reference a function simplifies this - ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). - }]; - - let arguments = (ins FlatSymbolRefAttr:$value); - let results = (outs AnyType); - let assemblyFormat = "attr-dict $value `:` type(results)"; - - let extraClassDeclaration = [{ - /// Returns true if a constant operation can be built with the given value - /// and result type. - static bool isBuildableWith(Attribute value, Type type); - }]; - - let hasFolder = 1; - let hasVerifier = 1; -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, - MemRefsNormalizable, ReturnLike, Terminator]> { - let summary = "return operation"; - let description = [{ - The `return` operation represents a return operation within a function. - The operation takes variable number of operands and produces no results. - The operand number and types must match the signature of the function - that contains the operation. - - Example: - - ```mlir - func @foo() : (i32, f8) { - ... - return %0, %1 : i32, f8 - } - ``` - }]; - - let arguments = (ins Variadic:$operands); - - let builders = [ - OpBuilder<(ins), - [{ build($_builder, $_state, llvm::None); }]>]; - - let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; - let hasVerifier = 1; -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// -def SwitchOp : Std_Op<"switch", +def SwitchOp : CF_Op<"switch", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect, Terminator]> { diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -84,15 +84,15 @@ affine.for %i = 0 to 100 { "foo"() : () -> () %v = scf.execute_region -> i64 { - cond_br %cond, ^bb1, ^bb2 + cf.cond_br %cond, ^bb1, ^bb2 ^bb1: %c1 = arith.constant 1 : i64 - br ^bb3(%c1 : i64) + cf.br ^bb3(%c1 : i64) ^bb2: %c2 = arith.constant 2 : i64 - br ^bb3(%c2 : i64) + cf.br ^bb3(%c2 : i64) ^bb3(%x : i64): scf.yield %x : i64 diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -14,7 +14,7 @@ #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -24,7 +24,6 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Interfaces/VectorInterfaces.h" // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -20,12 +20,11 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/VectorInterfaces.td" def StandardOps_Dialect : Dialect { let name = "std"; let cppNamespace = "::mlir"; - let dependentDialects = ["arith::ArithmeticDialect"]; + let dependentDialects = ["cf::ControlFlowDialect"]; let hasConstantMaterializer = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; } @@ -42,78 +41,6 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -//===----------------------------------------------------------------------===// -// AssertOp -//===----------------------------------------------------------------------===// - -def AssertOp : Std_Op<"assert"> { - let summary = "Assert operation with message attribute"; - let description = [{ - Assert operation with single boolean operand and an error message attribute. - If the argument is `true` this operation has no effect. Otherwise, the - program execution will abort. The provided error message may be used by a - runtime to propagate the error to the user. - - Example: - - ```mlir - assert %b, "Expected ... to be true" - ``` - }]; - - let arguments = (ins I1:$arg, StrAttr:$msg); - - let assemblyFormat = "$arg `,` $msg attr-dict"; - let hasCanonicalizeMethod = 1; -} - -//===----------------------------------------------------------------------===// -// BranchOp -//===----------------------------------------------------------------------===// - -def BranchOp : Std_Op<"br", - [DeclareOpInterfaceMethods, - NoSideEffect, Terminator]> { - let summary = "branch operation"; - let description = [{ - The `br` operation represents a branch operation in a function. - The operation takes variable number of operands and produces no results. - The operand number and types for each successor must match the arguments of - the block successor. - - Example: - - ```mlir - ^bb2: - %2 = call @someFn() - br ^bb3(%2 : tensor<*xf32>) - ^bb3(%3: tensor<*xf32>): - ``` - }]; - - let arguments = (ins Variadic:$destOperands); - let successors = (successor AnySuccessor:$dest); - - let builders = [ - OpBuilder<(ins "Block *":$dest, - CArg<"ValueRange", "{}">:$destOperands), [{ - $_state.addSuccessors(dest); - $_state.addOperands(destOperands); - }]>]; - - let extraClassDeclaration = [{ - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - }]; - - let hasCanonicalizeMethod = 1; - let assemblyFormat = [{ - $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict - }]; -} - //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -246,121 +173,6 @@ "$callee `(` $callee_operands `)` attr-dict `:` type($callee)"; } -//===----------------------------------------------------------------------===// -// CondBranchOp -//===----------------------------------------------------------------------===// - -def CondBranchOp : Std_Op<"cond_br", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - NoSideEffect, Terminator]> { - let summary = "conditional branch operation"; - let description = [{ - The `cond_br` terminator operation represents a conditional branch on a - boolean (1-bit integer) value. If the bit is set, then the first destination - is jumped to; if it is false, the second destination is chosen. The count - and types of operands must align with the arguments in the corresponding - target blocks. - - The MLIR conditional branch operation is not allowed to target the entry - block for a region. The two destinations of the conditional branch operation - are allowed to be the same. - - The following example illustrates a function with a conditional branch - operation that targets the same block. - - Example: - - ```mlir - func @select(%a: i32, %b: i32, %flag: i1) -> i32 { - // Both targets are the same, operands differ - cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32) - - ^bb1(%x : i32) : - return %x : i32 - } - ``` - }]; - - let arguments = (ins I1:$condition, - Variadic:$trueDestOperands, - Variadic:$falseDestOperands); - let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); - - let builders = [ - OpBuilder<(ins "Value":$condition, "Block *":$trueDest, - "ValueRange":$trueOperands, "Block *":$falseDest, - "ValueRange":$falseOperands), [{ - build($_builder, $_state, condition, trueOperands, falseOperands, trueDest, - falseDest); - }]>, - OpBuilder<(ins "Value":$condition, "Block *":$trueDest, - "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{ - build($_builder, $_state, condition, trueDest, ValueRange(), falseDest, - falseOperands); - }]>]; - - let extraClassDeclaration = [{ - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - // Accessors for operands to the 'true' destination. - Value getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - unsigned getNumTrueOperands() { return getTrueOperands().size(); } - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index) { - getTrueDestOperandsMutable().erase(index); - } - - // Accessors for operands to the 'false' destination. - Value getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_range getTrueOperands() { return getTrueDestOperands(); } - operand_range getFalseOperands() { return getFalseDestOperands(); } - - unsigned getNumFalseOperands() { return getFalseOperands().size(); } - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index) { - getFalseDestOperandsMutable().erase(index); - } - - private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - }]; - - let hasCanonicalizer = 1; - let assemblyFormat = [{ - $condition `,` - $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` - $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? - attr-dict - }]; -} - //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -443,93 +255,4 @@ let hasVerifier = 1; } -//===----------------------------------------------------------------------===// -// SwitchOp -//===----------------------------------------------------------------------===// - -def SwitchOp : Std_Op<"switch", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - NoSideEffect, Terminator]> { - let summary = "switch operation"; - let description = [{ - The `switch` terminator operation represents a switch on a signless integer - value. If the flag matches one of the specified cases, then the - corresponding destination is jumped to. If the flag does not match any of - the cases, the default destination is jumped to. The count and types of - operands must align with the arguments in the corresponding target blocks. - - Example: - - ```mlir - switch %flag : i32, [ - default: ^bb1(%a : i32), - 42: ^bb1(%b : i32), - 43: ^bb3(%c : i32) - ] - ``` - }]; - - let arguments = (ins - AnyInteger:$flag, - Variadic:$defaultOperands, - VariadicOfVariadic:$caseOperands, - OptionalAttr:$case_values, - I32ElementsAttr:$case_operand_segments - ); - let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations - ); - let builders = [ - OpBuilder<(ins "Value":$flag, - "Block *":$defaultDestination, - "ValueRange":$defaultOperands, - CArg<"ArrayRef", "{}">:$caseValues, - CArg<"BlockRange", "{}">:$caseDestinations, - CArg<"ArrayRef", "{}">:$caseOperands)>, - OpBuilder<(ins "Value":$flag, - "Block *":$defaultDestination, - "ValueRange":$defaultOperands, - CArg<"ArrayRef", "{}">:$caseValues, - CArg<"BlockRange", "{}">:$caseDestinations, - CArg<"ArrayRef", "{}">:$caseOperands)>, - OpBuilder<(ins "Value":$flag, - "Block *":$defaultDestination, - "ValueRange":$defaultOperands, - CArg<"DenseIntElementsAttr", "{}">:$caseValues, - CArg<"BlockRange", "{}">:$caseDestinations, - CArg<"ArrayRef", "{}">:$caseOperands)> - ]; - - let assemblyFormat = [{ - $flag `:` type($flag) `,` `[` `\n` - custom(ref(type($flag)),$defaultDestination, - $defaultOperands, - type($defaultOperands), - $case_values, - $caseDestinations, - $caseOperands, - type($caseOperands)) - `]` - attr-dict - }]; - - let extraClassDeclaration = [{ - /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index) { - return getCaseOperands()[index]; - } - - /// Return a mutable range of operands for the case destination block at the - /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index) { - return getCaseOperandsMutable()[index]; - } - }]; - - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - #endif // STANDARD_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/GPU/GPUDialect.h" @@ -61,6 +62,7 @@ arm_neon::ArmNeonDialect, async::AsyncDialect, bufferization::BufferizationDialect, + cf::ControlFlowDialect, complex::ComplexDialect, DLTIDialect, emitc::EmitCDialect, diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -6,6 +6,8 @@ add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexToLLVM) add_subdirectory(ComplexToStandard) +add_subdirectory(ControlFlowToLLVM) +add_subdirectory(ControlFlowToSPIRV) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) add_subdirectory(GPUToROCDL) @@ -25,10 +27,10 @@ add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(ReconcileUnrealizedCasts) +add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToGPU) add_subdirectory(SCFToOpenMP) add_subdirectory(SCFToSPIRV) -add_subdirectory(SCFToStandard) add_subdirectory(ShapeToStandard) add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ControlFlowToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ControlFlowToLLVM/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRControlFlowToLLVM + ControlFlowToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ControlFlowToLLVM + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRControlFlow + MLIRLLVMCommonConversion + MLIRLLVMIR + MLIRPass + MLIRTransformUtils + ) diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -0,0 +1,148 @@ +//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to convert MLIR standard and builtin dialects +// into the LLVM IR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include + +using namespace mlir; + +#define PASS_NAME "convert-cf-to-llvm" + +namespace { +/// Lower `std.assert`. The default lowering calls the `abort` function if the +/// assertion is violated and has no effect otherwise. The failure message is +/// ignored by the default lowering but should be propagated by any custom +/// lowering. +struct AssertOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // Insert the `abort` declaration if necessary. + auto module = op->getParentOfType(); + auto abortFunc = module.lookupSymbol("abort"); + if (!abortFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); + abortFunc = rewriter.create(rewriter.getUnknownLoc(), + "abort", abortFuncTy); + } + + // Split block at `assert` operation. + Block *opBlock = rewriter.getInsertionBlock(); + auto opPosition = rewriter.getInsertionPoint(); + Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); + + // Generate IR to call `abort`. + Block *failureBlock = rewriter.createBlock(opBlock->getParent()); + rewriter.create(loc, abortFunc, llvm::None); + rewriter.create(loc); + + // Generate assertion test. + rewriter.setInsertionPointToEnd(opBlock); + rewriter.replaceOpWithNewOp( + op, adaptor.getArg(), continuationBlock, failureBlock); + + return success(); + } +}; + +// Base class for LLVM IR lowering terminator operations with successors. +template +struct OneToOneLLVMTerminatorLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Super = OneToOneLLVMTerminatorLowering; + + LogicalResult + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), + op->getSuccessors(), op->getAttrs()); + return success(); + } +}; + +// FIXME: this should be tablegen'ed as well. +struct BranchOpLowering + : public OneToOneLLVMTerminatorLowering { + using Super::Super; +}; +struct CondBranchOpLowering + : public OneToOneLLVMTerminatorLowering { + using Super::Super; +}; +struct SwitchOpLowering + : public OneToOneLLVMTerminatorLowering { + using Super::Super; +}; + +} // namespace + +void mlir::cf::populateControlFlowToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // clang-format off + patterns.add< + AssertOpLowering, + BranchOpLowering, + CondBranchOpLowering, + SwitchOpLowering>(converter); + // clang-format on +} + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +/// A pass converting MLIR operations into the LLVM IR dialect. +struct ConvertControlFlowToLLVM + : public ConvertControlFlowToLLVMBase { + ConvertControlFlowToLLVM() = default; + + /// Run the dialect converter on the module. + void runOnOperation() override { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + + LowerToLLVMOptions options(&getContext()); + if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) + options.overrideIndexBitwidth(indexBitwidth); + + LLVMTypeConverter converter(&getContext(), options); + mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::cf::createConvertControlFlowToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt copy from mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt copy to mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/CMakeLists.txt @@ -1,6 +1,5 @@ -add_mlir_conversion_library(MLIRStandardToSPIRV - StandardToSPIRV.cpp - StandardToSPIRVPass.cpp +add_mlir_conversion_library(MLIRControlFlowToSPIRV + ControlFlowToSPIRV.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV @@ -10,15 +9,11 @@ MLIRConversionPassIncGen LINK_LIBS PUBLIC - MLIRArithmeticToSPIRV MLIRIR - MLIRMathToSPIRV - MLIRMemRef + MLIRControlFlow MLIRPass MLIRSPIRV MLIRSPIRVConversion MLIRSupport MLIRTransformUtils - MLIRStandard - MLIRTensor ) diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.cpp @@ -0,0 +1,73 @@ +//===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements patterns to convert standard dialect to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" +#include "../SPIRVCommon/Pattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "cf-to-spirv-pattern" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Operation conversion +//===----------------------------------------------------------------------===// + +namespace { + +/// Converts cf.br to spv.Branch. +struct BranchOpPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::BranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getDest(), + adaptor.getDestOperands()); + return success(); + } +}; + +/// Converts cf.cond_br to spv.BranchConditional. +struct CondBranchOpPattern final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cf::CondBranchOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(), + op.getFalseDest(), adaptor.getFalseDestOperands()); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern population +//===----------------------------------------------------------------------===// + +void mlir::cf::populateControlFlowToSPIRVPatterns( + SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + + patterns.add(typeConverter, context); +} diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -172,8 +173,8 @@ populateGpuRewritePatterns(patterns); (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); - mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, - llvmPatterns); + arith::populateArithmeticToLLVMConversionPatterns(converter, llvmPatterns); + cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); diff --git a/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt b/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/LinalgToLLVM/CMakeLists.txt @@ -19,7 +19,7 @@ MLIRLLVMCommonConversion MLIRLLVMIR MLIRMemRefToLLVM - MLIRSCFToStandard + MLIRSCFToControlFlow MLIRTransforms MLIRVectorToLLVM MLIRVectorToSCF diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -14,7 +14,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -433,7 +433,7 @@ /// +---------------------------------+ /// | | /// | | -/// | br loop(%loaded) | +/// | cf.br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | @@ -444,7 +444,7 @@ /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | -/// | | cond_br %ok, end, loop(%new) | +/// | | cf.cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -10,6 +10,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -66,7 +67,8 @@ // Convert to OpenMP operations with LLVM IR dialect RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); - mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); + arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); + cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); populateMemRefToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); populateOpenMPToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -29,6 +29,10 @@ class ArithmeticDialect; } // namespace arith +namespace cf { +class ControlFlowDialect; +} // namespace cf + namespace complex { class ComplexDialect; } // namespace complex diff --git a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt rename from mlir/lib/Conversion/SCFToStandard/CMakeLists.txt rename to mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt --- a/mlir/lib/Conversion/SCFToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToControlFlow/CMakeLists.txt @@ -1,8 +1,8 @@ -add_mlir_conversion_library(MLIRSCFToStandard - SCFToStandard.cpp +add_mlir_conversion_library(MLIRSCFToControlFlow + SCFToControlFlow.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToStandard + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToControlFlow DEPENDS MLIRConversionPassIncGen @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRArithmetic + MLIRControlFlow MLIRSCF MLIRTransforms ) diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp rename from mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp rename to mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -1,4 +1,4 @@ -//===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===// +//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,11 +11,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "../PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -29,7 +29,8 @@ namespace { -struct SCFToStandardPass : public SCFToStandardBase { +struct SCFToControlFlowPass + : public SCFToControlFlowBase { void runOnOperation() override; }; @@ -57,7 +58,7 @@ // | | // | | // | | -// | br cond(%iv, %init...) | +// | cf.br cond(%iv, %init...) | // +---------------------------------+ // | // -------| | @@ -65,7 +66,7 @@ // | +--------------------------------+ // | | cond(%iv, %init...): | // | | | -// | | cond_br %r, body, end | +// | | cf.cond_br %r, body, end | // | +--------------------------------+ // | | | // | | -------------| @@ -83,7 +84,7 @@ // | | | | // | | | | // | | %new_iv = | | -// | | br cond(%new_iv, %yields) | | +// | | cf.br cond(%new_iv, %yields) | | // | +--------------------------------+ | // | | | // |----------- |-------------------- @@ -125,7 +126,7 @@ // // +--------------------------------+ // | | -// | cond_br %cond, %then, %else | +// | cf.cond_br %cond, %then, %else | // +--------------------------------+ // | | // | --------------| @@ -133,7 +134,7 @@ // +--------------------------------+ | // | then: | | // | | | -// | br continue | | +// | cf.br continue | | // +--------------------------------+ | // | | // |---------- |------------- @@ -141,7 +142,7 @@ // | +--------------------------------+ // | | else: | // | | | -// | | br continue | +// | | cf.br continue | // | +--------------------------------+ // | | // ------| | @@ -155,7 +156,7 @@ // // +--------------------------------+ // | | -// | cond_br %cond, %then, %else | +// | cf.cond_br %cond, %then, %else | // +--------------------------------+ // | | // | --------------| @@ -163,7 +164,7 @@ // +--------------------------------+ | // | then: | | // | | | -// | br dom(%args...) | | +// | cf.br dom(%args...) | | // +--------------------------------+ | // | | // |---------- |------------- @@ -171,14 +172,14 @@ // | +--------------------------------+ // | | else: | // | | | -// | | br dom(%args...) | +// | | cf.br dom(%args...) | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ // | dom(%args...): | -// | br continue | +// | cf.br continue | // +--------------------------------+ // | // v @@ -218,7 +219,7 @@ /// /// +---------------------------------+ /// | | -/// | br ^before(%operands...) | +/// | cf.br ^before(%operands...) | /// +---------------------------------+ /// | /// -------| | @@ -233,7 +234,7 @@ /// | +--------------------------------+ /// | | ^before-last: /// | | %cond = | -/// | | cond_br %cond, | +/// | | cf.cond_br %cond, | /// | | ^after(%vals...), ^cont | /// | +--------------------------------+ /// | | | @@ -249,7 +250,7 @@ /// | +--------------------------------+ | /// | | ^after-last: | | /// | | %yields... = | | -/// | | br ^before(%yields...) | | +/// | | cf.br ^before(%yields...) | | /// | +--------------------------------+ | /// | | | /// |----------- |-------------------- @@ -321,7 +322,7 @@ SmallVector loopCarried; loopCarried.push_back(stepped); loopCarried.append(terminator->operand_begin(), terminator->operand_end()); - rewriter.create(loc, conditionBlock, loopCarried); + rewriter.create(loc, conditionBlock, loopCarried); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. @@ -337,15 +338,16 @@ destOperands.push_back(lowerBound); auto iterOperands = forOp.getIterOperands(); destOperands.append(iterOperands.begin(), iterOperands.end()); - rewriter.create(loc, conditionBlock, destOperands); + rewriter.create(loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); auto comparison = rewriter.create( loc, arith::CmpIPredicate::slt, iv, upperBound); - rewriter.create(loc, comparison, firstBodyBlock, - ArrayRef(), endBlock, ArrayRef()); + rewriter.create(loc, comparison, firstBodyBlock, + ArrayRef(), endBlock, + ArrayRef()); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); @@ -369,7 +371,7 @@ continueBlock = rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), SmallVector(ifOp.getNumResults(), loc)); - rewriter.create(loc, remainingOpsBlock); + rewriter.create(loc, remainingOpsBlock); } // Move blocks from the "then" region to the region containing 'scf.if', @@ -379,7 +381,7 @@ Operation *thenTerminator = thenRegion.back().getTerminator(); ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, continueBlock, thenTerminatorOperands); + rewriter.create(loc, continueBlock, thenTerminatorOperands); rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); @@ -393,15 +395,15 @@ Operation *elseTerminator = elseRegion.back().getTerminator(); ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, continueBlock, elseTerminatorOperands); + rewriter.create(loc, continueBlock, elseTerminatorOperands); rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, ifOp.getCondition(), thenBlock, - /*trueArgs=*/ArrayRef(), elseBlock, - /*falseArgs=*/ArrayRef()); + rewriter.create(loc, ifOp.getCondition(), thenBlock, + /*trueArgs=*/ArrayRef(), elseBlock, + /*falseArgs=*/ArrayRef()); // Ok, we're done! rewriter.replaceOp(ifOp, continueBlock->getArguments()); @@ -419,13 +421,13 @@ auto ®ion = op.getRegion(); rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, ®ion.front()); + rewriter.create(loc, ®ion.front()); for (Block &block : region) { if (auto terminator = dyn_cast(block.getTerminator())) { ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(&block); - rewriter.create(loc, remainingOpsBlock, terminatorOperands); + rewriter.create(loc, remainingOpsBlock, terminatorOperands); rewriter.eraseOp(terminator); } } @@ -538,20 +540,21 @@ // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, before, whileOp.getInits()); + rewriter.create(loc, before, whileOp.getInits()); // Replace terminators with branches. Assuming bodies are SESE, which holds // given only the patterns from this file, we only need to look at the last // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(beforeLast); auto condOp = cast(beforeLast->getTerminator()); - rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), - after, condOp.getArgs(), - continuation, ValueRange()); + rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), + after, condOp.getArgs(), + continuation, ValueRange()); rewriter.setInsertionPointToEnd(afterLast); auto yieldOp = cast(afterLast->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, before, yieldOp.getResults()); + rewriter.replaceOpWithNewOp(yieldOp, before, + yieldOp.getResults()); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. @@ -593,14 +596,14 @@ // Branch to the "before" region. rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(whileOp.getLoc(), before, whileOp.getInits()); + rewriter.create(whileOp.getLoc(), before, whileOp.getInits()); // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(beforeLast); auto condOp = cast(beforeLast->getTerminator()); - rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), - before, condOp.getArgs(), - continuation, ValueRange()); + rewriter.replaceOpWithNewOp(condOp, condOp.getCondition(), + before, condOp.getArgs(), + continuation, ValueRange()); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. @@ -609,17 +612,18 @@ return success(); } -void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) { +void mlir::populateSCFToControlFlowConversionPatterns( + RewritePatternSet &patterns) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), /*benefit=*/2); } -void SCFToStandardPass::runOnOperation() { +void SCFToControlFlowPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateLoopToStdConversionPatterns(patterns); - // Configure conversion to lower out scf.for, scf.if, scf.parallel and - // scf.while. Anything else is fine. + populateSCFToControlFlowConversionPatterns(patterns); + + // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); target.addIllegalOp(); @@ -629,6 +633,6 @@ signalPassFailure(); } -std::unique_ptr mlir::createLowerToCFGPass() { - return std::make_unique(); +std::unique_ptr mlir::createConvertSCFToCFPass() { + return std::make_unique(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "../PassDetail.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -29,7 +30,7 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(shape::CstrRequireOp op, PatternRewriter &rewriter) const override { - rewriter.create(op.getLoc(), op.getPred(), op.getMsgAttr()); + rewriter.create(op.getLoc(), op.getPred(), op.getMsgAttr()); rewriter.replaceOpWithNewOp(op, true); return success(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRAnalysis MLIRArithmeticToLLVM + MLIRControlFlowToLLVM MLIRDataLayoutInterfaces MLIRLLVMCommonConversion MLIRLLVMIR diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -14,6 +14,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -387,48 +388,6 @@ } }; -/// Lower `std.assert`. The default lowering calls the `abort` function if the -/// assertion is violated and has no effect otherwise. The failure message is -/// ignored by the default lowering but should be propagated by any custom -/// lowering. -struct AssertOpLowering : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(AssertOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - // Insert the `abort` declaration if necessary. - auto module = op->getParentOfType(); - auto abortFunc = module.lookupSymbol("abort"); - if (!abortFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); - abortFunc = rewriter.create(rewriter.getUnknownLoc(), - "abort", abortFuncTy); - } - - // Split block at `assert` operation. - Block *opBlock = rewriter.getInsertionBlock(); - auto opPosition = rewriter.getInsertionPoint(); - Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); - - // Generate IR to call `abort`. - Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - rewriter.create(loc, abortFunc, llvm::None); - rewriter.create(loc); - - // Generate assertion test. - rewriter.setInsertionPointToEnd(opBlock); - rewriter.replaceOpWithNewOp( - op, adaptor.getArg(), continuationBlock, failureBlock); - - return success(); - } -}; - struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -550,22 +509,6 @@ } }; -// Base class for LLVM IR lowering terminator operations with successors. -template -struct OneToOneLLVMTerminatorLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Super = OneToOneLLVMTerminatorLowering; - - LogicalResult - matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), - op->getSuccessors(), op->getAttrs()); - return success(); - } -}; - // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions @@ -633,21 +576,6 @@ return success(); } }; - -// FIXME: this should be tablegen'ed as well. -struct BranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Super::Super; -}; -struct CondBranchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Super::Super; -}; -struct SwitchOpLowering - : public OneToOneLLVMTerminatorLowering { - using Super::Super; -}; - } // namespace void mlir::populateStdToLLVMFuncOpConversionPattern( @@ -663,14 +591,10 @@ populateStdToLLVMFuncOpConversionPattern(converter, patterns); // clang-format off patterns.add< - AssertOpLowering, - BranchOpLowering, CallIndirectOpLowering, CallOpLowering, - CondBranchOpLowering, ConstantOpLowering, - ReturnOpLowering, - SwitchOpLowering>(converter); + ReturnOpLowering>(converter); // clang-format on } @@ -721,6 +645,7 @@ RewritePatternSet patterns(&getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns); arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); + cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToSPIRV/CMakeLists.txt @@ -11,6 +11,7 @@ LINK_LIBS PUBLIC MLIRArithmeticToSPIRV + MLIRControlFlowToSPIRV MLIRIR MLIRMathToSPIRV MLIRMemRef diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -46,24 +46,6 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.br to spv.Branch. -struct BranchOpPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(BranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Converts std.cond_br to spv.BranchConditional. -struct CondBranchOpPattern final : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CondBranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - /// Converts tensor.extract into loading using access chains from SPIR-V local /// variables. class TensorExtractPattern final @@ -146,31 +128,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// BranchOpPattern -//===----------------------------------------------------------------------===// - -LogicalResult -BranchOpPattern::matchAndRewrite(BranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp(op, op.getDest(), - adaptor.getDestOperands()); - return success(); -} - -//===----------------------------------------------------------------------===// -// CondBranchOpPattern -//===----------------------------------------------------------------------===// - -LogicalResult CondBranchOpPattern::matchAndRewrite( - CondBranchOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - rewriter.replaceOpWithNewOp( - op, op.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(), - op.getFalseDest(), adaptor.getFalseDestOperands()); - return success(); -} - //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -189,8 +146,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - ReturnOpPattern, BranchOpPattern, CondBranchOpPattern>(typeConverter, - context); + ReturnOpPattern>(typeConverter, context); } void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "../PassDetail.h" #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" +#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" @@ -40,9 +41,11 @@ options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; SPIRVTypeConverter typeConverter(targetAttr, options); - // TODO ArithmeticToSPIRV cannot be applied separately to StandardToSPIRV + // TODO ArithmeticToSPIRV/ControlFlowToSPIRV cannot be applied separately to + // StandardToSPIRV RewritePatternSet patterns(context); arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); + cf::populateControlFlowToSPIRVPatterns(typeConverter, patterns); populateMathToSPIRVPatterns(typeConverter, patterns); populateStandardToSPIRVPatterns(typeConverter, patterns); populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" @@ -169,11 +170,11 @@ /// /// ^entry: /// %token = async.runtime.create : !async.token - /// cond_br %cond, ^bb1, ^bb2 + /// cf.cond_br %cond, ^bb1, ^bb2 /// ^bb1: /// async.runtime.await %token /// async.runtime.drop_ref %token - /// br ^bb2 + /// cf.br ^bb2 /// ^bb2: /// return /// @@ -185,14 +186,14 @@ /// /// ^entry: /// %token = async.runtime.create : !async.token - /// cond_br %cond, ^bb1, ^reference_counting + /// cf.cond_br %cond, ^bb1, ^reference_counting /// ^bb1: /// async.runtime.await %token /// async.runtime.drop_ref %token - /// br ^bb2 + /// cf.br ^bb2 /// ^reference_counting: /// async.runtime.drop_ref %token - /// br ^bb2 + /// cf.br ^bb2 /// ^bb2: /// return /// @@ -208,7 +209,7 @@ /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup /// ^resume: /// %0 = async.runtime.load %value - /// br ^cleanup + /// cf.br ^cleanup /// ^cleanup: /// ... /// ^suspend: @@ -406,7 +407,7 @@ refCountingBlock = &successor->getParent()->emplaceBlock(); refCountingBlock->moveBefore(successor); OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); - builder.create(value.getLoc(), successor); + builder.create(value.getLoc(), successor); } OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -12,10 +12,11 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -105,18 +106,18 @@ /// %value = : !async.value // create async value /// %id = async.coro.id // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle -/// br ^preexisting_entry_block +/// cf.br ^preexisting_entry_block /// /// /* preexisting blocks modified to branch to the cleanup block */ /// /// ^set_error: // this block created lazily only if needed (see code below) /// async.runtime.set_error %token : !async.token /// async.runtime.set_error %value : !async.value -/// br ^cleanup +/// cf.br ^cleanup /// /// ^cleanup: /// async.coro.free %hdl // delete the coroutine state -/// br ^suspend +/// cf.br ^suspend /// /// ^suspend: /// async.coro.end %hdl // marks the end of a coroutine @@ -147,7 +148,7 @@ auto coroIdOp = builder.create(CoroIdType::get(ctx)); auto coroHdlOp = builder.create(CoroHandleType::get(ctx), coroIdOp.id()); - builder.create(originalEntryBlock); + builder.create(originalEntryBlock); Block *cleanupBlock = func.addBlock(); Block *suspendBlock = func.addBlock(); @@ -159,7 +160,7 @@ builder.create(coroIdOp.id(), coroHdlOp.handle()); // Branch into the suspend block. - builder.create(suspendBlock); + builder.create(suspendBlock); // ------------------------------------------------------------------------ // // Coroutine suspend block: mark the end of a coroutine and return allocated @@ -186,7 +187,7 @@ Operation *terminator = block.getTerminator(); if (auto yield = dyn_cast(terminator)) { builder.setInsertionPointToEnd(&block); - builder.create(cleanupBlock); + builder.create(cleanupBlock); } } @@ -227,7 +228,7 @@ builder.create(retValue); // Branch into the cleanup block. - builder.create(coro.cleanup); + builder.create(coro.cleanup); return coro.setError; } @@ -305,7 +306,7 @@ // Async resume operation (execution will be resumed in a thread managed by // the async runtime). { - BranchOp branch = cast(coro.entry->getTerminator()); + cf::BranchOp branch = cast(coro.entry->getTerminator()); builder.setInsertionPointToEnd(coro.entry); // Save the coroutine state: async.coro.save @@ -419,8 +420,8 @@ isError, builder.create( loc, i1, builder.getIntegerAttr(i1, 1))); - builder.create(notError, - "Awaited async operand is in error state"); + builder.create(notError, + "Awaited async operand is in error state"); } // Inside the coroutine we convert await operation into coroutine suspension @@ -452,11 +453,11 @@ // Check if the awaited value is in the error state. builder.setInsertionPointToStart(resume); auto isError = builder.create(loc, i1, operand); - builder.create(isError, - /*trueDest=*/setupSetErrorBlock(coro), - /*trueArgs=*/ArrayRef(), - /*falseDest=*/continuation, - /*falseArgs=*/ArrayRef()); + builder.create(isError, + /*trueDest=*/setupSetErrorBlock(coro), + /*trueArgs=*/ArrayRef(), + /*falseDest=*/continuation, + /*falseArgs=*/ArrayRef()); // Make sure that replacement value will be constructed in the // continuation block. @@ -560,18 +561,18 @@ }; //===----------------------------------------------------------------------===// -// Convert std.assert operation to cond_br into `set_error` block. +// Convert std.assert operation to cf.cond_br into `set_error` block. //===----------------------------------------------------------------------===// -class AssertOpLowering : public OpConversionPattern { +class AssertOpLowering : public OpConversionPattern { public: AssertOpLowering(MLIRContext *ctx, llvm::DenseMap &outlinedFunctions) - : OpConversionPattern(ctx), + : OpConversionPattern(ctx), outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(AssertOp op, OpAdaptor adaptor, + matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if assert operation is inside the async coroutine function. auto func = op->template getParentOfType(); @@ -585,11 +586,11 @@ Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, adaptor.getArg(), - /*trueDest=*/cont, - /*trueArgs=*/ArrayRef(), - /*falseDest=*/setupSetErrorBlock(coro), - /*falseArgs=*/ArrayRef()); + rewriter.create(loc, adaptor.getArg(), + /*trueDest=*/cont, + /*trueArgs=*/ArrayRef(), + /*falseDest=*/setupSetErrorBlock(coro), + /*falseArgs=*/ArrayRef()); rewriter.eraseOp(op); return success(); @@ -765,7 +766,7 @@ // and we have to make sure that structured control flow operations with async // operations in nested regions will be converted to branch-based control flow // before we add the coroutine basic blocks. - populateLoopToStdConversionPatterns(asyncPatterns); + populateSCFToControlFlowConversionPatterns(asyncPatterns); // Async lowering does not use type converter because it must preserve all // types for async.runtime operations. @@ -792,14 +793,15 @@ }); return !walkResult.wasInterrupted(); }); - runtimeTarget.addLegalOp(); + runtimeTarget.addLegalOp(); // Assertions must be converted to runtime errors inside async functions. - runtimeTarget.addDynamicallyLegalOp([&](AssertOp op) -> bool { - auto func = op->getParentOfType(); - return outlinedFunctions.find(func) == outlinedFunctions.end(); - }); + runtimeTarget.addDynamicallyLegalOp( + [&](cf::AssertOp op) -> bool { + auto func = op->getParentOfType(); + return outlinedFunctions.find(func) == outlinedFunctions.end(); + }); if (eliminateBlockingAwaitOps) runtimeTarget.addDynamicallyLegalOp( diff --git a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Async/Transforms/CMakeLists.txt @@ -17,7 +17,7 @@ MLIRIR MLIRPass MLIRSCF - MLIRSCFToStandard + MLIRSCFToControlFlow MLIRStandard MLIRTransforms MLIRTransformUtils diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -18,12 +18,12 @@ // (using the BufferViewFlowAnalysis class). Consider the following example: // // ^bb0(%arg0): -// cond_br %cond, ^bb1, ^bb2 +// cf.cond_br %cond, ^bb1, ^bb2 // ^bb1: -// br ^exit(%arg0) +// cf.br ^exit(%arg0) // ^bb2: // %new_value = ... -// br ^exit(%new_value) +// cf.br ^exit(%new_value) // ^exit(%arg1): // return %arg1; // diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(AMX) add_subdirectory(Bufferization) add_subdirectory(Complex) +add_subdirectory(ControlFlow) add_subdirectory(DLTI) add_subdirectory(EmitC) add_subdirectory(GPU) diff --git a/mlir/lib/Dialect/ControlFlow/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRControlFlow + ControlFlowOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/IR + + DEPENDS + MLIRControlFlowOpsIncGen + + LINK_LIBS PUBLIC + MLIRArithmetic + MLIRControlFlowInterfaces + MLIRIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp copy from mlir/lib/Dialect/StandardOps/IR/Ops.cpp copy to mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -1,4 +1,4 @@ -//===- Ops.cpp - Standard MLIR Operations ---------------------------------===// +//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -31,94 +30,46 @@ #include "llvm/Support/raw_ostream.h" #include -#include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc" - -// Pull in all enum type definitions and utility function declarations. -#include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc" using namespace mlir; +using namespace mlir::cf; //===----------------------------------------------------------------------===// -// StandardOpsDialect Interfaces +// ControlFlowDialect Interfaces //===----------------------------------------------------------------------===// namespace { -/// This class defines the interface for handling inlining with standard +/// This class defines the interface for handling inlining with control flow /// operations. -struct StdInlinerInterface : public DialectInlinerInterface { +struct ControlFlowInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + ~ControlFlowInlinerInterface() override = default; - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - /// All call operations within standard ops can be inlined. + /// All control flow operations can be inlined. bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { return true; } - - /// All operations within standard ops can be inlined. bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, Block *newDest) const final { - // Only "std.return" needs to be handled here. - auto returnOp = dyn_cast(op); - if (!returnOp) - return; - - // Replace the return with a branch to the dest. - OpBuilder builder(op); - builder.create(op->getLoc(), newDest, returnOp.getOperands()); - op->erase(); - } - - /// Handle the given inlined terminator by replacing it with a new operation - /// as necessary. - void handleTerminator(Operation *op, - ArrayRef valuesToRepl) const final { - // Only "std.return" needs to be handled here. - auto returnOp = cast(op); - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } + /// ControlFlow terminator operations don't really need any special handing. + void handleTerminator(Operation *op, Block *newDest) const final {} }; } // namespace //===----------------------------------------------------------------------===// -// StandardOpsDialect +// ControlFlowDialect //===----------------------------------------------------------------------===// -void StandardOpsDialect::initialize() { +void ControlFlowDialect::initialize() { addOperations< #define GET_OP_LIST -#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" >(); - addInterfaces(); -} - -/// Materialize a single constant operation from a given attribute value with -/// the desired resultant type. -Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); - if (ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, - value.cast()); - return nullptr; + addInterfaces(); } //===----------------------------------------------------------------------===// @@ -245,76 +196,14 @@ return getDest(); } -//===----------------------------------------------------------------------===// -// CallOp -//===----------------------------------------------------------------------===// - -LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // Check that the callee attribute was specified. - auto fnAttr = (*this)->getAttrOfType("callee"); - if (!fnAttr) - return emitOpError("requires a 'callee' symbol reference attribute"); - FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); - if (!fn) - return emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; - - // Verify that the operand and result types match the callee. - auto fnType = fn.getType(); - if (fnType.getNumInputs() != getNumOperands()) - return emitOpError("incorrect number of operands for callee"); - - for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (getOperand(i).getType() != fnType.getInput(i)) - return emitOpError("operand type mismatch: expected operand type ") - << fnType.getInput(i) << ", but provided " - << getOperand(i).getType() << " for operand number " << i; - - if (fnType.getNumResults() != getNumResults()) - return emitOpError("incorrect number of results for callee"); - - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (getResult(i).getType() != fnType.getResult(i)) { - auto diag = emitOpError("result type mismatch at index ") << i; - diag.attachNote() << " op result types: " << getResultTypes(); - diag.attachNote() << "function result types: " << fnType.getResults(); - return diag; - } - - return success(); -} - -FunctionType CallOp::getCalleeType() { - return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); -} - -//===----------------------------------------------------------------------===// -// CallIndirectOp -//===----------------------------------------------------------------------===// - -/// Fold indirect calls that have a constant function as the callee operand. -LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, - PatternRewriter &rewriter) { - // Check that the callee is a constant callee. - SymbolRefAttr calledFn; - if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) - return failure(); - - // Replace with a direct call. - rewriter.replaceOpWithNewOp(indirectCall, calledFn, - indirectCall.getResultTypes(), - indirectCall.getArgOperands()); - return success(); -} - //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// namespace { -/// cond_br true, ^bb1, ^bb2 +/// cf.cond_br true, ^bb1, ^bb2 /// -> br ^bb1 -/// cond_br false, ^bb1, ^bb2 +/// cf.cond_br false, ^bb1, ^bb2 /// -> br ^bb2 /// struct SimplifyConstCondBranchPred : public OpRewritePattern { @@ -338,13 +227,13 @@ } }; -/// cond_br %cond, ^bb1, ^bb2 +/// cf.cond_br %cond, ^bb1, ^bb2 /// ^bb1 /// br ^bbN(...) /// ^bb2 /// br ^bbK(...) /// -/// -> cond_br %cond, ^bbN(...), ^bbK(...) +/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...) /// struct SimplifyPassThroughCondBranch : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -372,10 +261,10 @@ } }; -/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) +/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) /// -> br ^bb1(A, ..., N) /// -/// cond_br %cond, ^bb1(A), ^bb1(B) +/// cf.cond_br %cond, ^bb1(A), ^bb1(B) /// -> %select = arith.select %cond, A, B /// br ^bb1(%select) /// @@ -422,16 +311,16 @@ }; /// ... -/// cond_br %cond, ^bb1(...), ^bb2(...) +/// cf.cond_br %cond, ^bb1(...), ^bb2(...) /// ... /// ^bb1: // has single predecessor /// ... -/// cond_br %cond, ^bb3(...), ^bb4(...) +/// cf.cond_br %cond, ^bb3(...), ^bb4(...) /// /// -> /// /// ... -/// cond_br %cond, ^bb1(...), ^bb2(...) +/// cf.cond_br %cond, ^bb1(...), ^bb2(...) /// ... /// ^bb1: // has single predecessor /// ... @@ -466,7 +355,7 @@ } }; -/// cond_br %arg0, ^trueB, ^falseB +/// cf.cond_br %arg0, ^trueB, ^falseB /// /// ^trueB: /// "test.consumer1"(%arg0) : (i1) -> () @@ -478,7 +367,7 @@ /// /// -> /// -/// cond_br %arg0, ^trueB, ^falseB +/// cf.cond_br %arg0, ^trueB, ^falseB /// ^trueB: /// "test.consumer1"(%true) : (i1) -> () /// ... @@ -561,66 +450,6 @@ return nullptr; } -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -LogicalResult ConstantOp::verify() { - StringRef fnName = getValue(); - Type type = getType(); - - // Try to find the referenced function. - auto fn = (*this)->getParentOfType().lookupSymbol(fnName); - if (!fn) - return emitOpError() << "reference to undefined function '" << fnName - << "'"; - - // Check that the referenced function has the correct type. - if (fn.getType() != type) - return emitOpError("reference to function with mismatched type"); - - return success(); -} - -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - return getValueAttr(); -} - -void ConstantOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(getResult(), "f"); -} - -bool ConstantOp::isBuildableWith(Attribute value, Type type) { - return value.isa() && type.isa(); -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -LogicalResult ReturnOp::verify() { - auto function = cast((*this)->getParentOp()); - - // The operand number and types must match the function signature. - const auto &results = function.getType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError("has ") - << getNumOperands() << " operands, but enclosing function (@" - << function.getName() << ") returns " << results.size(); - - for (unsigned i = 0, e = results.size(); i != e; ++i) - if (getOperand(i).getType() != results[i]) - return emitError() << "type of return operand " << i << " (" - << getOperand(i).getType() - << ") doesn't match function result type (" - << results[i] << ")" - << " in function @" << function.getName(); - - return success(); -} - //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// @@ -1059,4 +888,4 @@ //===----------------------------------------------------------------------===// #define GET_OP_CLASSES -#include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc" diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -12,10 +12,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" @@ -44,14 +44,14 @@ /// workgroup memory. /// /// %subgroup_reduce = `createSubgroupReduce(%operand)` - /// cond_br %is_first_lane, ^then1, ^continue1 + /// cf.cond_br %is_first_lane, ^then1, ^continue1 /// ^then1: /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] - /// br ^continue1 + /// cf.br ^continue1 /// ^continue1: /// gpu.barrier /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups - /// cond_br %is_valid_subgroup, ^then2, ^continue2 + /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2 /// ^then2: /// %partial_reduce = load %workgroup_buffer[%invocation_idx] /// %all_reduce = `createSubgroupReduce(%partial_reduce)` @@ -194,7 +194,7 @@ // Add branch before inserted body, into body. block = block->getNextNode(); - create(block, ValueRange()); + create(block, ValueRange()); // Replace all gpu.yield ops with branch out of body. for (; block != split; block = block->getNextNode()) { @@ -202,7 +202,7 @@ if (!isa(terminator)) continue; rewriter.setInsertionPointToEnd(block); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( terminator, split, ValueRange(terminator->getOperand(0))); } @@ -285,17 +285,17 @@ Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); rewriter.setInsertionPointToEnd(currentBlock); - create(condition, thenBlock, - /*trueOperands=*/ArrayRef(), elseBlock, - /*falseOperands=*/ArrayRef()); + create(condition, thenBlock, + /*trueOperands=*/ArrayRef(), elseBlock, + /*falseOperands=*/ArrayRef()); rewriter.setInsertionPointToStart(thenBlock); auto thenOperands = thenOpsFactory(); - create(continueBlock, thenOperands); + create(continueBlock, thenOperands); rewriter.setInsertionPointToStart(elseBlock); auto elseOperands = elseOpsFactory(); - create(continueBlock, elseOperands); + create(continueBlock, elseOperands); assert(thenOperands.size() == elseOperands.size()); rewriter.setInsertionPointToStart(continueBlock); diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -12,6 +12,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" @@ -186,7 +187,7 @@ Block &launchOpEntry = launchOpBody.front(); Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry); builder.setInsertionPointToEnd(&entryBlock); - builder.create(loc, clonedLaunchOpEntry); + builder.create(loc, clonedLaunchOpEntry); outlinedFunc.walk([](gpu::TerminatorOp op) { OpBuilder replacer(op); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" @@ -254,13 +255,13 @@ DenseSet &blockArgsToDetensor) override { SmallVector workList; - func->walk([&](CondBranchOp condBr) { + func->walk([&](cf::CondBranchOp condBr) { for (auto operand : condBr.getOperands()) { workList.push_back(operand); } }); - func->walk([&](BranchOp br) { + func->walk([&](cf::BranchOp br) { for (auto operand : br.getOperands()) { workList.push_back(operand); } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Matchers.h" @@ -165,13 +166,13 @@ // "test.foo"() : () -> () // %v = scf.execute_region -> i64 { // %c = "test.cmp"() : () -> i1 -// cond_br %c, ^bb2, ^bb3 +// cf.cond_br %c, ^bb2, ^bb3 // ^bb2: // %x = "test.val1"() : () -> i64 -// br ^bb4(%x : i64) +// cf.br ^bb4(%x : i64) // ^bb3: // %y = "test.val2"() : () -> i64 -// br ^bb4(%y : i64) +// cf.br ^bb4(%y : i64) // ^bb4(%z : i64): // scf.yield %z : i64 // } @@ -184,13 +185,13 @@ // func @func_execute_region_elim() { // "test.foo"() : () -> () // %c = "test.cmp"() : () -> i1 -// cond_br %c, ^bb1, ^bb2 +// cf.cond_br %c, ^bb1, ^bb2 // ^bb1: // pred: ^bb0 // %x = "test.val1"() : () -> i64 -// br ^bb3(%x : i64) +// cf.br ^bb3(%x : i64) // ^bb2: // pred: ^bb0 // %y = "test.val2"() : () -> i64 -// br ^bb3(%y : i64) +// cf.br ^bb3(%y : i64) // ^bb3(%z: i64): // 2 preds: ^bb1, ^bb2 // "test.bar"(%z) : (i64) -> () // return @@ -208,13 +209,13 @@ Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); rewriter.setInsertionPointToEnd(prevBlock); - rewriter.create(op.getLoc(), &op.getRegion().front()); + rewriter.create(op.getLoc(), &op.getRegion().front()); for (Block &blk : op.getRegion()) { if (YieldOp yieldOp = dyn_cast(blk.getTerminator())) { rewriter.setInsertionPoint(yieldOp); - rewriter.create(yieldOp.getLoc(), postBlock, - yieldOp.getResults()); + rewriter.create(yieldOp.getLoc(), postBlock, + yieldOp.getResults()); rewriter.eraseOp(yieldOp); } } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/CMakeLists.txt @@ -13,7 +13,7 @@ MLIRMemRefToLLVM MLIRPass MLIRReconcileUnrealizedCasts - MLIRSCFToStandard + MLIRSCFToControlFlow MLIRSparseTensor MLIRSparseTensorTransforms MLIRStandardOpsTransforms diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -32,7 +32,7 @@ pm.addPass(createLinalgBufferizePass()); pm.addPass(createConvertLinalgToLoopsPass()); pm.addPass(createConvertVectorToSCFPass()); - pm.addPass(createLowerToCFGPass()); // --convert-scf-to-std + pm.addPass(createConvertSCFToCFPass()); // --convert-scf-to-cf pm.addPass(createFuncBufferizePass()); pm.addPass(arith::createConstantBufferizePass()); pm.addPass(createTensorBufferizePass()); diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRArithmetic MLIRCallInterfaces MLIRCastInterfaces + MLIRControlFlow MLIRControlFlowInterfaces MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -8,9 +8,8 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -77,7 +76,7 @@ // Replace the return with a branch to the dest. OpBuilder builder(op); - builder.create(op->getLoc(), newDest, returnOp.getOperands()); + builder.create(op->getLoc(), newDest, returnOp.getOperands()); op->erase(); } @@ -121,130 +120,6 @@ return nullptr; } -//===----------------------------------------------------------------------===// -// AssertOp -//===----------------------------------------------------------------------===// - -LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { - // Erase assertion if argument is constant true. - if (matchPattern(op.getArg(), m_One())) { - rewriter.eraseOp(op); - return success(); - } - return failure(); -} - -//===----------------------------------------------------------------------===// -// BranchOp -//===----------------------------------------------------------------------===// - -/// Given a successor, try to collapse it to a new destination if it only -/// contains a passthrough unconditional branch. If the successor is -/// collapsable, `successor` and `successorOperands` are updated to reference -/// the new destination and values. `argStorage` is used as storage if operands -/// to the collapsed successor need to be remapped. It must outlive uses of -/// successorOperands. -static LogicalResult collapseBranch(Block *&successor, - ValueRange &successorOperands, - SmallVectorImpl &argStorage) { - // Check that the successor only contains a unconditional branch. - if (std::next(successor->begin()) != successor->end()) - return failure(); - // Check that the terminator is an unconditional branch. - BranchOp successorBranch = dyn_cast(successor->getTerminator()); - if (!successorBranch) - return failure(); - // Check that the arguments are only used within the terminator. - for (BlockArgument arg : successor->getArguments()) { - for (Operation *user : arg.getUsers()) - if (user != successorBranch) - return failure(); - } - // Don't try to collapse branches to infinite loops. - Block *successorDest = successorBranch.getDest(); - if (successorDest == successor) - return failure(); - - // Update the operands to the successor. If the branch parent has no - // arguments, we can use the branch operands directly. - OperandRange operands = successorBranch.getOperands(); - if (successor->args_empty()) { - successor = successorDest; - successorOperands = operands; - return success(); - } - - // Otherwise, we need to remap any argument operands. - for (Value operand : operands) { - BlockArgument argOperand = operand.dyn_cast(); - if (argOperand && argOperand.getOwner() == successor) - argStorage.push_back(successorOperands[argOperand.getArgNumber()]); - else - argStorage.push_back(operand); - } - successor = successorDest; - successorOperands = argStorage; - return success(); -} - -/// Simplify a branch to a block that has a single predecessor. This effectively -/// merges the two blocks. -static LogicalResult -simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { - // Check that the successor block has a single predecessor. - Block *succ = op.getDest(); - Block *opParent = op->getBlock(); - if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) - return failure(); - - // Merge the successor into the current block and erase the branch. - rewriter.mergeBlocks(succ, opParent, op.getOperands()); - rewriter.eraseOp(op); - return success(); -} - -/// br ^bb1 -/// ^bb1 -/// br ^bbN(...) -/// -/// -> br ^bbN(...) -/// -static LogicalResult simplifyPassThroughBr(BranchOp op, - PatternRewriter &rewriter) { - Block *dest = op.getDest(); - ValueRange destOperands = op.getOperands(); - SmallVector destOperandStorage; - - // Try to collapse the successor if it points somewhere other than this - // block. - if (dest == op->getBlock() || - failed(collapseBranch(dest, destOperands, destOperandStorage))) - return failure(); - - // Create a new branch with the collapsed successor. - rewriter.replaceOpWithNewOp(op, dest, destOperands); - return success(); -} - -LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { - return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || - succeeded(simplifyPassThroughBr(op, rewriter))); -} - -void BranchOp::setDest(Block *block) { return setSuccessor(block); } - -void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } - -Optional -BranchOp::getMutableSuccessorOperands(unsigned index) { - assert(index == 0 && "invalid successor index"); - return getDestOperandsMutable(); -} - -Block *BranchOp::getSuccessorForOperands(ArrayRef) { - return getDest(); -} - //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -307,260 +182,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// CondBranchOp -//===----------------------------------------------------------------------===// - -namespace { -/// cond_br true, ^bb1, ^bb2 -/// -> br ^bb1 -/// cond_br false, ^bb1, ^bb2 -/// -> br ^bb2 -/// -struct SimplifyConstCondBranchPred : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { - if (matchPattern(condbr.getCondition(), m_NonZero())) { - // True branch taken. - rewriter.replaceOpWithNewOp(condbr, condbr.getTrueDest(), - condbr.getTrueOperands()); - return success(); - } - if (matchPattern(condbr.getCondition(), m_Zero())) { - // False branch taken. - rewriter.replaceOpWithNewOp(condbr, condbr.getFalseDest(), - condbr.getFalseOperands()); - return success(); - } - return failure(); - } -}; - -/// cond_br %cond, ^bb1, ^bb2 -/// ^bb1 -/// br ^bbN(...) -/// ^bb2 -/// br ^bbK(...) -/// -/// -> cond_br %cond, ^bbN(...), ^bbK(...) -/// -struct SimplifyPassThroughCondBranch : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { - Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest(); - ValueRange trueDestOperands = condbr.getTrueOperands(); - ValueRange falseDestOperands = condbr.getFalseOperands(); - SmallVector trueDestOperandStorage, falseDestOperandStorage; - - // Try to collapse one of the current successors. - LogicalResult collapsedTrue = - collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); - LogicalResult collapsedFalse = - collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); - if (failed(collapsedTrue) && failed(collapsedFalse)) - return failure(); - - // Create a new branch with the collapsed successors. - rewriter.replaceOpWithNewOp(condbr, condbr.getCondition(), - trueDest, trueDestOperands, - falseDest, falseDestOperands); - return success(); - } -}; - -/// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) -/// -> br ^bb1(A, ..., N) -/// -/// cond_br %cond, ^bb1(A), ^bb1(B) -/// -> %select = arith.select %cond, A, B -/// br ^bb1(%select) -/// -struct SimplifyCondBranchIdenticalSuccessors - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { - // Check that the true and false destinations are the same and have the same - // operands. - Block *trueDest = condbr.getTrueDest(); - if (trueDest != condbr.getFalseDest()) - return failure(); - - // If all of the operands match, no selects need to be generated. - OperandRange trueOperands = condbr.getTrueOperands(); - OperandRange falseOperands = condbr.getFalseOperands(); - if (trueOperands == falseOperands) { - rewriter.replaceOpWithNewOp(condbr, trueDest, trueOperands); - return success(); - } - - // Otherwise, if the current block is the only predecessor insert selects - // for any mismatched branch operands. - if (trueDest->getUniquePredecessor() != condbr->getBlock()) - return failure(); - - // Generate a select for any operands that differ between the two. - SmallVector mergedOperands; - mergedOperands.reserve(trueOperands.size()); - Value condition = condbr.getCondition(); - for (auto it : llvm::zip(trueOperands, falseOperands)) { - if (std::get<0>(it) == std::get<1>(it)) - mergedOperands.push_back(std::get<0>(it)); - else - mergedOperands.push_back(rewriter.create( - condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); - } - - rewriter.replaceOpWithNewOp(condbr, trueDest, mergedOperands); - return success(); - } -}; - -/// ... -/// cond_br %cond, ^bb1(...), ^bb2(...) -/// ... -/// ^bb1: // has single predecessor -/// ... -/// cond_br %cond, ^bb3(...), ^bb4(...) -/// -/// -> -/// -/// ... -/// cond_br %cond, ^bb1(...), ^bb2(...) -/// ... -/// ^bb1: // has single predecessor -/// ... -/// br ^bb3(...) -/// -struct SimplifyCondBranchFromCondBranchOnSameCondition - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { - // Check that we have a single distinct predecessor. - Block *currentBlock = condbr->getBlock(); - Block *predecessor = currentBlock->getSinglePredecessor(); - if (!predecessor) - return failure(); - - // Check that the predecessor terminates with a conditional branch to this - // block and that it branches on the same condition. - auto predBranch = dyn_cast(predecessor->getTerminator()); - if (!predBranch || condbr.getCondition() != predBranch.getCondition()) - return failure(); - - // Fold this branch to an unconditional branch. - if (currentBlock == predBranch.getTrueDest()) - rewriter.replaceOpWithNewOp(condbr, condbr.getTrueDest(), - condbr.getTrueDestOperands()); - else - rewriter.replaceOpWithNewOp(condbr, condbr.getFalseDest(), - condbr.getFalseDestOperands()); - return success(); - } -}; - -/// cond_br %arg0, ^trueB, ^falseB -/// -/// ^trueB: -/// "test.consumer1"(%arg0) : (i1) -> () -/// ... -/// -/// ^falseB: -/// "test.consumer2"(%arg0) : (i1) -> () -/// ... -/// -/// -> -/// -/// cond_br %arg0, ^trueB, ^falseB -/// ^trueB: -/// "test.consumer1"(%true) : (i1) -> () -/// ... -/// -/// ^falseB: -/// "test.consumer2"(%false) : (i1) -> () -/// ... -struct CondBranchTruthPropagation : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { - // Check that we have a single distinct predecessor. - bool replaced = false; - Type ty = rewriter.getI1Type(); - - // These variables serve to prevent creating duplicate constants - // and hold constant true or false values. - Value constantTrue = nullptr; - Value constantFalse = nullptr; - - // TODO These checks can be expanded to encompas any use with only - // either the true of false edge as a predecessor. For now, we fall - // back to checking the single predecessor is given by the true/fasle - // destination, thereby ensuring that only that edge can reach the - // op. - if (condbr.getTrueDest()->getSinglePredecessor()) { - for (OpOperand &use : - llvm::make_early_inc_range(condbr.getCondition().getUses())) { - if (use.getOwner()->getBlock() == condbr.getTrueDest()) { - replaced = true; - - if (!constantTrue) - constantTrue = rewriter.create( - condbr.getLoc(), ty, rewriter.getBoolAttr(true)); - - rewriter.updateRootInPlace(use.getOwner(), - [&] { use.set(constantTrue); }); - } - } - } - if (condbr.getFalseDest()->getSinglePredecessor()) { - for (OpOperand &use : - llvm::make_early_inc_range(condbr.getCondition().getUses())) { - if (use.getOwner()->getBlock() == condbr.getFalseDest()) { - replaced = true; - - if (!constantFalse) - constantFalse = rewriter.create( - condbr.getLoc(), ty, rewriter.getBoolAttr(false)); - - rewriter.updateRootInPlace(use.getOwner(), - [&] { use.set(constantFalse); }); - } - } - } - return success(replaced); - } -}; -} // namespace - -void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -Optional -CondBranchOp::getMutableSuccessorOperands(unsigned index) { - assert(index < getNumSuccessors() && "invalid successor index"); - return index == trueIndex ? getTrueDestOperandsMutable() - : getFalseDestOperandsMutable(); -} - -Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { - if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) - return condAttr.getValue().isOneValue() ? getTrueDest() : getFalseDest(); - return nullptr; -} - //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -621,439 +242,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// SwitchOp -//===----------------------------------------------------------------------===// - -void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, - Block *defaultDestination, ValueRange defaultOperands, - DenseIntElementsAttr caseValues, - BlockRange caseDestinations, - ArrayRef caseOperands) { - build(builder, result, value, defaultOperands, caseOperands, caseValues, - defaultDestination, caseDestinations); -} - -void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, - Block *defaultDestination, ValueRange defaultOperands, - ArrayRef caseValues, BlockRange caseDestinations, - ArrayRef caseOperands) { - DenseIntElementsAttr caseValuesAttr; - if (!caseValues.empty()) { - ShapedType caseValueType = VectorType::get( - static_cast(caseValues.size()), value.getType()); - caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues); - } - build(builder, result, value, defaultDestination, defaultOperands, - caseValuesAttr, caseDestinations, caseOperands); -} - -/// ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? -/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* -static ParseResult parseSwitchOpCases( - OpAsmParser &parser, Type &flagType, Block *&defaultDestination, - SmallVectorImpl &defaultOperands, - SmallVectorImpl &defaultOperandTypes, - DenseIntElementsAttr &caseValues, - SmallVectorImpl &caseDestinations, - SmallVectorImpl> &caseOperands, - SmallVectorImpl> &caseOperandTypes) { - if (parser.parseKeyword("default") || parser.parseColon() || - parser.parseSuccessor(defaultDestination)) - return failure(); - if (succeeded(parser.parseOptionalLParen())) { - if (parser.parseRegionArgumentList(defaultOperands) || - parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) - return failure(); - } - - SmallVector values; - unsigned bitWidth = flagType.getIntOrFloatBitWidth(); - while (succeeded(parser.parseOptionalComma())) { - int64_t value = 0; - if (failed(parser.parseInteger(value))) - return failure(); - values.push_back(APInt(bitWidth, value)); - - Block *destination; - SmallVector operands; - SmallVector operandTypes; - if (failed(parser.parseColon()) || - failed(parser.parseSuccessor(destination))) - return failure(); - if (succeeded(parser.parseOptionalLParen())) { - if (failed(parser.parseRegionArgumentList(operands)) || - failed(parser.parseColonTypeList(operandTypes)) || - failed(parser.parseRParen())) - return failure(); - } - caseDestinations.push_back(destination); - caseOperands.emplace_back(operands); - caseOperandTypes.emplace_back(operandTypes); - } - - if (!values.empty()) { - ShapedType caseValueType = - VectorType::get(static_cast(values.size()), flagType); - caseValues = DenseIntElementsAttr::get(caseValueType, values); - } - return success(); -} - -static void printSwitchOpCases( - OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, - OperandRange defaultOperands, TypeRange defaultOperandTypes, - DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, - OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) { - p << " default: "; - p.printSuccessorAndUseList(defaultDestination, defaultOperands); - - if (!caseValues) - return; - - for (const auto &it : llvm::enumerate(caseValues.getValues())) { - p << ','; - p.printNewline(); - p << " "; - p << it.value().getLimitedValue(); - p << ": "; - p.printSuccessorAndUseList(caseDestinations[it.index()], - caseOperands[it.index()]); - } - p.printNewline(); -} - -LogicalResult SwitchOp::verify() { - auto caseValues = getCaseValues(); - auto caseDestinations = getCaseDestinations(); - - if (!caseValues && caseDestinations.empty()) - return success(); - - Type flagType = getFlag().getType(); - Type caseValueType = caseValues->getType().getElementType(); - if (caseValueType != flagType) - return emitOpError() << "'flag' type (" << flagType - << ") should match case value type (" << caseValueType - << ")"; - - if (caseValues && - caseValues->size() != static_cast(caseDestinations.size())) - return emitOpError() << "number of case values (" << caseValues->size() - << ") should match number of " - "case destinations (" - << caseDestinations.size() << ")"; - return success(); -} - -Optional -SwitchOp::getMutableSuccessorOperands(unsigned index) { - assert(index < getNumSuccessors() && "invalid successor index"); - return index == 0 ? getDefaultOperandsMutable() - : getCaseOperandsMutable(index - 1); -} - -Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { - Optional caseValues = getCaseValues(); - - if (!caseValues) - return getDefaultDestination(); - - SuccessorRange caseDests = getCaseDestinations(); - if (auto value = operands.front().dyn_cast_or_null()) { - for (const auto &it : llvm::enumerate(caseValues->getValues())) - if (it.value() == value.getValue()) - return caseDests[it.index()]; - return getDefaultDestination(); - } - return nullptr; -} - -/// switch %flag : i32, [ -/// default: ^bb1 -/// ] -/// -> br ^bb1 -static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op, - PatternRewriter &rewriter) { - if (!op.getCaseDestinations().empty()) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), - op.getDefaultOperands()); - return success(); -} - -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb1, -/// 43: ^bb2 -/// ] -/// -> -/// switch %flag : i32, [ -/// default: ^bb1, -/// 43: ^bb2 -/// ] -static LogicalResult -dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) { - SmallVector newCaseDestinations; - SmallVector newCaseOperands; - SmallVector newCaseValues; - bool requiresChange = false; - auto caseValues = op.getCaseValues(); - auto caseDests = op.getCaseDestinations(); - - for (const auto &it : llvm::enumerate(caseValues->getValues())) { - if (caseDests[it.index()] == op.getDefaultDestination() && - op.getCaseOperands(it.index()) == op.getDefaultOperands()) { - requiresChange = true; - continue; - } - newCaseDestinations.push_back(caseDests[it.index()]); - newCaseOperands.push_back(op.getCaseOperands(it.index())); - newCaseValues.push_back(it.value()); - } - - if (!requiresChange) - return failure(); - - rewriter.replaceOpWithNewOp( - op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), - newCaseValues, newCaseDestinations, newCaseOperands); - return success(); -} - -/// Helper for folding a switch with a constant value. -/// switch %c_42 : i32, [ -/// default: ^bb1 , -/// 42: ^bb2, -/// 43: ^bb3 -/// ] -/// -> br ^bb2 -static void foldSwitch(SwitchOp op, PatternRewriter &rewriter, - const APInt &caseValue) { - auto caseValues = op.getCaseValues(); - for (const auto &it : llvm::enumerate(caseValues->getValues())) { - if (it.value() == caseValue) { - rewriter.replaceOpWithNewOp( - op, op.getCaseDestinations()[it.index()], - op.getCaseOperands(it.index())); - return; - } - } - rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), - op.getDefaultOperands()); -} - -/// switch %c_42 : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// 43: ^bb3 -/// ] -/// -> br ^bb2 -static LogicalResult simplifyConstSwitchValue(SwitchOp op, - PatternRewriter &rewriter) { - APInt caseValue; - if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue))) - return failure(); - - foldSwitch(op, rewriter, caseValue); - return success(); -} - -/// switch %c_42 : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// ] -/// ^bb2: -/// br ^bb3 -/// -> -/// switch %c_42 : i32, [ -/// default: ^bb1, -/// 42: ^bb3, -/// ] -static LogicalResult simplifyPassThroughSwitch(SwitchOp op, - PatternRewriter &rewriter) { - SmallVector newCaseDests; - SmallVector newCaseOperands; - SmallVector> argStorage; - auto caseValues = op.getCaseValues(); - auto caseDests = op.getCaseDestinations(); - bool requiresChange = false; - for (int64_t i = 0, size = caseValues->size(); i < size; ++i) { - Block *caseDest = caseDests[i]; - ValueRange caseOperands = op.getCaseOperands(i); - argStorage.emplace_back(); - if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back()))) - requiresChange = true; - - newCaseDests.push_back(caseDest); - newCaseOperands.push_back(caseOperands); - } - - Block *defaultDest = op.getDefaultDestination(); - ValueRange defaultOperands = op.getDefaultOperands(); - argStorage.emplace_back(); - - if (succeeded( - collapseBranch(defaultDest, defaultOperands, argStorage.back()))) - requiresChange = true; - - if (!requiresChange) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getFlag(), defaultDest, - defaultOperands, caseValues.getValue(), - newCaseDests, newCaseOperands); - return success(); -} - -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// ] -/// ^bb2: -/// switch %flag : i32, [ -/// default: ^bb3, -/// 42: ^bb4 -/// ] -/// -> -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// ] -/// ^bb2: -/// br ^bb4 -/// -/// and -/// -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// ] -/// ^bb2: -/// switch %flag : i32, [ -/// default: ^bb3, -/// 43: ^bb4 -/// ] -/// -> -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// ] -/// ^bb2: -/// br ^bb3 -static LogicalResult -simplifySwitchFromSwitchOnSameCondition(SwitchOp op, - PatternRewriter &rewriter) { - // Check that we have a single distinct predecessor. - Block *currentBlock = op->getBlock(); - Block *predecessor = currentBlock->getSinglePredecessor(); - if (!predecessor) - return failure(); - - // Check that the predecessor terminates with a switch branch to this block - // and that it branches on the same condition and that this branch isn't the - // default destination. - auto predSwitch = dyn_cast(predecessor->getTerminator()); - if (!predSwitch || op.getFlag() != predSwitch.getFlag() || - predSwitch.getDefaultDestination() == currentBlock) - return failure(); - - // Fold this switch to an unconditional branch. - SuccessorRange predDests = predSwitch.getCaseDestinations(); - auto it = llvm::find(predDests, currentBlock); - if (it != predDests.end()) { - Optional predCaseValues = predSwitch.getCaseValues(); - foldSwitch(op, rewriter, - predCaseValues->getValues()[it - predDests.begin()]); - } else { - rewriter.replaceOpWithNewOp(op, op.getDefaultDestination(), - op.getDefaultOperands()); - } - return success(); -} - -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb2 -/// ] -/// ^bb1: -/// switch %flag : i32, [ -/// default: ^bb3, -/// 42: ^bb4, -/// 43: ^bb5 -/// ] -/// -> -/// switch %flag : i32, [ -/// default: ^bb1, -/// 42: ^bb2, -/// ] -/// ^bb1: -/// switch %flag : i32, [ -/// default: ^bb3, -/// 43: ^bb5 -/// ] -static LogicalResult -simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op, - PatternRewriter &rewriter) { - // Check that we have a single distinct predecessor. - Block *currentBlock = op->getBlock(); - Block *predecessor = currentBlock->getSinglePredecessor(); - if (!predecessor) - return failure(); - - // Check that the predecessor terminates with a switch branch to this block - // and that it branches on the same condition and that this branch is the - // default destination. - auto predSwitch = dyn_cast(predecessor->getTerminator()); - if (!predSwitch || op.getFlag() != predSwitch.getFlag() || - predSwitch.getDefaultDestination() != currentBlock) - return failure(); - - // Delete case values that are not possible here. - DenseSet caseValuesToRemove; - auto predDests = predSwitch.getCaseDestinations(); - auto predCaseValues = predSwitch.getCaseValues(); - for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i) - if (currentBlock != predDests[i]) - caseValuesToRemove.insert(predCaseValues->getValues()[i]); - - SmallVector newCaseDestinations; - SmallVector newCaseOperands; - SmallVector newCaseValues; - bool requiresChange = false; - - auto caseValues = op.getCaseValues(); - auto caseDests = op.getCaseDestinations(); - for (const auto &it : llvm::enumerate(caseValues->getValues())) { - if (caseValuesToRemove.contains(it.value())) { - requiresChange = true; - continue; - } - newCaseDestinations.push_back(caseDests[it.index()]); - newCaseOperands.push_back(op.getCaseOperands(it.index())); - newCaseValues.push_back(it.value()); - } - - if (!requiresChange) - return failure(); - - rewriter.replaceOpWithNewOp( - op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(), - newCaseValues, newCaseDestinations, newCaseOperands); - return success(); -} - -void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(&simplifySwitchWithOnlyDefault) - .add(&dropSwitchCasesThatMatchDefault) - .add(&simplifyConstSwitchValue) - .add(&simplifyPassThroughSwitch) - .add(&simplifySwitchFromSwitchOnSameCondition) - .add(&simplifySwitchFromDefaultSwitchOnSameCondition); -} - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" @@ -41,6 +42,7 @@ [](DialectRegistry ®istry) { // clang-format off registry.insert - +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -23,6 +22,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" +#include #define DEBUG_TYPE "translate-to-cpp" @@ -237,7 +237,8 @@ return printConstantOp(emitter, operation, value); } -static LogicalResult printOperation(CppEmitter &emitter, BranchOp branchOp) { +static LogicalResult printOperation(CppEmitter &emitter, + cf::BranchOp branchOp) { raw_ostream &os = emitter.ostream(); Block &successor = *branchOp.getSuccessor(); @@ -257,7 +258,7 @@ } static LogicalResult printOperation(CppEmitter &emitter, - CondBranchOp condBranchOp) { + cf::CondBranchOp condBranchOp) { raw_indented_ostream &os = emitter.ostream(); Block &trueSuccessor = *condBranchOp.getTrueDest(); Block &falseSuccessor = *condBranchOp.getFalseDest(); @@ -637,11 +638,12 @@ return failure(); } for (Operation &op : block.getOperations()) { - // When generating code for an scf.if or std.cond_br op no semicolon needs + // When generating code for an scf.if or cf.cond_br op no semicolon needs // to be printed after the closing brace. // When generating code for an scf.for op, printing a trailing semicolon // is handled within the printOperation function. - bool trailingSemicolon = !isa(op); + bool trailingSemicolon = + !isa(op); if (failed(emitter.emitOperation( op, /*trailingSemicolon=*/trailingSemicolon))) @@ -907,8 +909,8 @@ .Case( [&](auto op) { return printOperation(*this, op); }) // Standard ops. - .Case( + .Case( [&](auto op) { return printOperation(*this, op); }) // Arithmetic ops. .Case(