diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -59,6 +59,9 @@ /// operations. std::unique_ptr> createLinalgGeneralizationPass(); +std::unique_ptr createLinalgDetensorizePass(); +std::unique_ptr createFuncDetensorizePass(); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -136,4 +136,15 @@ let dependentDialects = ["linalg::LinalgDialect"]; } +def LinalgDetensorize : FunctionPass<"linalg-detensorize"> { + let summary = "Detensorize linalg ops"; + let constructor = "mlir::createLinalgDetensorizePass()"; + let dependentDialects = []; +} + +def FuncDetensorize : Pass<"func-detensorize", "ModuleOp"> { + let summary = "Detensorize func/call/return ops"; + let constructor = "mlir::createFuncDetensorizePass()"; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -18,6 +18,9 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -142,6 +145,58 @@ void createAffineComputationSlice(Operation *opInst, SmallVectorImpl *sliceOps); +template +LogicalResult finalizeTensorRelatedConversion(ModuleOp module, + MLIRContext *context) { + TypeConverterTy typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + populateFuncOpTypeConversionPattern(patterns, context, typeConverter); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()) && + typeConverter.isLegal(&op.getBody()); + }); + populateCallOpTypeConversionPattern(patterns, context, typeConverter); + target.addDynamicallyLegalOp( + [&](CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context, + typeConverter); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](ReturnOp op) { return typeConverter.isLegal(op); }); + // Mark terminators as legal if they have the ReturnLike trait or + // implement the BranchOpInterface and have valid types. If they do not + // implement the trait or interface, mark them as illegal no matter what. + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + // If it is not a terminator, ignore it. + if (!op->mightHaveTrait()) + return true; + // If it is not the last operation in the block, also ignore it. We do + // this to handle unknown operations, as well. + Block *block = op->getBlock(); + if (!block || &block->back() != op) + return true; + // ReturnLike operations have to be legalized with their parent. For + // return this is handled, for other ops they remain as is. + if (op->hasTrait()) + return true; + // All successor operands of branch like operations must be rewritten. + if (auto branchOp = dyn_cast(op)) { + for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { + auto successorOperands = branchOp.getSuccessorOperands(p); + if (successorOperands.hasValue() && + !typeConverter.isLegal(successorOperands.getValue().getTypes())) + return false; + } + return true; + } + return false; + }); + + return applyFullConversion(module, target, std::move(patterns)); +} } // end namespace mlir #endif // MLIR_TRANSFORMS_UTILS_H diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -397,7 +397,6 @@ // InitTensorOp //===----------------------------------------------------------------------===// - static LogicalResult verify(InitTensorOp op) { RankedTensorType resultType = op.getType(); SmallVector staticSizes = llvm::to_vector<4>(llvm::map_range( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -12,22 +12,30 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "PassDetail.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Utils.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include #include #define DEBUG_TYPE "linalg-transforms" @@ -650,3 +658,195 @@ return failure(); } + +namespace { +//===----------------------------------------------------------------------===// +// Support for detensoring. +//===----------------------------------------------------------------------===// +/// Detensoring is the process through which a tensor value is convereted to one +/// or potentially more primitive value(s). During this process, operations with +/// such detensored operands are also converted to an equivalent form that works +/// on primitives. +/// +/// The detensoring process is driven by linalg-on-tensor ops. In particular, a +/// linalg-on-tensor op is checked to see whether *all* its operands can be +/// detensored. If so, those operands are converted to their primitive +/// counterparts and the linalg op is replaced by an equivalent op that takes +/// those new primitive values as operands. Therefore, the detensoring process +/// can be divided into 2 main logical phases: +/// +/// 1. Detect/match an op that can be detensored. +/// 2. Detensor the operands of the op and replace it with a primitive +/// equivalent. +/// +/// These 2 logical phases are implemented by LinalgDetensoringPattern which is +/// documented in-place below. + +/// Defines the criteria a TensorType must follow in order to be considered +/// "detensorable". +/// +/// NOTE: For now, only 0-D and 1-D tensors are supported. +/// +/// Returns true if tensorType can be detensored. +bool canBeDetensored(TensorType tensorType) { + return tensorType.hasRank() && + ((tensorType.getRank() == 1 && tensorType.getNumElements() == 1) || + (tensorType.getRank() == 0)); +} + +class DetensorizeGenericOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(GenericOp genericOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation &genericOpBody = genericOp.getBody()->front(); + BlockAndValueMapping tensorToDetensoredOperandMapping; + + tensorToDetensoredOperandMapping.map( + genericOpBody.getOperands(), + ArrayRef{operands.begin(), genericOpBody.getNumOperands()}); + + OpBuilder::InsertionGuard g(rewriter); + + rewriter.setInsertionPoint(genericOp); + Operation *detensoredOp = + genericOpBody.clone(tensorToDetensoredOperandMapping); + rewriter.insert(detensoredOp); + rewriter.replaceOp(genericOp, detensoredOp->getResults()); + + return success(); + } +}; + +Value materializeFromElements(OpBuilder &builder, Type type, ValueRange inputs, + Location loc) { + auto createNewTensorOp = builder.create( + loc, inputs[0].getType(), inputs[0]); + + if (type.cast().getRank() == 1) + return createNewTensorOp; + + return builder.create( + loc, type, createNewTensorOp, ArrayRef{}); +} + +class DetensorizeTypeConverter : public TypeConverter { +public: + DetensorizeTypeConverter() { + addConversion([](Type type) { return type; }); + + addConversion([](TensorType tensorType) -> Type { + if (canBeDetensored(tensorType)) { + return tensorType.getElementType(); + } + + return tensorType; + }); + + addTargetMaterialization([](OpBuilder &builder, Type type, + ValueRange inputs, Location loc) -> Value { + Type indexTy = IndexType::get(builder.getContext()); + Attribute zero = IntegerAttr::get(indexTy, 0); + Value c0 = builder.create(loc, indexTy, zero); + + TensorType tensorType = inputs[0].getType().cast(); + return builder.create( + loc, inputs[0], + tensorType.getRank() == 0 ? ValueRange{} : ValueRange{c0}); + }); + + addSourceMaterialization(materializeFromElements); + addArgumentMaterialization(materializeFromElements); + } +}; + +// Canonicalizes the pattern of the form +// +// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32> +// %reshaped_tensor = linalg.tensor_reshape %tensor [] : tensor<1xi32> into +// tensor +// %extracted_element = tensor.extract %reshaped_tensor[] : tensor +// +// to just %element. +struct ExtractFromReshapeFromElements + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extract, + PatternRewriter &rewriter) const final { + if (extract.indices().size() != 0) + return failure(); + + auto tensorReshape = extract.tensor().getDefiningOp(); + if (tensorReshape == nullptr) + return failure(); + + auto tensorFromElements = + tensorReshape.getOperand() + .getDefiningOp(); + if (tensorFromElements == nullptr) + return failure(); + + rewriter.replaceOp(extract, tensorFromElements.getOperand(0)); + return success(); + } +}; + +struct LinalgDetensorize : public LinalgDetensorizeBase { + void runOnFunction() override { + auto *context = &getContext(); + DetensorizeTypeConverter typeConverter; + OwningRewritePatternList patterns; + ConversionTarget target(*context); + + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + patterns.insert(typeConverter, context); + + target.addDynamicallyLegalOp([&](GenericOp op) { + // If any of the operands or results cannot be detensored, the op is + // considered legal and won't be detensored. + return llvm::any_of( + op.getShapedOperandTypes(), [](ShapedType shapedType) { + assert(shapedType.isa()); + return !canBeDetensored(shapedType.cast()); + }); + }); + + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) + signalPassFailure(); + } +}; + +struct FuncDetensorize : public FuncDetensorizeBase { + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + if (failed(finalizeTensorRelatedConversion( + module, context))) + signalPassFailure(); + + // For the 0-D case, the op sequence tensor.from_elements -> + // linalg.tensor_reshape -> tensor.extract is inserted in some places. Apply + // a canonicalization pattern to get rid of such op sequences. + OwningRewritePatternList canonPatterns; + canonPatterns.insert(context); + + if (failed(applyPatternsAndFoldGreedily(module, std::move(canonPatterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createLinalgDetensorizePass() { + return std::make_unique(); +} + +std::unique_ptr mlir::createFuncDetensorizePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Transforms/Bufferize.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Utils.h" using namespace mlir; @@ -27,55 +28,8 @@ auto module = getOperation(); auto *context = &getContext(); - BufferizeTypeConverter typeConverter; - OwningRewritePatternList patterns; - ConversionTarget target(*context); - - populateFuncOpTypeConversionPattern(patterns, context, typeConverter); - target.addDynamicallyLegalOp([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()) && - typeConverter.isLegal(&op.getBody()); - }); - populateCallOpTypeConversionPattern(patterns, context, typeConverter); - target.addDynamicallyLegalOp( - [&](CallOp op) { return typeConverter.isLegal(op); }); - - populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context, - typeConverter); - target.addLegalOp(); - target.addDynamicallyLegalOp( - [&](ReturnOp op) { return typeConverter.isLegal(op); }); - // Mark terminators as legal if they have the ReturnLike trait or - // implement the BranchOpInterface and have valid types. If they do not - // implement the trait or interface, mark them as illegal no matter what. - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - // If it is not a terminator, ignore it. - if (!op->mightHaveTrait()) - return true; - // If it is not the last operation in the block, also ignore it. We do - // this to handle unknown operations, as well. - Block *block = op->getBlock(); - if (!block || &block->back() != op) - return true; - // ReturnLike operations have to be legalized with their parent. For - // return this is handled, for other ops they remain as is. - if (op->hasTrait()) - return true; - // All successor operands of branch like operations must be rewritten. - if (auto branchOp = dyn_cast(op)) { - for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) { - auto successorOperands = branchOp.getSuccessorOperands(p); - if (successorOperands.hasValue() && - !typeConverter.isLegal(successorOperands.getValue().getTypes())) - return false; - } - return true; - } - return false; - }); - - if (failed(applyFullConversion(module, target, std::move(patterns)))) + if (failed(finalizeTensorRelatedConversion( + module, context))) signalPassFailure(); } }; diff --git a/mlir/test/Dialect/Linalg/detensorized.mlir b/mlir/test/Dialect/Linalg/detensorized.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s -linalg-detensorize -canonicalize | FileCheck %s + +#map = affine_map<(d0) -> (d0)> + +func @detensor_simple(%arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> attributes {iree.module.export} { + %0 = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %arg2 : tensor<1xf32>, tensor<1xf32>) + outs(%0 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor<1xf32> + return %1: tensor<1xf32> +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor<1xf32>, %[[arg2:.*]]: tensor<1xf32>) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: return %[[new_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1xf32> attributes {iree.module.export} { + %0 = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %arg2 : tensor<1xf32>, tensor<1xf32>) + outs(%0 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor<1xf32> + + %3 = linalg.init_tensor [1] : tensor<1xf32> + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%arg1, %1 : tensor<1xf32>, tensor<1xf32>) + outs(%3 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<1xf32> + + %6 = linalg.init_tensor [1] : tensor<1xf32> + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%1, %4 : tensor<1xf32>, tensor<1xf32>) + outs(%6 : tensor<1xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor<1xf32> + + return %7: tensor<1xf32> +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor<1xf32>, %[[arg2:.*]]: tensor<1xf32>) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: return %[[new_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorized_0d.mlir b/mlir/test/Dialect/Linalg/detensorized_0d.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_0d.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt %s -linalg-detensorize -canonicalize | FileCheck %s + +#map = affine_map<() -> ()> + +func @detensor_simple(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + return %1: tensor +} +// CHECK-LABEL: func @detensor_simple +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] + +func @detensor_op_sequence(%arg1: tensor, %arg2: tensor) -> tensor attributes {iree.module.export} { + %0 = linalg.init_tensor [] : tensor + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + } -> tensor + + %3 = linalg.init_tensor [] : tensor + %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%arg1, %1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = mulf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + %6 = linalg.init_tensor [] : tensor + %7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} + ins(%1, %4 : tensor, tensor) + outs(%6 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %5 = divf %arg3, %arg4 : f32 + linalg.yield %5 : f32 + } -> tensor + + return %7: tensor +} +// CHECK-LABEL: func @detensor_op_sequence +// CHECK-SAME: (%[[arg1:.*]]: tensor, %[[arg2:.*]]: tensor) +// CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] +// CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] +// CHECK: %[[detensored_res:.*]] = addf %[[arg1_val]], %[[arg2_val]] +// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]] +// CHECK: %[[detensored_res2:.*]] = mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res3:.*]] = divf %[[detensored_res]], %[[detensored_res2]] +// CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] +// CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_reshape %[[new_tensor_res]] +// CHECK: return %[[reshaped_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorized_while.mlir b/mlir/test/Dialect/Linalg/detensorized_while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/detensorized_while.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -linalg-detensorize -func-detensorize -canonicalize | FileCheck %s + +func @main() -> tensor attributes {iree.module.export} { + %cst = constant dense<1> : tensor + %cst_0 = constant dense<3> : tensor + br ^bb1(%cst : tensor) +^bb1(%0: tensor): // 2 preds: ^bb0, ^bb2 + %1 = linalg.init_tensor [] : tensor + %2 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%0, %cst_0 : tensor, tensor) outs(%1 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i1): // no predecessors + %8 = cmpi slt, %arg0, %arg1 : i32 + linalg.yield %8 : i1 + } -> tensor + %3 = tensor.extract %2[] : tensor + cond_br %3, ^bb2(%0 : tensor), ^bb3(%0 : tensor) +^bb2(%4: tensor): // pred: ^bb1 + %5 = linalg.init_tensor [] : tensor + %6 = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%4, %4 : tensor, tensor) outs(%5 : tensor) { + ^bb0(%arg0: i32, %arg1: i32, %arg2: i32): // no predecessors + %8 = addi %arg0, %arg1 : i32 + linalg.yield %8 : i32 + } -> tensor + br ^bb1(%6 : tensor) +^bb3(%7: tensor): // pred: ^bb1 + return %7 : tensor +} +// CHECK-LABEL: func @main() -> i32 +// CHECK-DAG: %[[c1:.*]] = constant 1 +// CHECK-DAG: %[[c3:.*]] = constant 3 +// CHECK: ^[[bb1:.*]](%[[bb1_arg:.*]]: i32) +// CHECK: %[[cmp_res:.*]] = cmpi +// CHECK: cond_br %[[cmp_res]] +// CHECK: ^[[bb2:.*]](%[[bb2_arg:.*]]: i32) +// CHECK: %[[add_res:.*]] = addi +// CHECK: br +// CHECK: ^[[bb3:.*]](%[[bb3_arg:.*]]: i32) +// CHECK: return