diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -23,6 +23,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" +#include "llvm/ADT/TypeSwitch.h" + using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; @@ -65,7 +67,7 @@ assert(op.getOperation()->getNumRegions() == 1 && "Expected single region op"); auto &b = ScopedContext::getBuilderRef(); - auto &block = op.region().front(); + auto &block = op.getOperation()->getRegion(0).front(); BlockAndValueMapping map; map.map(block.getArguments(), indexedValues); for (auto &op : block.without_terminator()) { @@ -102,8 +104,6 @@ makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; } -namespace { - /// Emits the MLIR for the scalar part of the generic op by: /// 1. Emitting load ops for each input and output view in order. This is /// achieved by applying the appropriate input or output map to the @@ -134,10 +134,9 @@ /// } /// } /// ``` -// TODO: need a LinalgStructuredOpInterface. -template -void emitScalarImplementation(ArrayRef allIvs, - LinalgStructuredOpType linalgOp) { +template +static void emitScalarImplementation(ArrayRef allIvs, + LinalgOp linalgOp) { assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto &b = ScopedContext::getBuilderRef(); @@ -150,7 +149,7 @@ auto attr = linalgOp.template getAttrOfType("symbol_source"); auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); if (attr) { - auto operand = linalgOp.getOperand(attr.getInt()); + auto operand = linalgOp.getOperation()->getOperand(attr.getInt()); auto shapedType = operand.getType().template cast(); allIvsPlusDims.reserve(allIvs.size() + shapedType.getRank()); for (unsigned idx = 0, e = shapedType.getRank(); idx < e; ++idx) @@ -190,7 +189,7 @@ } template -void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { +static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { assert(copyOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto nPar = copyOp.getNumParallelLoops(); @@ -211,7 +210,7 @@ } template -void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { +static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { assert(fillOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); @@ -224,8 +223,8 @@ } template -Value getConvOpInput(ConvOp convOp, StdIndexedValue im, - MutableArrayRef imIdx) { +static Value getConvOpInput(ConvOp convOp, StdIndexedValue im, + MutableArrayRef imIdx) { // TODO: add a level of indirection to linalg.generic. if (!convOp.padding()) return im(imIdx); @@ -311,8 +310,9 @@ } } -template -void emitScalarImplementation(ArrayRef allIvs, PoolingMaxOp op) { +template +static void emitPoolingMinMaxScalarImplementation(ArrayRef allIvs, + OpType op) { InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); // Emit scalar form. IndexedValueType output(op.output()); @@ -320,30 +320,34 @@ Value lhs = output(indices.outputs); Value rhs = input(indices.inputs); using edsc::op::sgt; - Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); - output(indices.outputs) = maxValue; + using edsc::op::slt; + Value value = std::is_same() + ? std_select(slt(lhs, rhs), lhs, rhs) + : std_select(sgt(lhs, rhs), lhs, rhs); + output(indices.outputs) = value; } template -void emitScalarImplementation(ArrayRef allIvs, PoolingMinOp op) { - InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); - // Emit scalar form. - IndexedValueType output(op.output()); - IndexedValueType input(op.input()); - Value lhs = output(indices.outputs); - Value rhs = input(indices.inputs); - using edsc::op::slt; - Value minValue = std_select(slt(lhs, rhs), lhs, rhs); - output(indices.outputs) = minValue; +static void emitScalarImplementation(ArrayRef allIvs, PoolingMaxOp op) { + emitPoolingMinMaxScalarImplementation(allIvs, + op); } + template -void emitScalarImplementation(ArrayRef allIvs, PoolingSumOp op) { +static void emitScalarImplementation(ArrayRef allIvs, PoolingMinOp op) { + emitPoolingMinMaxScalarImplementation(allIvs, + op); +} + +template +static void emitScalarImplementation(ArrayRef allIvs, PoolingSumOp op) { auto indices = getInputAndOutputIndices(allIvs, op); IndexedValueType input(op.input()), output(op.output()); // Emit scalar form. output(indices.outputs) += input(indices.inputs); } + /// Emits the MLIR for the scalar part of the indexed generic op by: /// 1. Emitting load ops for each input and output view in order. This is /// achieved by applying the appropriate input or output map to the @@ -422,15 +426,16 @@ indexing, outputBuffers); } -template -Optional linalgOpToLoopsImpl(Operation *op, OpBuilder &builder) { +template +static Optional linalgOpToLoopsImpl(Operation *op, + OpBuilder &builder) { using IndexedValueTy = typename GenerateLoopNest::IndexedValueTy; ScopedContext scope(builder, op->getLoc()); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). - auto linalgOp = cast(op); + auto linalgOp = cast(op); assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto mapsRange = @@ -447,7 +452,12 @@ [&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector { assert(iterArgs.empty() && "unexpected iterArgs"); allIvs.append(ivs.begin(), ivs.end()); - emitScalarImplementation(allIvs, linalgOp); + llvm::TypeSwitch(op) + .Case([&](auto op) { + emitScalarImplementation(allIvs, op); + }) + .Default([&](Operation *op) { assert(false && "unexpected op"); }); return scf::ValueVector{}; }); // Number of loop ops might be different from the number of ivs since some @@ -467,32 +477,38 @@ return loops; } -template +namespace { +template class LinalgRewritePattern : public RewritePattern { public: - explicit LinalgRewritePattern(MLIRContext *context) - : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} + LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (!linalgOpToLoopsImpl(op, rewriter)) + if (!isa(op)) + return failure(); + if (!linalgOpToLoopsImpl(op, rewriter)) return failure(); rewriter.eraseOp(op); return success(); } }; -template -void insertOnePattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert>(ctx); -} +struct FoldAffineOp; +} // namespace -template -void insertPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) { - (void)std::initializer_list{ - 0, (insertOnePattern(patterns, ctx), 0)...}; +template +static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { + OwningRewritePatternList patterns; + patterns.insert>(); + DimOp::getCanonicalizationPatterns(patterns, context); + AffineApplyOp::getCanonicalizationPatterns(patterns, context); + patterns.insert(context); + // Just apply the patterns greedily. + applyPatternsAndFoldGreedily(funcOp, patterns); } +namespace { /// Local folding pattern for AffineApplyOp that we can apply greedily. /// This replaces AffineApplyOp by the proper value in cases where the /// associated map is trivial. @@ -529,38 +545,20 @@ return failure(); } }; -} // namespace - -template -static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) { - OwningRewritePatternList patterns; - // Canonicalization and folding patterns applied greedily allow cleaning up - // the emitted IR on the fly. - // TODO: fold view and subview ops? - insertPatterns(patterns, context); - DimOp::getCanonicalizationPatterns(patterns, context); - AffineApplyOp::getCanonicalizationPatterns(patterns, context); - patterns.insert(context); - // Just apply the patterns greedily. - applyPatternsAndFoldGreedily(funcOp, patterns); -} - -namespace { struct LowerToAffineLoops : public LinalgLowerToAffineLoopsBase { void runOnFunction() override { lowerLinalgToLoopsImpl(getFunction(), &getContext()); } }; + struct LowerToLoops : public LinalgLowerToLoopsBase { void runOnFunction() override { lowerLinalgToLoopsImpl(getFunction(), &getContext()); } }; + struct LowerToParallelLoops : public LinalgLowerToParallelLoopsBase { void runOnFunction() override { @@ -583,60 +581,6 @@ return std::make_unique(); } -// TODO: gradually remove this layer as more ops become "named". -template -static Optional linalgOpToLoopsImplSwitch(Operation *op, - OpBuilder &builder) { - assert(isa(op) && "LinalgOp expected"); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - - // TODO: Cases below are generic and need a LinalgStructuredOpInterface. - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); - llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); -} - SmallVector mlir::linalg::emitLoopRanges(OpBuilder &b, Location loc, AffineMap map, ValueRange viewSizes) { @@ -705,7 +649,7 @@ template Optional mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op) { - return linalgOpToLoopsImplSwitch(op, builder); + return linalgOpToLoopsImpl(op, builder); } template Optional