diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -28,8 +28,8 @@ // Create a new call to the type-canonicalized `LinalgOp::getLibraryCallName()` // function. The implementation of the function can be either in the same module // or in an externally linked library. -// This is a generic entry point for all LinalgOp, except for CopyOp and -// IndexedGenericOp, for which more specialized patterns are provided. +// This is a generic entry point for all LinalgOp, except for CopyOp, for which +// more specialized patterns are provided. class LinalgOpToLibraryCallRewrite : public OpInterfaceRewritePattern { public: @@ -58,16 +58,6 @@ PatternRewriter &rewriter) const override; }; -/// Conversion pattern specialization for IndexedGenericOp, has special handling -/// for the extra index operands. -class IndexedGenericOpToLibraryCallRewrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(IndexedGenericOp op, - PatternRewriter &rewriter) const override; -}; - /// Populate the given list with patterns that convert from Linalg to Standard. void populateLinalgToStandardConversionPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -26,12 +26,6 @@ static SmallVector extractOperandTypes(Operation *op) { SmallVector result; result.reserve(op->getNumOperands()); - if (auto indexedGenericOp = dyn_cast(op)) { - auto *ctx = op->getContext(); - auto numLoops = indexedGenericOp.getNumLoops(); - result.reserve(op->getNumOperands() + numLoops); - result.assign(numLoops, IndexType::get(ctx)); - } for (auto type : op->getOperandTypes()) { // The underlying descriptor type (e.g. LLVM) does not have layout // information. Canonicalizing the type at the level of std when going into @@ -103,7 +97,11 @@ LogicalResult mlir::linalg::LinalgOpToLibraryCallRewrite::matchAndRewrite( LinalgOp op, PatternRewriter &rewriter) const { // Only LinalgOp for which there is no specialized pattern go through this. - if (isa(op) || isa(op)) + if (isa(op)) + return failure(); + + // Canonicalize indexed generic operations before library call conversion. + if (isa(op)) return failure(); auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); @@ -167,31 +165,6 @@ return success(); } -LogicalResult -mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite( - IndexedGenericOp op, PatternRewriter &rewriter) const { - auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); - if (!libraryCallName) - return failure(); - - // TODO: Use induction variables values instead of zeros, when - // IndexedGenericOp is tiled. - auto zero = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - auto indexedGenericOp = cast(op); - auto numLoops = indexedGenericOp.getNumLoops(); - SmallVector operands; - operands.reserve(numLoops + op.getNumOperands()); - for (unsigned i = 0; i < numLoops; ++i) - operands.push_back(zero); - for (auto operand : op.getOperands()) - operands.push_back(operand); - rewriter.replaceOpWithNewOp( - op, libraryCallName.getValue(), TypeRange(), - createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), operands)); - return success(); -} - /// Populate the given list with patterns that convert from Linalg to Standard. void mlir::linalg::populateLinalgToStandardConversionPatterns( RewritePatternSet &patterns) { @@ -201,7 +174,6 @@ patterns.add< CopyOpToLibraryCallRewrite, CopyTransposeRewrite, - IndexedGenericOpToLibraryCallRewrite, LinalgOpToLibraryCallRewrite>(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir --- a/mlir/test/Dialect/Linalg/standard.mlir +++ b/mlir/test/Dialect/Linalg/standard.mlir @@ -95,25 +95,3 @@ } // CHECK-LABEL: func @matmul_vec_impl( // CHECK: call @external_outerproduct_matmul(%{{.*}}) : - -#indexed_matmul_trait = { - iterator_types = ["parallel", "parallel", "reduction"], - indexing_maps = #matmul_accesses, - library_call = "external_indexed_outerproduct_matmul" -} -func @matmul_vec_indexed(%A: !matrix_type_A, - %B: !matrix_type_B, - %C: !matrix_type_C) { - linalg.indexed_generic #indexed_matmul_trait - ins(%A, %B : !matrix_type_A, !matrix_type_B) - outs(%C : !matrix_type_C) { - ^bb0(%i: index, %j: index, %k: index, - %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): - %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B - linalg.yield %d: !vector_type_C - } - return -} -// CHECK-LABEL: func @matmul_vec_indexed( -// CHECK: %[[ZERO:.*]] = constant 0 : index -// CHECK: call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}})