diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -19,6 +19,7 @@ std::unique_ptr> createConvertElementwiseToLinalgPass(); std::unique_ptr> createLinalgFoldUnitExtentDimsPass(); +std::unique_ptr> createLinalgInterchangePass(); std::unique_ptr createLinalgFusionOfTensorOpsPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -36,6 +36,21 @@ ]; } +def LinalgInterchange : FunctionPass<"linalg-interchange"> { + let summary = "Interchange the iteration domain basis of linalg operations"; + let constructor = "mlir::createLinalgInterchangePass()"; + let options = [ + ListOption<"interchangeVector", "interchange-vector", "unsigned", + "Permute the dimensions of the iteration domain following the " + "the given interchange vector", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> + ]; + let dependentDialects = [ + "linalg::LinalgDialect", + "AffineDialect" + ]; +} + def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> { let summary = "Fuse operations on RankedTensorType in linalg dialect"; let constructor = "mlir::createLinalgFusionOfTensorOpsPass()"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -195,15 +195,17 @@ const LinalgTilingOptions &tilingOptions); /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. -/// This is an in-place transformation controlled by `interchangeVector`. -/// An empty vector is interpreted as the identity permutation and the -/// transformation returns early. +/// Additionally, the updates the dimensions of all index op accesses in the +/// body of transformed operation. This is an in-place transformation controlled +/// by `interchangeVector`. An empty vector is interpreted as the identity +/// permutation and the transformation returns early. /// /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). -LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); +void interchange(PatternRewriter &rewriter, LinalgOp op, + ArrayRef interchangeVector); /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -10,8 +10,10 @@ // //===----------------------------------------------------------------------===// +#include "PassDetail.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" @@ -23,6 +25,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include @@ -34,17 +37,13 @@ LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( Operation *op, ArrayRef interchangeVector) { - if (interchangeVector.empty()) - return failure(); // Transformation applies to generic ops only. if (!isa(op)) return failure(); - LinalgOp linOp = cast(op); - // Transformation applies to buffers only. - if (!linOp.hasBufferSemantics()) - return failure(); - // Permutation must be applicable. - if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) + LinalgOp linalgOp = cast(op); + // Interchange vector must be non-empty and match the number of loops. + if (interchangeVector.empty() || + linalgOp.getNumLoops() != interchangeVector.size()) return failure(); // Permutation map must be invertible. if (!inversePermutation( @@ -53,33 +52,83 @@ return success(); } -LinalgOp mlir::linalg::interchange(LinalgOp op, - ArrayRef interchangeVector) { - if (interchangeVector.empty()) - return op; - +void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op, + ArrayRef interchangeVector) { + // 1. Compute the inverse permutation map. MLIRContext *context = op.getContext(); - auto permutationMap = inversePermutation( + AffineMap permutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, context)); assert(permutationMap && "expected permutation to be invertible"); + assert(interchangeVector.size() == op.getNumLoops() && + "expected interchange vector to have entry for every loop"); + + // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; - auto indexingMaps = op.indexing_maps().getValue(); + ArrayRef indexingMaps = op.indexing_maps().getValue(); for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) { AffineMap m = indexingMaps[i].cast().getValue(); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(AffineMapAttr::get(m)); } - auto itTypes = op.iterator_types().getValue(); - SmallVector itTypesVector; - for (unsigned i = 0, e = itTypes.size(); i != e; ++i) - itTypesVector.push_back(itTypes[i]); - applyPermutationToVector(itTypesVector, interchangeVector); - op->setAttr(getIndexingMapsAttrName(), ArrayAttr::get(context, newIndexingMaps)); + + // 3. Compute the interchanged iterator types. + ArrayRef itTypes = op.iterator_types().getValue(); + SmallVector itTypesVector; + llvm::append_range(itTypesVector, itTypes); + applyPermutationToVector(itTypesVector, interchangeVector); op->setAttr(getIteratorTypesAttrName(), ArrayAttr::get(context, itTypesVector)); - return op; + // 4. Transform the index operations by applying the permutation map. + if (op.hasIndexSemantics()) { + assert(op->getNumRegions() == 1 && + op->getRegion(0).getBlocks().size() == 1 && + "expected generic operation to have one block."); + Block &block = op->getRegion(0).front(); + for (IndexOp indexOp : + llvm::make_early_inc_range(block.getOps())) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(indexOp); + SmallVector allIndices; + allIndices.reserve(op.getNumLoops()); + llvm::transform(llvm::seq(0, op.getNumLoops()), + std::back_inserter(allIndices), [&](int64_t dim) { + return rewriter.create(indexOp.getLoc(), dim); + }); + rewriter.replaceOpWithNewOp( + indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); + } + } +} + +namespace { + +/// Pass that interchanges the iteration domain basis of linalg operations. +struct LinalgInterchangePass + : public LinalgInterchangeBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *context = funcOp.getContext(); + + // Apply the interchange pattern to all generic operations. + RewritePatternSet patterns(context); + patterns.add>( + context, interchangeVector, + LinalgTransformationFilter(ArrayRef{}, + Identifier::get("interchange", context))); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + + // Cleanup all transformation markers. + funcOp.walk([](LinalgOp op) { + op->removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); + } +}; +} // namespace + +std::unique_ptr> mlir::createLinalgInterchangePass() { + return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -408,8 +408,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!linalgOp) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -419,7 +418,7 @@ // TODO: figure out how this interplays with named ops. In particular this // should break the named op property. rewriter.updateRootInPlace(op, [&]() { - interchange(linalgOp, interchangeVector); + interchange(rewriter, linalgOp, interchangeVector); // New filter if specified. filter.replaceLinalgTransformationFilter(rewriter, op); }); diff --git a/mlir/test/Dialect/Linalg/interchange.mlir b/mlir/test/Dialect/Linalg/interchange.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/interchange.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -linalg-interchange="interchange-vector=4,0,3,1,2" | FileCheck %s +// RUN: mlir-opt %s -linalg-interchange="interchange-vector=4,0,3,1,2" -linalg-interchange="interchange-vector=1,3,4,2,0" | FileCheck --check-prefix=CANCEL-OUT %s + +#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> + +func @interchange_generic_op(%arg0 : memref<1x2x3x4x5xindex>, %arg1 : memref<1x2x4xindex>) { + linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]} + ins(%arg0 : memref<1x2x3x4x5xindex>) + outs(%arg1 : memref<1x2x4xindex>) { + ^bb0(%arg2 : index, %arg3 : index) : + %0 = linalg.index 0 : index + %1 = linalg.index 1 : index + %2 = linalg.index 4 : index + %3 = subi %0, %1 : index + %4 = addi %3, %2 : index + %5 = addi %4, %arg2 : index + linalg.yield %5 : index + } + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4, d2, d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d2)> +// CHECK: func @interchange_generic_op +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "parallel", "reduction"] +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[IDX4:.+]] = linalg.index 0 : index +// CHECK: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index +// CHECK: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index +// CHECK: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index + +// CANCEL-OUT-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CANCEL-OUT-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CANCEL-OUT: func @interchange_generic_op +// CANCEL-OUT: linalg.generic +// CANCEL-OUT-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CANCEL-OUT-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"] +// CANCEL-OUT-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CANCEL-OUT-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CANCEL-OUT-DAG: %[[IDX4:.+]] = linalg.index 4 : index +// CANCEL-OUT: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index +// CANCEL-OUT: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index +// CANCEL-OUT: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index + +