diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -37,6 +37,25 @@ }]; } +def FuseOp : Op]> { + let description = [{ + Tiles the operations pointed to by the target handle and fuses their + producers greedily using the options provided as attributes. + }]; + + let arguments = + (ins PDL_Operation:$target, + DefaultValuedAttr:$tile_sizes, + DefaultValuedAttr:$tile_interchange); + let results = (outs PDL_Operation:$transformed, + Variadic:$loops); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + def GeneralizeOp : Op { @@ -136,7 +155,7 @@ def TileOp : Op, - DeclareOpInterfaceMethods]> { + FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { let description = [{ Indicates that the given `target` op should be tiled with the options provided as attributes. This transform generates a loop nest with a smaller diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -451,16 +451,15 @@ StringRef getName() override { return "transform.payload_ir"; } }; -/// Trait implementing the MemoryEffectOpInterface for single-operand zero- or -/// single-result operations that "consume" their operand and produce a new -/// result. +/// Trait implementing the MemoryEffectOpInterface for single-operand operations +/// that "consume" their operand and produce a new result. template class FunctionalStyleTransformOpTrait : public OpTrait::TraitBase { public: /// This op "consumes" the operand by reading and freeing it, "produces" the - /// result by allocating and writing it and reads/writes the payload IR in the - /// process. + /// results by allocating and writing it and reads/writes the payload IR in + /// the process. void getEffects(SmallVectorImpl &effects) { effects.emplace_back(MemoryEffects::Read::get(), this->getOperation()->getOperand(0), @@ -468,12 +467,10 @@ effects.emplace_back(MemoryEffects::Free::get(), this->getOperation()->getOperand(0), TransformMappingResource::get()); - if (this->getOperation()->getNumResults() == 1) { - effects.emplace_back(MemoryEffects::Allocate::get(), - this->getOperation()->getResult(0), + for (Value result : this->getOperation()->getResults()) { + effects.emplace_back(MemoryEffects::Allocate::get(), result, TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), - this->getOperation()->getResult(0), + effects.emplace_back(MemoryEffects::Write::get(), result, TransformMappingResource::get()); } effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); @@ -484,9 +481,6 @@ static LogicalResult verifyTrait(Operation *op) { static_assert(OpTy::template hasTrait(), "expected single-operand op"); - static_assert(OpTy::template hasTrait() || - OpTy::template hasTrait(), - "expected zero- or single-result op"); if (!op->getName().getInterface()) { op->emitError() << "FunctionalStyleTransformOpTrait should only be attached to ops " diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -92,6 +92,130 @@ return reportUnknownTransformError(target); } +//===----------------------------------------------------------------------===// +// FuseOp +//===----------------------------------------------------------------------===// + +/// Apply a tiling transformation to all payload ops and store both the +/// tiled operation as well as the created tile loops. +static LogicalResult +applyTilingToAll(Operation *transformOp, Value target, + ArrayRef tileSizes, + transform::TransformResults &transformResults, + transform::TransformState &state, + function_ref(LinalgOp)> applyFn) { + // Number of loops: Number of tiles sizes that are not zero. + size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); + // All payload ops. These should all be LinalgOps for now. + ArrayRef payloadOps = state.getPayloadOps(target); + + SmallVector tiledLinalgOps; + SmallVector> loopOps(numLoops); + for (unsigned int i = 0; i < numLoops; ++i) + loopOps[i].reserve(payloadOps.size()); + + for (Operation *target : payloadOps) { + auto linalgOp = dyn_cast(target); + if (!linalgOp) + return transformOp->emitError("only LinalgOps are supported"); + + FailureOr tiled = applyFn(linalgOp); + if (failed(tiled)) + return failure(); + + tiledLinalgOps.push_back(tiled->op); + if (tiled->loops.size() != numLoops) + // Not enough loops were generated. This usually means that the input size + // was smaller than the tiling size. + // TODO: LinalgTilingPattern should return failure(). + return failure(); + for (unsigned int i = 0; i < numLoops; ++i) + loopOps[i].push_back(tiled->loops[i]); + } + + transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); + for (unsigned int i = 0; i < numLoops; ++i) + transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + return success(); +} + +/// Parse a tiling-like operation that returns the tiled op as well as the +/// created tile loops. The function counts the non-zero tile sizes to compute +/// the number of results. +static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, + StringRef sizesAttrName) { + OpAsmParser::UnresolvedOperand targetOperand; + SMLoc opLoc = parser.getCurrentLocation(); + if (parser.parseOperand(targetOperand) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + Attribute sizesAttr = result.attributes.get(sizesAttrName); + if (!sizesAttr) + return parser.emitError(opLoc) + << "expected '" << sizesAttrName << "' attribute"; + auto sizesArrayAttr = sizesAttr.dyn_cast(); + if (!sizesArrayAttr) + return parser.emitError(opLoc) + << "'" << sizesAttrName << "' attribute must be an array"; + Type pdlOpType = parser.getBuilder().getType(); + size_t numExpectedLoops = + sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); + result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); + if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) + return failure(); + return success(); +} + +LogicalResult +transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, + mlir::transform::TransformState &state) { + LinalgTilingAndFusionOptions fusionOptions; + fusionOptions.tileSizes = extractI64Array(getTileSizes()); + fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); + + return applyTilingToAll( + getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, + state, [&](LinalgOp linalgOp) -> FailureOr { + LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(linalgOp); + FailureOr tileLoopNest = + pattern.returningMatchAndRewrite(linalgOp, rewriter); + if (failed(tileLoopNest)) + return failure(); + + TiledLinalgOp tiledLinalgOp; + tiledLinalgOp.op = tileLoopNest->getRootOp(); + tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(), + tileLoopNest->getLoopOps().end()}; + return tiledLinalgOp; + }); +} + +ParseResult transform::FuseOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseTileLikeOp( + parser, result, + transform::FuseOp::getTileSizesAttrName(result.name).getValue()); +} + +void transform::FuseOp::print(OpAsmPrinter &p) { + p << ' '; + p << getTarget(); + p.printOptionalAttrDict((*this)->getAttrs()); +} + +LogicalResult transform::FuseOp::verify() { + SmallVector permutation = extractI64Array(getTileInterchange()); + auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); + if (!std::is_permutation(sequence.begin(), sequence.end(), + permutation.begin(), permutation.end())) { + return emitOpError() << "expects interchange to be a permutation, found " + << getTileInterchange(); + } + return success(); +} + //===----------------------------------------------------------------------===// // GeneralizeOp //===----------------------------------------------------------------------===// @@ -274,49 +398,6 @@ // TileOp //===----------------------------------------------------------------------===// -/// Apply a tiling transformation to all payload ops and store both the -/// tiled operation as well as the created tile loops. -static LogicalResult -applyTilingToAll(Operation *transformOp, Value target, - ArrayRef tileSizes, - transform::TransformResults &transformResults, - transform::TransformState &state, - function_ref(LinalgOp)> applyFn) { - // Number of loops: Number of tiles sizes that are not zero. - size_t numLoops = tileSizes.size() - llvm::count(tileSizes, 0); - // All payload ops. These should all be LinalgOps for now. - ArrayRef payloadOps = state.getPayloadOps(target); - - SmallVector tiledLinalgOps; - SmallVector> loopOps(numLoops); - for (unsigned int i = 0; i < numLoops; ++i) - loopOps[i].reserve(payloadOps.size()); - - for (Operation *target : payloadOps) { - auto linalgOp = dyn_cast(target); - if (!linalgOp) - return transformOp->emitError("only LinalgOps are supported"); - - FailureOr tiled = applyFn(linalgOp); - if (failed(tiled)) - return failure(); - - tiledLinalgOps.push_back(tiled->op); - if (tiled->loops.size() != numLoops) - // Not enough loops were generated. This usually means that the input size - // was smaller than the tiling size. - // TODO: LinalgTilingPattern should return failure(). - return failure(); - for (unsigned int i = 0; i < numLoops; ++i) - loopOps[i].push_back(tiled->loops[i]); - } - - transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); - for (unsigned int i = 0; i < numLoops; ++i) - transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); - return success(); -} - LogicalResult transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { LinalgTilingOptions tilingOptions; @@ -337,27 +418,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { - StringRef sizesAttrName = TileOp::getSizesAttrName(result.name).getValue(); - OpAsmParser::UnresolvedOperand targetOperand; - SMLoc opLoc = parser.getCurrentLocation(); - if (parser.parseOperand(targetOperand) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - Attribute sizesAttr = result.attributes.get(sizesAttrName); - if (!sizesAttr) - return parser.emitError(opLoc) - << "expected '" << sizesAttrName << "' attribute"; - auto sizesArrayAttr = sizesAttr.dyn_cast(); - if (!sizesArrayAttr) - return parser.emitError(opLoc) - << "'" << sizesAttrName << "' attribute must be an array"; - Type pdlOpType = parser.getBuilder().getType(); - size_t numExpectedLoops = - sizesArrayAttr.size() - llvm::count(extractI64Array(sizesArrayAttr), 0); - result.addTypes(SmallVector(numExpectedLoops + 1, pdlOpType)); - if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) - return failure(); - return success(); + return parseTileLikeOp(parser, result, + TileOp::getSizesAttrName(result.name).getValue()); } void TileOp::print(OpAsmPrinter &p) { @@ -366,26 +428,6 @@ p.printOptionalAttrDict((*this)->getAttrs()); } -void TileOp::getEffects( - SmallVectorImpl> - &effects) { - // `target` arg is consumed and can no longer be used. - effects.emplace_back(MemoryEffects::Read::get(), getTarget(), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Free::get(), getTarget(), - TransformMappingResource::get()); - - for (Value r : getResults()) { - effects.emplace_back(MemoryEffects::Write::get(), r, - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Allocate::get(), r, - TransformMappingResource::get()); - } - - effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); -} - //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @fuse_unary +func.func @fuse_unary(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unary +func.func @fuse_unary(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + %1 = linalg.elemwise_binary ins(%0, %arg0 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %1 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.elemwise_binary"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]} + transform.loop.peel %loops#0 + } +}