diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -194,16 +194,17 @@ const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions); -/// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. -/// This is an in-place transformation controlled by `interchangeVector`. -/// An empty vector is interpreted as the identity permutation and the -/// transformation returns early. +/// Interchanges the `iterator_types` and `iterator_maps` dimensions and adapts +/// the index accesses of `op`. This is an in-place transformation controlled by +/// `interchangeVector`. An empty vector is interpreted as the identity +/// permutation and the transformation returns early. /// /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). -LinalgOp interchange(LinalgOp op, ArrayRef interchangeVector); +void interchange(PatternRewriter &rewriter, LinalgOp op, + ArrayRef interchangeVector); /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -34,17 +34,13 @@ LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( Operation *op, ArrayRef interchangeVector) { - if (interchangeVector.empty()) - return failure(); // Transformation applies to generic ops only. if (!isa(op)) return failure(); - LinalgOp linOp = cast(op); - // Transformation applies to buffers only. - if (!linOp.hasBufferSemantics()) - return failure(); - // Permutation must be applicable. - if (linOp.getIndexingMap(0).getNumInputs() != interchangeVector.size()) + LinalgOp linalgOp = cast(op); + // Interchange vector must be non-empty and match the number of loops. + if (interchangeVector.empty() || + linalgOp.getNumLoops() != interchangeVector.size()) return failure(); // Permutation map must be invertible. if (!inversePermutation( @@ -53,33 +49,56 @@ return success(); } -LinalgOp mlir::linalg::interchange(LinalgOp op, - ArrayRef interchangeVector) { - if (interchangeVector.empty()) - return op; - +void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op, + ArrayRef interchangeVector) { + // 1. Compute the inverse permutation map. MLIRContext *context = op.getContext(); - auto permutationMap = inversePermutation( + AffineMap permutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, context)); assert(permutationMap && "expected permutation to be invertible"); + assert(interchangeVector.size() == op.getNumLoops() && + "expected interchange vector to have entry for every loop"); + + // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; - auto indexingMaps = op.indexing_maps().getValue(); + ArrayRef indexingMaps = op.indexing_maps().getValue(); for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) { AffineMap m = indexingMaps[i].cast().getValue(); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(AffineMapAttr::get(m)); } - auto itTypes = op.iterator_types().getValue(); - SmallVector itTypesVector; - for (unsigned i = 0, e = itTypes.size(); i != e; ++i) - itTypesVector.push_back(itTypes[i]); - applyPermutationToVector(itTypesVector, interchangeVector); - op->setAttr(getIndexingMapsAttrName(), ArrayAttr::get(context, newIndexingMaps)); + + // 3. Compute the interchanged iterator types. + ArrayRef itTypes = op.iterator_types().getValue(); + SmallVector itTypesVector; + llvm::append_range(itTypesVector, itTypes); + applyPermutationToVector(itTypesVector, interchangeVector); op->setAttr(getIteratorTypesAttrName(), ArrayAttr::get(context, itTypesVector)); - return op; + // 4. Transform the index operations by applying the permutation map. + if (op.hasIndexSemantics()) { + // TODO: Remove the assertion and add a getBody() method to LinalgOp + // interface once every LinalgOp has a body. + assert(op->getNumRegions() == 1 && + op->getRegion(0).getBlocks().size() == 1 && + "expected generic operation to have one block."); + Block &block = op->getRegion(0).front(); + OpBuilder::InsertionGuard guard(rewriter); + for (IndexOp indexOp : + llvm::make_early_inc_range(block.getOps())) { + rewriter.setInsertionPoint(indexOp); + SmallVector allIndices; + allIndices.reserve(op.getNumLoops()); + llvm::transform(llvm::seq(0, op.getNumLoops()), + std::back_inserter(allIndices), [&](int64_t dim) { + return rewriter.create(indexOp->getLoc(), dim); + }); + rewriter.replaceOpWithNewOp( + indexOp, permutationMap.getSubMap(indexOp.dim()), allIndices); + } + } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -404,8 +404,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) + if (!linalgOp) return failure(); if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); @@ -415,7 +414,7 @@ // TODO: figure out how this interplays with named ops. In particular this // should break the named op property. rewriter.updateRootInPlace(op, [&]() { - interchange(linalgOp, interchangeVector); + interchange(rewriter, linalgOp, interchangeVector); // New filter if specified. filter.replaceLinalgTransformationFilter(rewriter, op); }); diff --git a/mlir/test/Dialect/Linalg/interchange.mlir b/mlir/test/Dialect/Linalg/interchange.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/interchange.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-interchange-pattern=4,0,3,1,2 -test-linalg-transform-patterns=test-interchange-pattern=1,3,4,2,0 | FileCheck --check-prefix=CANCEL-OUT %s + +#map0 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> + +func @interchange_generic_op(%arg0 : memref<1x2x3x4x5xindex>, %arg1 : memref<1x2x4xindex>) { + linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"]} + ins(%arg0 : memref<1x2x3x4x5xindex>) + outs(%arg1 : memref<1x2x4xindex>) { + ^bb0(%arg2 : index, %arg3 : index) : + %0 = linalg.index 0 : index + %1 = linalg.index 1 : index + %2 = linalg.index 4 : index + %3 = subi %0, %1 : index + %4 = addi %3, %2 : index + %5 = addi %4, %arg2 : index + linalg.yield %5 : index + } + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4, d2, d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d2)> +// CHECK: func @interchange_generic_op +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "parallel", "reduction"] +// CHECK-DAG: %[[IDX0:.+]] = linalg.index 1 : index +// CHECK-DAG: %[[IDX1:.+]] = linalg.index 3 : index +// CHECK-DAG: %[[IDX4:.+]] = linalg.index 0 : index +// CHECK: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index +// CHECK: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index +// CHECK: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index + +// CANCEL-OUT-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CANCEL-OUT-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> +// CANCEL-OUT: func @interchange_generic_op +// CANCEL-OUT: linalg.generic +// CANCEL-OUT-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CANCEL-OUT-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "reduction"] +// CANCEL-OUT-DAG: %[[IDX0:.+]] = linalg.index 0 : index +// CANCEL-OUT-DAG: %[[IDX1:.+]] = linalg.index 1 : index +// CANCEL-OUT-DAG: %[[IDX4:.+]] = linalg.index 4 : index +// CANCEL-OUT: %[[T0:.+]] = subi %[[IDX0]], %[[IDX1]] : index +// CANCEL-OUT: %[[T1:.+]] = addi %[[T0]], %[[IDX4]] : index +// CANCEL-OUT: %[[T2:.+]] = addi %[[T1]], %{{.*}} : index + + diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -91,6 +91,9 @@ *this, "tile-sizes-for-padding", llvm::cl::desc("Linalg tile sizes when tile+pad"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption testInterchangePattern{ + *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Test the interchange pattern.")}; }; } // end anonymous namespace @@ -540,6 +543,17 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } +static void applyInterchangePattern(FuncOp funcOp, + ArrayRef interchangeVector) { + MLIRContext *context = funcOp.getContext(); + RewritePatternSet interchangePattern(context); + interchangePattern.add>( + context, interchangeVector, + LinalgTransformationFilter(ArrayRef{}, + Identifier::get("interchange", context))); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { @@ -580,6 +594,8 @@ (void)linalg::hoistPaddingOnTensors(padTensorOp, testHoistPadding); }); } + if (testInterchangePattern.hasValue()) + return applyInterchangePattern(getFunction(), testInterchangePattern); } namespace mlir {