diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -146,6 +146,9 @@ InterfaceMethod<[{ Query whether the op has only MemRef input and outputs. }], "bool", "hasBufferSemantics">, + InterfaceMethod<[{ + Query whether the op has only RankedTensor input and outputs. + }], "bool", "hasTensorSemantics">, //========================================================================// // Other static interface methods. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -221,6 +221,16 @@ [](Value v) { return v.getType().isa(); }); } + /// Query whether the op has only tensor inputs and outputs. + bool hasTensorSemantics() { + return llvm::all_of( + getInputs(), + [](Value v) { return v.getType().isa(); }) && + llvm::all_of(this->getOperation()->getResults(), [](Value v) { + return v.getType().isa(); + }); + } + //==========================================================================// // Other static interface methods. //==========================================================================// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -71,6 +71,12 @@ const LinalgDependenceGraph &graph, OperationFolder *folder = nullptr); +/// Fuse linalg operation on tensors, where the result of the producer is used +/// as the operand of the consumer at position `consumerIdx`. +Optional fuseTensorOps(OpBuilder &b, LinalgOp producer, + LinalgOp consumer, unsigned consumerIdx, + OperationFolder *folder = nullptr); + /// Returns the linearized list of all view dimensions in a linalgOp. Applying /// the inverse, concatenated loopToOperandRangeMaps to this list allows the /// derivation of loop ranges for any linalgOp. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/STLExtras.h" @@ -318,6 +319,160 @@ return llvm::None; } +/// Checks if two Generic ops are fusible, when one is a producer and another is +/// a consumer (with the result of the producer being the `consumerIdx` operand +/// of the consumer). +static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, + unsigned consumerIdx) { + // Verify that the producer and consumer are ops on tensors. + if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) { + return false; + } + auto producerOp = dyn_cast(producer.getOperation()); + auto consumerOp = dyn_cast(consumer.getOperation()); + // Verify that the producer and consumers are generic ops, and only handle + // cases where the producer has a single return value that should be the same + // as argument at `consumerIdx` of the consumer. + if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || + producerOp.getResult(0) != consumerOp.getOperand(consumerIdx)) { + return false; + } + // Check that the producer has all "parallel" iterator types. + if (producerOp.getNumParallelLoops() != producerOp.getNumLoops()) { + return false; + } + // Get the consumer index map. The number of results of the consumer index map + // must match the number of loops of the producer. + AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); + if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) { + return false; + } + + // For now, only handle ops which have regions and does not use functions to + // specify the region. + if (producerOp.fun() || consumerOp.fun()) { + return false; + } + return true; +} + +/// Computes the indexing maps for arguments of a producer generic op when the +/// result of the producer is fused with the consumer. +/// - consumerIndexMap is the indexing_map for the argument in the consumer op +/// that is the result of the producer op. +/// - invProducerResultIndexMap is the inverse of the indexing_map for the +/// result in the producer op. +/// - producerArgIndexMap is the indexing_map of the argument of the producer +/// op. +/// The result is the indexing_map to use for the producer argument when the +/// producer and consumer ops are fused. +static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, + AffineMap invProducerResultIndexMap, + AffineMap producerArgIndexMap) { + // t1 is map from producer result tensor index -> producer arg tensor index. + auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); + // return is consumer loop -> producer arg tensor index, i.e. indexing_map for + // the producer argument in the fused operation. + return t1.compose(consumerIndexMap); +} + +Optional fuseTensorOps(OpBuilder &b, LinalgOp producer, + LinalgOp consumer, unsigned consumerIdx, + OperationFolder *folder = nullptr) { + if (!areTensorOpsFusible(producer, consumer, consumerIdx)) { + return {}; + } + MLIRContext *context = b.getContext(); + auto producerOp = cast(producer.getOperation()); + auto consumerOp = cast(consumer.getOperation()); + AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); + AffineMap invProducerResultIndexMap = + inversePermutation(producerOp.getOutputIndexingMap(0)); + + // Compute the fused op operandslist by replacing the operand corresponding to + // the result of the producer, with the operands of the producer. + unsigned fusedArgsIn = + producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; + auto fusedArgsOut = consumerOp.getNumOutputs(); + SmallVector fusedOperandsList(consumerOp.getOperands()); + fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); + fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); + fusedOperandsList.insert( + std::next(fusedOperandsList.begin(), consumerIdx), + producerOp.operand_begin(), + std::next(producerOp.operand_begin(), producerOp.getNumInputs())); + + // Compute the fused indexing_maps of the operands/results of the fused op. + SmallVector fusedIndexingMapAttrs; + fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); + fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), + consumerOp.indexing_maps().end()); + fusedIndexingMapAttrs.erase( + std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); + auto insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); + for (auto producerArgIndexAttr : + llvm::enumerate(producerOp.indexing_maps())) { + if (producerArgIndexAttr.index() == producerOp.getNumInputs()) { + break; + } + auto composedIndexMap = computeProducerArgMap( + consumerIndexMap, invProducerResultIndexMap, + producerArgIndexAttr.value().cast().getValue()); + insertPos = std::next(fusedIndexingMapAttrs.insert( + insertPos, AffineMapAttr::get(composedIndexMap))); + } + + // Generate the fused op. + SmallVector fusedOpResultTypes; + llvm::for_each(consumerOp.getResults(), [&fusedOpResultTypes](Value v) { + fusedOpResultTypes.push_back(v.getType()); + }); + auto fusedLinalgOp = b.create( + UnknownLoc::get(context), fusedOpResultTypes, fusedOperandsList, + b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), + b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), + /*doc=*/nullptr, + /*fun=*/nullptr, + /*library_call=*/nullptr); + + // Build the region of the fused op. + auto &fusedOpRegion = fusedLinalgOp.region(); + Block &producerOpBlock = producerOp.region().front(); + Block &consumerOpBlock = consumerOp.region().front(); + Block *fusedBlock = new Block(); + fusedOpRegion.push_back(fusedBlock); + BlockAndValueMapping mapper; + // Map the arguments for the unmodified args from the consumer. + for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { + if (consumerOpArg.index() == consumerIdx) { + // Map the arguments for the args from the producer. + for (auto producerOpArg : producerOpBlock.getArguments()) { + mapper.map(producerOpArg, + fusedBlock->addArgument(producerOpArg.getType())); + } + continue; + } + mapper.map(consumerOpArg.value(), + fusedBlock->addArgument(consumerOpArg.value().getType())); + } + + // Add operations from producer (except the yield operation) to the fused op. + for (auto &op : producerOpBlock.getOperations()) { + if (auto yieldOp = dyn_cast(op)) { + // Lookup the value the yield operation is mapped to. + Value yieldVal = yieldOp.getOperand(0); + auto clonedVal = mapper.lookup(yieldVal); + mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); + continue; + } + fusedBlock->push_back(op.clone(mapper)); + } + for (auto &op : consumerOpBlock.getOperations()) { + fusedBlock->push_back(op.clone(mapper)); + } + return cast(fusedLinalgOp.getOperation()); +} + static void fuseLinalgOpsGreedily(FuncOp f) { LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); @@ -359,6 +514,46 @@ } namespace { + +/// Patterns to fuse a generic op, with the producer of its operands. +struct FuseGenericTensorOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics()) { + return matchFailure(); + } + // Find the first operand that is defined by another generic op on tensors. + for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { + auto definingOp = + dyn_cast_or_null(operand.value().getDefiningOp()); + if (!definingOp || !definingOp.hasTensorSemantics()) { + return matchFailure(); + } + auto fusedOp = + fuseTensorOps(rewriter, cast(definingOp.getOperation()), + cast(op.getOperation()), operand.index()); + if (!fusedOp) { + return matchFailure(); + } + rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); + return matchSuccess(); + } + return matchFailure(); + } +}; + +/// Pass that fuses generic ops on tensors. Used only for testing. +struct FusionOfTensorOpsPass : public OperationPass { + void runOnOperation() override { + OwningRewritePatternList patterns; + Operation *op = getOperation(); + patterns.insert(op->getContext()); + applyPatternsGreedily(op->getRegions(), patterns); + }; +}; + struct LinalgFusionPass : public FunctionPass { void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } }; @@ -370,3 +565,7 @@ static PassRegistration pass("linalg-fusion", "Fuse operations in the linalg dialect"); + +static PassRegistration + tensorOpsPass("linalg-fusion-for-tensor-ops", + "Fuse operations on RankedTensorType in linalg dialect"); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops -split-input-file | FileCheck %s --dump-input-on-failure + +// CHECK-DAG: [[MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> +#map0 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: @add_mul_fusion +func @add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor +{ + %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = addf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + // CHECK-SAME: indexing_maps = {{\[}}[[MAP0]], [[MAP0]], [[MAP0]], [[MAP0]]{{\]}} + %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + // CHECK: ^{{[a-zA-Z0-9_]*}} + // CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]] + // CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]] + // CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]] + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + // CHECK: [[T1:%[a-zA-Z0-9_]*]] = addf [[ARG0]], [[ARG1]] + // CHECK-NOT: linalg.yield + // CHECK: mulf [[T1]], [[ARG2]] + // CHECK: linalg.yield + %3 = mulf %arg5, %arg6 : f32 + linalg.yield %3 : f32 + }: tensor, tensor -> tensor + return %2 : tensor +} + +// ----- + +// CHECK-DAG: [[MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: [[MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)> +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: @transpose_add_mul_fusion +func @transpose_add_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor +{ + %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = addf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + // CHECK-SAME: indexing_maps = {{\[}}[[MAP0]], [[MAP1]], [[MAP0]], [[MAP0]]{{\]}} + %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + %3 = mulf %arg5, %arg6 : f32 + linalg.yield %3 : f32 + }: tensor, tensor -> tensor + return %2 : tensor +} + +// ----- + +// CHECK-DAG: [[MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: [[MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d1, d0)> +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: @add_transpose_mul_fusion +func @add_transpose_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor +{ + %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = addf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + // CHECK-SAME: indexing_maps = {{\[}}[[MAP1]], [[MAP0]], [[MAP0]], [[MAP0]]{{\]}} + %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + %3 = mulf %arg5, %arg6 : f32 + linalg.yield %3 : f32 + }: tensor, tensor -> tensor + return %2 : tensor +} + +// ----- + +// CHECK-DAG: [[MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: [[MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: [[MAP2:#[a-zA-Z0-9_]*]] = affine_map<(d0) -> (d0)> +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @add_broadcast_mul_fusion +func @add_broadcast_mul_fusion(%arg0: tensor, %arg1 : tensor, %arg2 : tensor) -> tensor +{ + %0 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel"]} %arg0, %arg1 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %1 = addf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + }: tensor, tensor -> tensor + // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 + // CHECK-SAME: indexing_maps = {{\[}}[[MAP1]], [[MAP1]], [[MAP0]], [[MAP0]] + %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %arg2 { + ^bb0(%arg5: f32, %arg6: f32): // no predecessors + %3 = mulf %arg5, %arg6 : f32 + linalg.yield %3 : f32 + }: tensor, tensor -> tensor + return %2 : tensor +}