diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2674,8 +2674,18 @@ struct Range { Value offset, size, stride; }; - // TODO: retire `getRanges`. - SmallVector getRanges(); + /// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each + /// Range entry contains either the dynamic value or a ConstantIndexOp + /// constructed with `b` at location `loc`. + SmallVector getOrCreateRanges(OpBuilder &b, Location loc); + + /// A subview result type can be fully inferred from the source type and the + /// static representation of offsets, sizes and strides. Special sentinels + /// encode the dynamic case. + static Type inferSubViewResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the rank of the result MemRefType. unsigned getRank() { return getType().getRank(); } @@ -2748,7 +2758,6 @@ }]; let hasCanonicalizer = 1; - let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -184,15 +184,16 @@ unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); + OpBuilder b(consumer.getOperation()); + auto loc = consumer.getLoc(); // Iterate over dimensions identified by the producer map for `producerIdx`. // This defines a subset of the loop ranges that we need to complete later. for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); - loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; + loopRanges[posInProducerLoop] = + subView.getOrCreateRanges(b, loc)[en.index()]; } - OpBuilder b(consumer.getOperation()); - auto loc = consumer.getLoc(); // Iterate over all dimensions. For the dimensions not identified by the // producer map for `producerIdx`, we need to explicitly compute the view that // defines the loop ranges using the `producer`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -153,7 +153,7 @@ SmallVector fullSizes, partialSizes; fullSizes.reserve(rank); partialSizes.reserve(rank); - for (auto en : llvm::enumerate(subView.getRanges())) { + for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) { auto rank = en.index(); auto rangeValue = en.value(); // Try to extract a tight constant. @@ -169,7 +169,7 @@ dynamicBuffers, folder, alignment); auto fullLocalView = folded_std_view( folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer, - folded_std_constant_index(folder, 0), fullSizes); + zero, fullSizes); SmallVector zeros(fullSizes.size(), zero); SmallVector ones(fullSizes.size(), one); auto partialLocalView = diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2273,10 +2273,10 @@ /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. -static Type inferSubViewResultType(MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides) { +Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides) { unsigned rank = sourceMemRefType.getRank(); (void)rank; assert(staticOffsets.size() == rank && @@ -2472,7 +2472,7 @@ return failure(); // Verify result type against inferred type. - auto expectedType = inferSubViewResultType( + auto expectedType = SubViewOp::inferSubViewResultType( op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); @@ -2487,16 +2487,6 @@ << range.stride; } -SmallVector SubViewOp::getRanges() { - SmallVector res; - unsigned rank = getType().getRank(); - res.reserve(rank); - for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), - *(strides().begin() + i)}); - return res; -} - static unsigned getNumDynamicEntriesUpToIdx( ArrayAttr attr, llvm::function_ref isDynamic, unsigned idx) { return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx, @@ -2538,6 +2528,29 @@ return 1 + offsets().size() + sizes().size() + numDynamic; } +/// Return the list of SubViewOp::Range (i.e. offset, size, stride). Each Range +/// entry contains either the dynamic value or a ConstantIndexOp constructed +/// with `b` at location `loc`. +SmallVector SubViewOp::getOrCreateRanges(OpBuilder &b, + Location loc) { + SmallVector res; + unsigned rank = getType().getRank(); + res.reserve(rank); + for (unsigned idx = 0; idx < rank; ++idx) { + auto offset = isDynamicOffset(idx) + ? getDynamicOffset(idx) + : b.create(loc, getStaticOffset(idx)); + auto size = isDynamicSize(idx) + ? getDynamicSize(idx) + : b.create(loc, getStaticSize(idx)); + auto stride = isDynamicStride(idx) + ? getDynamicStride(idx) + : b.create(loc, getStaticStride(idx)); + res.emplace_back(Range{offset, size, stride}); + } + return res; +} + LogicalResult SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { if (!strides().empty()) @@ -2581,7 +2594,8 @@ } /// Pattern to rewrite a subview op with constant arguments. -class SubViewOpFolder final : public OpRewritePattern { +class SubViewOpConstantArgumentFolder final + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2716,27 +2730,63 @@ return true; } -OpFoldResult SubViewOp::fold(ArrayRef) { - auto folds = [](Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } - } - return folded ? success() : failure(); - }; +/// Pattern to rewrite a subview op with MemRefCast arguments. +/// This essentially pushes memref_cast past its consuming subview when +/// `canFoldIntoConsumerOp` is true. +/// +/// Example: +/// ``` +/// %0 = memref_cast %V : memref<16x16xf32> to memref +/// %1 = subview %0[0, 0][3, 4][1, 1] : +/// memref to memref<3x4xf32, offset:?, strides:[?, 1]> +/// ``` +/// is rewritten into: +/// ``` +/// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> +/// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to +/// memref<3x4xf32, offset:?, strides:[?, 1]> +/// ``` +class SubViewOpMemRefCastFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; - if (succeeded(folds(*this))) - return getResult(); - return {}; -} + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + // Any constant operand, just return to let SubViewOpConstantFolder kick in. + if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + auto castOp = subViewOp.source().getDefiningOp(); + if (!castOp) + return failure(); + + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on + /// the cast source operand type and the SubViewOp static information. This + /// is the resulting type if the MemRefCastOp were folded. + Type resultType = SubViewOp::inferSubViewResultType( + castOp.source().getType().cast(), + extractFromI64ArrayAttr(subViewOp.static_offsets()), + extractFromI64ArrayAttr(subViewOp.static_sizes()), + extractFromI64ArrayAttr(subViewOp.static_strides())); + Value newSubView = rewriter.create( + subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), + subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), + subViewOp.static_sizes(), subViewOp.static_strides()); + rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), + newSubView); + return success(); + } +}; void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -1,7 +1,5 @@ -// TODO: this needs a fix to land before being reactivated. -// RUN: ls -// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s -// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -941,3 +941,19 @@ return %1: memref } +// ----- + +// CHECK-DAG: #[[map0:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)> +// CHECK-DAG: #[[map1:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + +// CHECK-LABEL: func @memref_cast_folding_subview_static( +func @memref_cast_folding_subview_static(%V: memref<16x16xf32>, %a: index, %b: index) + -> memref<3x4xf32, offset:?, strides:[?, 1]> +{ + %0 = memref_cast %V : memref<16x16xf32> to memref + %1 = subview %0[0, 0][3, 4][1, 1] : memref to memref<3x4xf32, offset:?, strides:[?, 1]> + + // CHECK: subview{{.*}}: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> + // CHECK: memref_cast{{.*}}: memref<3x4xf32, #[[map0]]> to memref<3x4xf32, #[[map1]]> + return %1: memref<3x4xf32, offset:?, strides:[?, 1]> +}