diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -213,8 +213,8 @@ /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). -void interchange(PatternRewriter &rewriter, LinalgOp op, - ArrayRef interchangeVector); +void interchangeGenericOp(PatternRewriter &rewriter, GenericOp genericOp, + ArrayRef interchangeVector); /// Callback function type used to perform the allocation for the promoted /// `subView`. In `boundingSubViewsize` a best attempt is made to find the @@ -363,11 +363,11 @@ // Preconditions that ensure the corresponding transformation succeeds and can // be applied as a rewrite pattern. //===----------------------------------------------------------------------===// -/// Emits a `generic` or `indexed_generic` operation with the `indexing_maps` -/// and `iterator_types` permutated according to `permutation`. +/// Emits a `generic` operation with the `indexing_maps` and `iterator_types` +/// permutated according to `permutation`. LogicalResult -interchangeGenericLinalgOpPrecondition(Operation *op, - ArrayRef interchangeVector); +interchangeGenericOpPrecondition(GenericOp genericOp, + ArrayRef interchangeVector); /// Promote std.subviews feeding linalg operations. LogicalResult promoteSubviewsPrecondition(Operation *op, @@ -630,18 +630,18 @@ }; /// -/// Linalg interchange patterns. +/// Linalg generic interchage pattern. /// /// Apply the `interchange` transformation as a pattern. /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `interchange` for more details. -struct LinalgBaseInterchangePattern : public RewritePattern { - LinalgBaseInterchangePattern( - StringRef opName, MLIRContext *context, - ArrayRef interchangeVector, +struct GenericOpInterchangePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + GenericOpInterchangePattern( + MLIRContext *context, ArrayRef interchangeVector, LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override; private: @@ -651,16 +651,6 @@ SmallVector interchangeVector; }; -template -struct LinalgInterchangePattern : public LinalgBaseInterchangePattern { - LinalgInterchangePattern( - MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseInterchangePattern(OpTy::getOperationName(), context, - interchangeVector, filter, benefit) {} -}; - /// /// Linalg promotion patterns. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -32,68 +32,65 @@ using namespace mlir; using namespace mlir::linalg; -LogicalResult mlir::linalg::interchangeGenericLinalgOpPrecondition( - Operation *op, ArrayRef interchangeVector) { - // Transformation applies to generic ops only. - if (!isa(op)) - return failure(); - LinalgOp linalgOp = cast(op); +LogicalResult mlir::linalg::interchangeGenericOpPrecondition( + GenericOp genericOp, ArrayRef interchangeVector) { // Interchange vector must be non-empty and match the number of loops. if (interchangeVector.empty() || - linalgOp.getNumLoops() != interchangeVector.size()) + genericOp.getNumLoops() != interchangeVector.size()) return failure(); // Permutation map must be invertible. - if (!inversePermutation( - AffineMap::getPermutationMap(interchangeVector, op->getContext()))) + if (!inversePermutation(AffineMap::getPermutationMap(interchangeVector, + genericOp.getContext()))) return failure(); return success(); } -void mlir::linalg::interchange(PatternRewriter &rewriter, LinalgOp op, - ArrayRef interchangeVector) { +void mlir::linalg::interchangeGenericOp(PatternRewriter &rewriter, + GenericOp genericOp, + ArrayRef interchangeVector) { // 1. Compute the inverse permutation map. - MLIRContext *context = op.getContext(); + MLIRContext *context = genericOp.getContext(); AffineMap permutationMap = inversePermutation( AffineMap::getPermutationMap(interchangeVector, context)); assert(permutationMap && "expected permutation to be invertible"); - assert(interchangeVector.size() == op.getNumLoops() && + assert(interchangeVector.size() == genericOp.getNumLoops() && "expected interchange vector to have entry for every loop"); // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; - ArrayRef indexingMaps = op.indexing_maps().getValue(); - for (unsigned i = 0, e = op.getNumShapedOperands(); i != e; ++i) { + ArrayRef indexingMaps = genericOp.indexing_maps().getValue(); + for (unsigned i = 0, e = genericOp.getNumShapedOperands(); i != e; ++i) { AffineMap m = indexingMaps[i].cast().getValue(); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(AffineMapAttr::get(m)); } - op->setAttr(getIndexingMapsAttrName(), - ArrayAttr::get(context, newIndexingMaps)); + genericOp->setAttr(getIndexingMapsAttrName(), + ArrayAttr::get(context, newIndexingMaps)); // 3. Compute the interchanged iterator types. - ArrayRef itTypes = op.iterator_types().getValue(); + ArrayRef itTypes = genericOp.iterator_types().getValue(); SmallVector itTypesVector; llvm::append_range(itTypesVector, itTypes); applyPermutationToVector(itTypesVector, interchangeVector); - op->setAttr(getIteratorTypesAttrName(), - ArrayAttr::get(context, itTypesVector)); + genericOp->setAttr(getIteratorTypesAttrName(), + ArrayAttr::get(context, itTypesVector)); // 4. Transform the index operations by applying the permutation map. - if (op.hasIndexSemantics()) { + if (genericOp.hasIndexSemantics()) { // TODO: Remove the assertion and add a getBody() method to LinalgOp // interface once every LinalgOp has a body. - assert(op->getNumRegions() == 1 && - op->getRegion(0).getBlocks().size() == 1 && + assert(genericOp->getNumRegions() == 1 && + genericOp->getRegion(0).getBlocks().size() == 1 && "expected generic operation to have one block."); - Block &block = op->getRegion(0).front(); + Block &block = genericOp->getRegion(0).front(); OpBuilder::InsertionGuard guard(rewriter); for (IndexOp indexOp : llvm::make_early_inc_range(block.getOps())) { rewriter.setInsertionPoint(indexOp); SmallVector allIndices; - allIndices.reserve(op.getNumLoops()); - llvm::transform(llvm::seq(0, op.getNumLoops()), + allIndices.reserve(genericOp.getNumLoops()); + llvm::transform(llvm::seq(0, genericOp.getNumLoops()), std::back_inserter(allIndices), [&](uint64_t dim) { return rewriter.create(indexOp->getLoc(), dim); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -393,30 +393,26 @@ return success(); } -/// Linalg base interchange pattern. -mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( - StringRef opName, MLIRContext *context, - ArrayRef interchangeVector, LinalgTransformationFilter filter, - PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(filter), +/// Linalg generic interchange pattern. +mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( + MLIRContext *context, ArrayRef interchangeVector, + LinalgTransformationFilter filter, PatternBenefit benefit) + : OpRewritePattern(context, benefit), filter(filter), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} -LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) +LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( + GenericOp genericOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, genericOp))) return failure(); - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector))) + if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector))) return failure(); // TODO: figure out how this interplays with named ops. In particular this // should break the named op property. - rewriter.updateRootInPlace(op, [&]() { - interchange(rewriter, linalgOp, interchangeVector); + rewriter.updateRootInPlace(genericOp, [&]() { + interchangeGenericOp(rewriter, genericOp, interchangeVector); // New filter if specified. - filter.replaceLinalgTransformationFilter(rewriter, op); + filter.replaceLinalgTransformationFilter(rewriter, genericOp); }); return success(); } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -125,37 +125,6 @@ // CHECK-SAME: memref // CHECK-SAME: memref -#indexed_matmul_trait = { - args_in = 2, - args_out = 1, - indexing_maps = #matmul_accesses, - library_call = "linalg_matmul_indexed", - iterator_types = ["parallel", "parallel", "reduction"] -} -func @permute_generic_indexed( - %A: memref, - %B: memref, - %C: memref) { - linalg.indexed_generic #indexed_matmul_trait - ins(%A, %B : memref, - memref) - outs(%C : memref) { - ^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32): - %d = mulf %a, %b: f32 - %e = addf %c, %d: f32 - linalg.yield %e: f32 - } - return -} -// CHECK-LABEL: func @permute_generic_indexed -// CHECK: linalg.indexed_generic { -// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], -// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], -// CHECK-SAME: library_call = "linalg_matmul_indexed"} -// CHECK: memref, -// CHECK-SAME: memref -// CHECK-SAME: memref - func @matvec_perm(%A: memref, %x: memref, %y: memref) { diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -194,14 +194,9 @@ .addOpFilter()); //===--------------------------------------------------------------------===// - // Linalg generic permutation patterns. + // Linalg generic interchange pattern. //===--------------------------------------------------------------------===// - patterns.add>( - ctx, - /*interchangeVector=*/ArrayRef{1, 2, 0}, - LinalgTransformationFilter(ArrayRef{}, - Identifier::get("PERMUTED", ctx))); - patterns.add>( + patterns.add( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, LinalgTransformationFilter(ArrayRef{}, @@ -551,7 +546,7 @@ ArrayRef interchangeVector) { MLIRContext *context = funcOp.getContext(); RewritePatternSet interchangePattern(context); - interchangePattern.add>( + interchangePattern.add( context, interchangeVector, LinalgTransformationFilter(ArrayRef{}, Identifier::get("interchange", context)));