diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -305,14 +305,14 @@ }; /// Given an `originalShape` and a `reducedShape` assumed to be a subset of -/// `originalShape` with some `1` entries erased, return the vector of booleans -/// that specifies which of the entries of `originalShape` are keep to obtain +/// `originalShape` with some `1` entries erased, return the set of indices +/// that specifies which of the entries of `originalShape` are dropped to obtain /// `reducedShape`. The returned mask can be applied as a projection to /// `originalShape` to obtain the `reducedShape`. This mask is useful to track /// which dimensions must be kept when e.g. compute MemRef strides under /// rank-reducing operations. Return None if reducedShape cannot be obtained /// by dropping only `1` entries in `originalShape`. -llvm::Optional> +llvm::Optional> computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape); diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -127,6 +127,12 @@ AffineExpr replaceDimsAndSymbols(ArrayRef dimReplacements, ArrayRef symReplacements) const; + /// Dim-only version of replaceDimsAndSymbols. + AffineExpr replaceDims(ArrayRef dimReplacements) const; + + /// Symbol-only version of replaceDimsAndSymbols. + AffineExpr replaceSymbols(ArrayRef symReplacements) const; + /// Sparse replace method. Replace `expr` by `replacement` and return the /// modified expression tree. AffineExpr replace(AffineExpr expr, AffineExpr replacement) const; diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -18,6 +18,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" namespace mlir { @@ -311,6 +312,20 @@ /// Simplifies an affine map by simplifying its underlying AffineExpr results. AffineMap simplifyAffineMap(AffineMap map); +/// Drop the dims that are not used. +AffineMap compressUnusedDims(AffineMap map); + +/// Drop the dims that are not listed in `unusedDims`. +AffineMap compressDims(AffineMap map, + const llvm::SmallDenseSet &unusedDims); + +/// Drop the symbols that are not used. +AffineMap compressUnusedSymbols(AffineMap map); + +/// Drop the symbols that are not listed in `unusedSymbols`. +AffineMap compressSymbols(AffineMap map, + const llvm::SmallDenseSet &unusedSymbols); + /// Returns a map with the same dimension and symbol count as `map`, but whose /// results are the unique affine expressions of `map`. AffineMap removeDuplicateExprs(AffineMap map); @@ -390,8 +405,11 @@ /// 3) map : affine_map<(d0, d1, d2) -> (d0, d1)> /// projected_dimensions : {1} /// result : affine_map<(d0, d1) -> (d0, 0)> -AffineMap getProjectedMap(AffineMap map, - ArrayRef projectedDimensions); +/// +/// This function also compresses unused symbols away. +AffineMap +getProjectedMap(AffineMap map, + const llvm::SmallDenseSet &projectedDimensions); inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { map.print(os); @@ -402,7 +420,8 @@ namespace llvm { // AffineExpr hash just like pointers -template <> struct DenseMapInfo { +template <> +struct DenseMapInfo { static mlir::AffineMap getEmptyKey() { auto pointer = llvm::DenseMapInfo::getEmptyKey(); return mlir::AffineMap(static_cast(pointer)); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -566,6 +566,10 @@ /// Return true if the layout for `t` is compatible with strided semantics. bool isStrided(MemRefType t); +/// Return the layout map in strided linear layout AffineMap form. +/// Return null if the layout is not compatible with a strided layout. +AffineMap getStridedLinearLayoutMap(MemRefType t); + } // end namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3277,7 +3277,7 @@ auto inferredShape = inferredType.getShape(); size_t inferredShapeRank = inferredShape.size(); size_t resultShapeRank = shape.size(); - SmallVector mask = + llvm::SmallDenseSet unusedDims = computeRankReductionMask(inferredShape, shape).getValue(); // Extract strides needed to compute offset. @@ -3318,7 +3318,7 @@ "expected sizes and strides of equal length"); for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { - if (!mask[i]) + if (unusedDims.contains(i)) continue; // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -536,10 +536,10 @@ /// Prune all dimensions that are of reduction iterator type from `map`. static AffineMap pruneReductionDimsFromMap(ArrayRef iteratorTypes, AffineMap map) { - SmallVector projectedDims; + llvm::SmallDenseSet projectedDims; for (auto attr : llvm::enumerate(iteratorTypes)) { if (!isParallelIterator(attr.value())) - projectedDims.push_back(attr.index()); + projectedDims.insert(attr.index()); } return getProjectedMap(map, projectedDims); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2957,35 +2957,44 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } -llvm::Optional> +/// Given an `originalShape` and a `reducedShape` assumed to be a subset of +/// `originalShape` with some `1` entries erased, return the set of indices +/// that specifies which of the entries of `originalShape` are dropped to obtain +/// `reducedShape`. The returned mask can be applied as a projection to +/// `originalShape` to obtain the `reducedShape`. This mask is useful to track +/// which dimensions must be kept when e.g. compute MemRef strides under +/// rank-reducing operations. Return None if reducedShape cannot be obtained +/// by dropping only `1` entries in `originalShape`. +llvm::Optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape) { size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); - SmallVector mask(originalRank); + llvm::SmallDenseSet unusedDims; unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { - // Skip matching dims greedily. - mask[originalIdx] = - (reducedIdx < reducedRank) && - (originalShape[originalIdx] == reducedShape[reducedIdx]); - if (mask[originalIdx]) + // Greedily insert `originalIdx` if no match. + if (reducedIdx < reducedRank && + originalShape[originalIdx] == reducedShape[reducedIdx]) { reducedIdx++; - // 1 is the only non-matching allowed. - else if (originalShape[originalIdx] != 1) - return {}; - } + continue; + } + unusedDims.insert(originalIdx); + // If no match on `originalIdx`, the `originalShape` at this dimension + // must be 1, otherwise we bail. + if (originalShape[originalIdx] != 1) + return llvm::None; + } + // The whole reducedShape must be scanned, otherwise we bail. if (reducedIdx != reducedRank) - return {}; - - return mask; + return llvm::None; + return unusedDims; } enum SubViewVerificationResult { Success, RankTooLarge, SizeMismatch, - StrideMismatch, ElemTypeMismatch, MemSpaceMismatch, AffineMapMismatch @@ -2994,8 +3003,9 @@ /// Checks if `original` Type type can be rank reduced to `reduced` type. /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. -static SubViewVerificationResult isRankReducedType(Type originalType, - Type candidateReducedType) { +static SubViewVerificationResult +isRankReducedType(Type originalType, Type candidateReducedType, + std::string *errMsg = nullptr) { if (originalType == candidateReducedType) return SubViewVerificationResult::Success; if (!originalType.isa() && !originalType.isa()) @@ -3019,13 +3029,17 @@ if (candidateReducedRank > originalRank) return SubViewVerificationResult::RankTooLarge; - auto optionalMask = + auto optionalUnusedDimsMask = computeRankReductionMask(originalShape, candidateReducedShape); // Sizes cannot be matched in case empty vector is returned. - if (!optionalMask.hasValue()) + if (!optionalUnusedDimsMask.hasValue()) return SubViewVerificationResult::SizeMismatch; + if (originalShapedType.getElementType() != + candidateReducedShapedType.getElementType()) + return SubViewVerificationResult::ElemTypeMismatch; + // We are done for the tensor case. if (originalType.isa()) return SubViewVerificationResult::Success; @@ -3033,74 +3047,54 @@ // Strided layout logic is relevant for MemRefType only. MemRefType original = originalType.cast(); MemRefType candidateReduced = candidateReducedType.cast(); - MLIRContext *c = original.getContext(); - int64_t originalOffset, candidateReducedOffset; - SmallVector originalStrides, candidateReducedStrides, keepStrides; - SmallVector keepMask = optionalMask.getValue(); - (void)getStridesAndOffset(original, originalStrides, originalOffset); - (void)getStridesAndOffset(candidateReduced, candidateReducedStrides, - candidateReducedOffset); - - // Filter strides based on the mask and check that they are the same - // as candidateReduced ones. - unsigned candidateReducedIdx = 0; - for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { - if (keepMask[originalIdx]) { - if (originalStrides[originalIdx] != - candidateReducedStrides[candidateReducedIdx++]) - return SubViewVerificationResult::StrideMismatch; - keepStrides.push_back(originalStrides[originalIdx]); - } - } - - if (original.getElementType() != candidateReduced.getElementType()) - return SubViewVerificationResult::ElemTypeMismatch; - if (original.getMemorySpace() != candidateReduced.getMemorySpace()) return SubViewVerificationResult::MemSpaceMismatch; - // reducedMap is obtained by projecting away the dimensions inferred from - // matching the 1's positions in candidateReducedType. - auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c); - - MemRefType expectedReducedType = MemRefType::get( - candidateReduced.getShape(), candidateReduced.getElementType(), - reducedMap, candidateReduced.getMemorySpace()); - expectedReducedType = canonicalizeStridedLayout(expectedReducedType); - - if (expectedReducedType != canonicalizeStridedLayout(candidateReduced)) + llvm::SmallDenseSet unusedDims = optionalUnusedDimsMask.getValue(); + auto inferredType = + getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); + AffineMap candidateLayout; + if (candidateReduced.getAffineMaps().empty()) + candidateLayout = getStridedLinearLayoutMap(candidateReduced); + else + candidateLayout = candidateReduced.getAffineMaps().front(); + if (inferredType != candidateLayout) { + if (errMsg) { + llvm::raw_string_ostream os(*errMsg); + os << "inferred type: " << inferredType; + } return SubViewVerificationResult::AffineMapMismatch; - + } return SubViewVerificationResult::Success; } template static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, - OpTy op, Type expectedType) { + OpTy op, Type expectedType, + StringRef errMsg = "") { auto memrefType = expectedType.cast(); switch (result) { case SubViewVerificationResult::Success: return success(); case SubViewVerificationResult::RankTooLarge: return op.emitError("expected result rank to be smaller or equal to ") - << "the source rank."; + << "the source rank. " << errMsg; case SubViewVerificationResult::SizeMismatch: return op.emitError("expected result type to be ") << expectedType - << " or a rank-reduced version. (mismatch of result sizes)"; - case SubViewVerificationResult::StrideMismatch: - return op.emitError("expected result type to be ") - << expectedType - << " or a rank-reduced version. (mismatch of result strides)"; + << " or a rank-reduced version. (mismatch of result sizes) " + << errMsg; case SubViewVerificationResult::ElemTypeMismatch: return op.emitError("expected result element type to be ") - << memrefType.getElementType(); + << memrefType.getElementType() << errMsg; case SubViewVerificationResult::MemSpaceMismatch: - return op.emitError("expected result and source memory spaces to match."); + return op.emitError("expected result and source memory spaces to match.") + << errMsg; case SubViewVerificationResult::AffineMapMismatch: return op.emitError("expected result type to be ") << expectedType - << " or a rank-reduced version. (mismatch of result affine map)"; + << " or a rank-reduced version. (mismatch of result affine map) " + << errMsg; } llvm_unreachable("unexpected subview verification result"); } @@ -3126,8 +3120,9 @@ extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); - auto result = isRankReducedType(expectedType, subViewType); - return produceSubViewErrorMsg(result, op, expectedType); + std::string errMsg; + auto result = isRankReducedType(expectedType, subViewType, &errMsg); + return produceSubViewErrorMsg(result, op, expectedType, errMsg); } raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -92,6 +92,15 @@ llvm_unreachable("Unknown AffineExpr"); } +AffineExpr AffineExpr::replaceDims(ArrayRef dimReplacements) const { + return replaceDimsAndSymbols(dimReplacements, {}); +} + +AffineExpr +AffineExpr::replaceSymbols(ArrayRef symReplacements) const { + return replaceDimsAndSymbols({}, symReplacements); +} + /// Replace symbols[0 .. numDims - 1] by symbols[shift .. shift + numDims - 1]. AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift) const { SmallVector dims; diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -420,6 +420,71 @@ llvm::seq(getNumResults() - numResults, getNumResults()))); } +AffineMap mlir::compressDims(AffineMap map, + const llvm::SmallDenseSet &unusedDims) { + unsigned numDims = 0; + SmallVector dimReplacements; + dimReplacements.reserve(map.getNumDims()); + MLIRContext *context = map.getContext(); + for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { + if (unusedDims.contains(dim)) + dimReplacements.push_back(getAffineConstantExpr(0, context)); + else + dimReplacements.push_back(getAffineDimExpr(numDims++, context)); + } + SmallVector resultExprs; + resultExprs.reserve(map.getNumResults()); + for (auto e : map.getResults()) + resultExprs.push_back(e.replaceDims(dimReplacements)); + return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); +} + +AffineMap mlir::compressUnusedDims(AffineMap map) { + llvm::SmallDenseSet usedDims; + map.walkExprs([&](AffineExpr expr) { + if (auto dimExpr = expr.dyn_cast()) + usedDims.insert(dimExpr.getPosition()); + }); + llvm::SmallDenseSet unusedDims; + for (unsigned d = 0, e = map.getNumDims(); d != e; ++d) + if (!usedDims.contains(d)) + unusedDims.insert(d); + return compressDims(map, unusedDims); +} + +AffineMap +mlir::compressSymbols(AffineMap map, + const llvm::SmallDenseSet &unusedSymbols) { + unsigned numSymbols = 0; + SmallVector symReplacements; + symReplacements.reserve(map.getNumSymbols()); + MLIRContext *context = map.getContext(); + for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { + if (unusedSymbols.contains(sym)) + symReplacements.push_back(getAffineConstantExpr(0, context)); + else + symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); + } + SmallVector resultExprs; + resultExprs.reserve(map.getNumResults()); + for (auto e : map.getResults()) + resultExprs.push_back(e.replaceSymbols(symReplacements)); + return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); +} + +AffineMap mlir::compressUnusedSymbols(AffineMap map) { + llvm::SmallDenseSet usedSymbols; + map.walkExprs([&](AffineExpr expr) { + if (auto symExpr = expr.dyn_cast()) + usedSymbols.insert(symExpr.getPosition()); + }); + llvm::SmallDenseSet unusedSymbols; + for (unsigned d = 0, e = map.getNumSymbols(); d != e; ++d) + if (!usedSymbols.contains(d)) + unusedSymbols.insert(d); + return compressSymbols(map, unusedSymbols); +} + AffineMap mlir::simplifyAffineMap(AffineMap map) { SmallVector exprs; for (auto e : map.getResults()) { @@ -480,20 +545,10 @@ maps.front().getContext()); } -AffineMap mlir::getProjectedMap(AffineMap map, - ArrayRef projectedDimensions) { - DenseSet projectedDims(projectedDimensions.begin(), - projectedDimensions.end()); - MLIRContext *context = map.getContext(); - SmallVector resultExprs; - for (auto dim : enumerate(llvm::seq(0, map.getNumDims()))) { - if (!projectedDims.count(dim.value())) - resultExprs.push_back(getAffineDimExpr(dim.index(), context)); - else - resultExprs.push_back(getAffineConstantExpr(0, context)); - } - return map.compose(AffineMap::get( - map.getNumDims() - projectedDimensions.size(), 0, resultExprs, context)); +AffineMap +mlir::getProjectedMap(AffineMap map, + const llvm::SmallDenseSet &unusedDims) { + return compressUnusedSymbols(compressDims(map, unusedDims)); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -829,7 +829,17 @@ /// Return true if the layout for `t` is compatible with strided semantics. bool mlir::isStrided(MemRefType t) { int64_t offset; - SmallVector stridesAndOffset; - auto res = getStridesAndOffset(t, stridesAndOffset, offset); + SmallVector strides; + auto res = getStridesAndOffset(t, strides, offset); return succeeded(res); } + +/// Return the layout map in strided linear layout AffineMap form. +/// Return null if the layout is not compatible with a strided layout. +AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) { + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(t, strides, offset))) + return AffineMap(); + return makeStridedLinearLayoutMap(strides, offset, t.getContext()); +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -812,6 +812,10 @@ // CHECK: subview %{{.*}}[%{{.*}}, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref %28 = subview %24[%arg0, 1] [1, 1] [1, 1] : memref<5x3xf32> to memref (s0)>> + // CHECK: subview %{{.*}}[0, %{{.*}}] [%{{.*}}, 1] [1, 1] : memref to memref + %a30 = alloc(%arg0, %arg0) : memref + %30 = subview %a30[0, %arg1][%arg2, 1][1, 1] : memref to memref (d0 * s1 + s0)>> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -970,7 +970,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result affine map)}} %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] : memref<8x16x4xf32> to memref @@ -1022,13 +1022,22 @@ // ----- func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map)}} %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref return } // ----- +// The affine map affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)> has an extra unused symbol. +func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map) inferred type: (d0)[s0, s1] -> (d0 * s1 + s0)}} + %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref (d0 * s1 + s0)>> + return +} + +// ----- + func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>