diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2674,7 +2674,15 @@ struct Range { Value offset, size, stride; }; - SmallVector getRanges(); + SmallVector getOrCreateRanges(OpBuilder &b, Location loc); + + /// A subview result type can be fully inferred from the source type and the + /// static representation of offsets, sizes and strides. Special sentinels + /// encode the dynamic case. + static Type inferSubViewResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides); /// Return the rank of the result MemRefType. unsigned getRank() { return getType().getRank(); } @@ -2699,7 +2707,6 @@ }]; let hasCanonicalizer = 1; - let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -184,15 +184,16 @@ unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); + OpBuilder b(consumer.getOperation()); + auto loc = consumer.getLoc(); // Iterate over dimensions identified by the producer map for `producerIdx`. // This defines a subset of the loop ranges that we need to complete later. for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); - loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; + loopRanges[posInProducerLoop] = + subView.getOrCreateRanges(b, loc)[en.index()]; } - OpBuilder b(consumer.getOperation()); - auto loc = consumer.getLoc(); // Iterate over all dimensions. For the dimensions not identified by the // producer map for `producerIdx`, we need to explicitly compute the view that // defines the loop ranges using the `producer`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -46,11 +46,6 @@ #define DEBUG_TYPE "linalg-promotion" namespace { - -/// Helper struct that captures the information required to apply the -/// transformation on each op. This bridges the abstraction gap with the -/// user-facing API which exposes positional arguments to control which operands -/// are promoted. struct LinalgOpInstancePromotionOptions { LinalgOpInstancePromotionOptions(LinalgOp op, const LinalgPromotionOptions &options); @@ -88,7 +83,7 @@ /// Otherwise return size. static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc, Value size) { - auto affineMinOp = size.getDefiningOp(); + auto affineMinOp = dyn_cast_or_null(size.getDefiningOp()); if (!affineMinOp) return size; int64_t minConst = std::numeric_limits::max(); @@ -112,7 +107,7 @@ alignment_attr = IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue()); if (!dynamicBuffers) - if (auto cst = size.getDefiningOp()) + if (auto cst = dyn_cast_or_null(size.getDefiningOp())) return std_alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)), ValueRange{}, alignment_attr); @@ -153,7 +148,7 @@ SmallVector fullSizes, partialSizes; fullSizes.reserve(rank); partialSizes.reserve(rank); - for (auto en : llvm::enumerate(subView.getRanges())) { + for (auto en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) { auto rank = en.index(); auto rangeValue = en.value(); // Try to extract a tight constant. @@ -169,10 +164,10 @@ dynamicBuffers, folder, alignment); auto fullLocalView = folded_std_view( folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer, - folded_std_constant_index(folder, 0), fullSizes); + zero, fullSizes); SmallVector zeros(fullSizes.size(), zero); SmallVector ones(fullSizes.size(), one); - auto partialLocalView = + Value partialLocalView = folded_std_subview(folder, fullLocalView, zeros, partialSizes, ones); return PromotionInfo{buffer, fullLocalView, partialLocalView}; } @@ -285,11 +280,10 @@ // Check that at least one of the requested operands is indeed a subview. for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) { auto sv = isa_and_nonnull(en.value().getDefiningOp()); - if (sv) { + if (sv) if (!options.operandsToPromote.hasValue() || options.operandsToPromote->count(en.index())) return success(); - } } // TODO: Check all subviews requested are bound by a static constant. // TODO: Check that the total footprint fits within a given size. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2276,10 +2276,10 @@ /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. -static Type inferSubViewResultType(MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides) { +Type SubViewOp::inferSubViewResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides) { unsigned rank = sourceMemRefType.getRank(); (void)rank; assert(staticOffsets.size() == rank && @@ -2390,8 +2390,8 @@ ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceMemRefType = source.getType().cast(); - auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets, - staticSizes, staticStrides); + auto resultType = SubViewOp::inferSubViewResultType( + sourceMemRefType, staticOffsets, staticSizes, staticStrides); build(b, result, resultType, source, offsets, sizes, strides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); @@ -2475,7 +2475,7 @@ return failure(); // Verify result type against inferred type. - auto expectedType = inferSubViewResultType( + auto expectedType = SubViewOp::inferSubViewResultType( op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); @@ -2490,13 +2490,41 @@ << range.stride; } -SmallVector SubViewOp::getRanges() { +static Value getOrCreateSubViewPart(OpBuilder &b, Location loc, + ValueRange values, ArrayAttr attr, + llvm::function_ref isDynamic, + unsigned idx) { + auto cst = attr.getValue()[idx].cast().getInt(); + if (!isDynamic(cst)) + return b.create(loc, cst); + if (idx == 0) + return values[0]; + unsigned numDynamicEntriesUpToIdx = std::count_if( + attr.getValue().begin(), std::prev(attr.getValue().begin() + idx), + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + return values[idx - numDynamicEntriesUpToIdx]; +} + +SmallVector SubViewOp::getOrCreateRanges(OpBuilder &b, + Location loc) { SmallVector res; unsigned rank = getType().getRank(); res.reserve(rank); - for (unsigned i = 0; i < rank; ++i) - res.emplace_back(Range{*(offsets().begin() + i), *(sizes().begin() + i), - *(strides().begin() + i)}); + for (unsigned idx = 0; idx < rank; ++idx) { + res.emplace_back(Range{ + getOrCreateSubViewPart(b, loc, offsets(), + static_offsets().cast(), + ShapedType::isDynamicStrideOrOffset, idx), + getOrCreateSubViewPart(b, loc, sizes(), + static_sizes().cast(), + ShapedType::isDynamic, idx), + getOrCreateSubViewPart(b, loc, strides(), + static_strides().cast(), + ShapedType::isDynamicStrideOrOffset, idx), + }); + } return res; } @@ -2543,7 +2571,7 @@ } /// Pattern to rewrite a subview op with constant arguments. -class SubViewOpFolder final : public OpRewritePattern { +class SubViewOpConstantFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2678,27 +2706,54 @@ return true; } -OpFoldResult SubViewOp::fold(ArrayRef) { - auto folds = [](Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } +/// Pattern to rewrite a subview op with MemRefCast arguments. +class SubViewOpMemRefCastFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { + // Any constant operand, just return to let SubViewOpConstantFolder kick in. + if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + auto castOp = + dyn_cast_or_null(subViewOp.source().getDefiningOp()); + if (!castOp) + return failure(); + + Type resultType = SubViewOp::inferSubViewResultType( + castOp.source().getType().cast(), + extractFromI64ArrayAttr(subViewOp.static_offsets()), + extractFromI64ArrayAttr(subViewOp.static_sizes()), + extractFromI64ArrayAttr(subViewOp.static_strides())); + + if (resultType == subViewOp.getType()) { + rewriter.replaceOpWithNewOp( + subViewOp, resultType, castOp.source(), subViewOp.offsets(), + subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), + subViewOp.static_sizes(), subViewOp.static_strides()); + return success(); } - return folded ? success() : failure(); - }; - if (succeeded(folds(*this))) - return getResult(); - return {}; -} + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + Value newSubView = rewriter.create( + subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), + subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), + subViewOp.static_sizes(), subViewOp.static_strides()); + rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), + newSubView); + return success(); + } +}; void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -1,7 +1,5 @@ -// TODO: this needs a fix to land before being reactivated. -// RUN: ls -// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s -// R_UN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-1d | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-matmul-to-vector-patterns-tile-2d | FileCheck %s func @matmul(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>, %B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,