diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1373,6 +1373,10 @@ /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + /// Return the dimensions of the source type that are dropped when + /// the result is rank-reduced. + llvm::SmallDenseSet getDroppedDims(); }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -66,7 +66,7 @@ let cppNamespace = "::mlir"; let methods = [ - InterfaceMethod< + StaticInterfaceMethod< /*desc=*/[{ Return the number of leading operands before the `offsets`, `sizes` and and `strides` operands. diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1272,12 +1272,8 @@ extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); - auto shape = viewMemRefType.getShape(); - auto inferredShape = inferredType.getShape(); - size_t inferredShapeRank = inferredShape.size(); - size_t resultShapeRank = shape.size(); - llvm::SmallDenseSet unusedDims = - computeRankReductionMask(inferredShape, shape).getValue(); + size_t inferredShapeRank = inferredType.getRank(); + size_t resultShapeRank = viewMemRefType.getRank(); // Extract strides needed to compute offset. SmallVector strideValues; @@ -1315,6 +1311,7 @@ SmallVector mixedStrides = subViewOp.getMixedStrides(); assert(mixedSizes.size() == mixedStrides.size() && "expected sizes and strides of equal length"); + llvm::SmallDenseSet unusedDims = subViewOp.getDroppedDims(); for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { if (unusedDims.contains(i)) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -690,6 +690,91 @@ return success(); } +static llvm::DenseMap +getNumOccurences(ArrayRef vals) { + llvm::DenseMap numOccurences; + for (auto val : vals) { + numOccurences[val]++; + } + return numOccurences; +} + +/// Given the type of the un-rank reduced subview result type and the +/// rank-reduced result type, computes the dropped dimensions. This accounts for +/// cases where there are multiple unit-dims, but only a subset of those are +/// dropped. For MemRefTypes these can be disambiguated using the strides. If a +/// dimension is dropped the stride must be dropped too. +static llvm::Optional> +computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, + ArrayAttr staticSizes) { + llvm::SmallDenseSet unusedDims; + if (originalType.getRank() == reducedType.getRank()) + return unusedDims; + + for (auto dim : llvm::enumerate(staticSizes)) + if (dim.value().cast().getInt() == 1) + unusedDims.insert(dim.index()); + SmallVector originalStrides, candidateStrides; + int64_t originalOffset, candidateOffset; + if (failed( + getStridesAndOffset(originalType, originalStrides, originalOffset)) || + failed( + getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) + return llvm::None; + + // For memrefs, a dimenion is truly dropped if its corresponding stride is + // also dropped. This is particularly important when more than one of the dims + // is 1. Track the number of occurences of the strides in the original type + // and the candidate type. For each unused dim that stride should not be + // present in the candidate type. Note that there could be multiple dimensions + // that have the same rank. We dont need to exactly figure out which dim + // corresponds to which stride, we just need to verify that the + // + // number of reptitions of a stride in the original + number of unused dims + // with that stride == number of repititions of a stride in the candidate. + llvm::DenseMap currUnaccountedStrides = + getNumOccurences(originalStrides); + llvm::DenseMap candidateStridesNumOccurences = + getNumOccurences(candidateStrides); + llvm::SmallDenseSet prunedUnusedDims; + for (unsigned dim : unusedDims) { + int64_t originalStride = originalStrides[dim]; + if (currUnaccountedStrides[originalStride] > + candidateStridesNumOccurences[originalStride]) { + // This is dim can be treated as dropped. + currUnaccountedStrides[originalStride]--; + continue; + } + if (currUnaccountedStrides[originalStride] == + candidateStridesNumOccurences[originalStride]) { + // The stride for this is not dropped. Keep as is. + prunedUnusedDims.insert(dim); + continue; + } + if (currUnaccountedStrides[originalStride] < + candidateStridesNumOccurences[originalStride]) { + // This should never happen. Cant have a stride in the reduced rank type + // that wasnt in the original one. + return llvm::None; + } + } + + for (auto prunedDim : prunedUnusedDims) + unusedDims.erase(prunedDim); + if (unusedDims.size() + reducedType.getRank() != originalType.getRank()) + return llvm::None; + return unusedDims; +} + +llvm::SmallDenseSet SubViewOp::getDroppedDims() { + MemRefType sourceType = getSourceType(); + MemRefType resultType = getType(); + llvm::Optional> unusedDims = + computeMemRefRankReductionMask(sourceType, resultType, static_sizes()); + assert(unusedDims && "unable to find unused dims of subview"); + return *unusedDims; +} + OpFoldResult DimOp::fold(ArrayRef operands) { // All forms of folding require a known index. auto index = operands[1].dyn_cast_or_null(); @@ -725,6 +810,25 @@ return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); + if (auto subview = dyn_cast_or_null(definingOp)) { + llvm::SmallDenseSet unusedDims = subview.getDroppedDims(); + unsigned resultIndex = 0; + unsigned sourceRank = subview.getSourceType().getRank(); + unsigned sourceIndex = 0; + for (auto i : llvm::seq(0, sourceRank)) { + if (unusedDims.count(i)) + continue; + if (resultIndex == unsignedIndex) { + sourceIndex = i; + break; + } + resultIndex++; + } + assert(subview.isDynamicSize(sourceIndex) && + "expected dynamic subview size"); + return subview.getDynamicSize(sourceIndex); + } + if (auto sizeInterface = dyn_cast_or_null(definingOp)) { assert(sizeInterface.isDynamicSize(unsignedIndex) && @@ -1887,7 +1991,7 @@ /// not matching dimension must be 1. static SubViewVerificationResult isRankReducedType(Type originalType, Type candidateReducedType, - std::string *errMsg = nullptr) { + ArrayAttr staticSizes, std::string *errMsg = nullptr) { if (originalType == candidateReducedType) return SubViewVerificationResult::Success; if (!originalType.isa()) @@ -1908,8 +2012,11 @@ if (candidateReducedRank > originalRank) return SubViewVerificationResult::RankTooLarge; + MemRefType original = originalType.cast(); + MemRefType candidateReduced = candidateReducedType.cast(); + auto optionalUnusedDimsMask = - computeRankReductionMask(originalShape, candidateReducedShape); + computeMemRefRankReductionMask(original, candidateReduced, staticSizes); // Sizes cannot be matched in case empty vector is returned. if (!optionalUnusedDimsMask.hasValue()) @@ -1920,42 +2027,8 @@ return SubViewVerificationResult::ElemTypeMismatch; // Strided layout logic is relevant for MemRefType only. - MemRefType original = originalType.cast(); - MemRefType candidateReduced = candidateReducedType.cast(); if (original.getMemorySpace() != candidateReduced.getMemorySpace()) return SubViewVerificationResult::MemSpaceMismatch; - - llvm::SmallDenseSet unusedDims = optionalUnusedDimsMask.getValue(); - auto inferredType = - getProjectedMap(getStridedLinearLayoutMap(original), unusedDims); - AffineMap candidateLayout; - if (candidateReduced.getAffineMaps().empty()) - candidateLayout = getStridedLinearLayoutMap(candidateReduced); - else - candidateLayout = candidateReduced.getAffineMaps().front(); - assert(inferredType.getNumResults() == 1 && - candidateLayout.getNumResults() == 1); - if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() || - inferredType.getNumDims() != candidateLayout.getNumDims()) { - if (errMsg) { - llvm::raw_string_ostream os(*errMsg); - os << "inferred type: " << inferredType; - } - return SubViewVerificationResult::AffineMapMismatch; - } - // Check that the difference of the affine maps simplifies to 0. - AffineExpr diffExpr = - inferredType.getResult(0) - candidateLayout.getResult(0); - diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(), - inferredType.getNumSymbols()); - auto cst = diffExpr.dyn_cast(); - if (!(cst && cst.getValue() == 0)) { - if (errMsg) { - llvm::raw_string_ostream os(*errMsg); - os << "inferred type: " << inferredType; - } - return SubViewVerificationResult::AffineMapMismatch; - } return SubViewVerificationResult::Success; } @@ -2012,7 +2085,8 @@ extractFromI64ArrayAttr(op.static_strides())); std::string errMsg; - auto result = isRankReducedType(expectedType, subViewType, &errMsg); + auto result = + isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg); return produceSubViewErrorMsg(result, op, expectedType, errMsg); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -49,18 +49,13 @@ SmallVector useIndices; // Check if this is rank-reducing case. Then for every unit-dim size add a // zero to the indices. - ArrayRef resultShape = subViewOp.getType().getShape(); unsigned resultDim = 0; - for (auto size : llvm::enumerate(mixedSizes)) { - auto attr = size.value().dyn_cast(); - // Check if this dimension has been dropped, i.e. the size is 1, but the - // associated dimension is not 1. - if (attr && attr.cast().getInt() == 1 && - (resultDim >= resultShape.size() || resultShape[resultDim] != 1)) + llvm::SmallDenseSet unusedDims = subViewOp.getDroppedDims(); + for (auto dim : llvm::seq(0, subViewOp.getSourceType().getRank())) { + if (unusedDims.count(dim)) useIndices.push_back(rewriter.create(loc, 0)); - else if (resultDim < resultShape.size()) { + else useIndices.push_back(indices[resultDim++]); - } } if (useIndices.size() != mixedOffsets.size()) return failure(); @@ -104,6 +99,25 @@ return op.source(); } +/// Given the permutation map of the original +/// `vector.transfer_read`/`vector.transfer_write` operations compute the +/// permutation map to use after the subview is folded with it. +static AffineMap getPermutationMap(MLIRContext *context, + memref::SubViewOp subViewOp, + AffineMap currPermutationMap) { + llvm::SmallDenseSet unusedDims = subViewOp.getDroppedDims(); + SmallVector exprs; + unsigned resultIdx = 0; + int64_t sourceRank = subViewOp.getSourceType().getRank(); + for (auto dim : llvm::seq(0, sourceRank)) { + if (unusedDims.count(dim)) + continue; + exprs.push_back(getAffineDimExpr(resultIdx++, context)); + } + auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); + return currPermutationMap.compose(resultDimToSourceDimMap); +} + //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// @@ -153,7 +167,9 @@ ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, - loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr()); + getPermutationMap(rewriter.getContext(), subViewOp, + loadOp.permutation_map()), + loadOp.padding(), loadOp.in_boundsAttr()); } template <> @@ -170,7 +186,9 @@ ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( transferWriteOp, transferWriteOp.vector(), subViewOp.source(), - sourceIndices, transferWriteOp.permutation_map(), + sourceIndices, + getPermutationMap(rewriter.getContext(), subViewOp, + transferWriteOp.permutation_map()), transferWriteOp.in_boundsAttr()); } } // namespace diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1418,3 +1418,28 @@ // CHECKPARALLEL: %[[inc:.*]] = mulf %[[vb]], %[[va]] : f32 // CHECKPARALLEL: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECKPARALLEL: store %[[res]], %[[arg2]][%[[arg3]], %[[arg4]], %[[arg5]]] : memref + +// ----- + +func @lower_to_loops_with_rank_reducing_subviews( + %arg0 : memref, %arg1 : memref, %arg2 : index, + %arg3 : index, %arg4 : index) { + %0 = memref.subview %arg0[%arg2] [%arg3] [1] + : memref to memref + %1 = memref.subview %arg1[0, %arg4] [1, %arg3] [1, 1] + : memref to memref + linalg.copy(%0, %1) + : memref, memref + return +} +// CHECK-LABEL: func @lower_to_loops_with_rank_reducing_subviews +// CHECK: scf.for %[[IV:.+]] = %{{.+}} to %{{.+}} step %{{.+}} { +// CHECK: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]] +// CHECK: memref.store %[[VAL]], %{{.+}}[%[[IV]]] +// CHECK: } + +// CHECKPARALLEL-LABEL: func @lower_to_loops_with_rank_reducing_subviews +// CHECKPARALLEL: scf.parallel (%[[IV:.+]]) = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) { +// CHECKPARALLEL: %[[VAL:.+]] = memref.load %{{.+}}[%[[IV]]] +// CHECKPARALLEL: memref.store %[[VAL]], %{{.+}}[%[[IV]]] +// CHECKPARALLEL: } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -159,6 +159,63 @@ // CHECK: %[[RESULT:.+]] = memref.cast %[[SUBVIEW]] // CHECK: return %[[RESULT]] +// ----- + +func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>, + %arg1 : index, %arg2 : index, %arg3 : index) -> memref +{ + %c1 = constant 1 : index + %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref + %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref to memref + return %1 : memref +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0] -> (d0 * 384 + s0 + d1)> +// CHECK: func @multiple_reducing_dims +// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : memref<1x384x384xf32> to memref<1x?xf32, #[[MAP1]]> +// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] +// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref + +// ----- + +func @multiple_reducing_dims_dynamic(%arg0 : memref, + %arg1 : index, %arg2 : index, %arg3 : index) -> memref +{ + %c1 = constant 1 : index + %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref to memref + %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref to memref + return %1 : memref +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK: func @multiple_reducing_dims_dynamic +// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<1x?xf32, #[[MAP1]]> +// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] +// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref + +// ----- + +func @multiple_reducing_dims_all_dynamic(%arg0 : memref, + %arg1 : index, %arg2 : index, %arg3 : index) -> memref +{ + %c1 = constant 1 : index + %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] + : memref to memref + %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref to memref + return %1 : memref +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +// CHECK: func @multiple_reducing_dims_all_dynamic +// CHECK: %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1] +// CHECK-SAME: : memref to memref<1x?xf32, #[[MAP1]]> +// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] +// CHECK-SAME: : memref<1x?xf32, #[[MAP1]]> to memref + + // ----- // CHECK-LABEL: @clone_before_dealloc @@ -567,4 +624,3 @@ %collapsed = memref.collapse_shape %dynamic [[0], [1, 2, 3]] : memref into memref return %collapsed : memref } - diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-subview-ops.mlir @@ -160,3 +160,66 @@ // CHECK-DAG: %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]] // CHECK-DAG: %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]] // CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]] + +// ----- + +func @fold_vector_transfer_read_with_rank_reduced_subview( + %arg0 : memref, + %arg1: index, %arg2 : index, %arg3 : index, %arg4: index, %arg5 : index, + %arg6 : index) -> vector<4xf32> { + %cst = constant 0.0 : f32 + %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %arg3, %arg4] [1, 1, 1] + : memref to + memref + %1 = vector.transfer_read %0[%arg5, %arg6], %cst {in_bounds = [true]} + : memref, vector<4xf32> + return %1 : vector<4xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)> +// CHECK: func @fold_vector_transfer_read_with_rank_reduced_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]](%[[ARG5]])[%[[ARG1]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG6]])[%[[ARG2]]] +// CHECK: vector.transfer_read %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] +// CHECK-SAME: permutation_map = #[[MAP2]] + +// ----- + +func @fold_vector_transfer_write_with_rank_reduced_subview( + %arg0 : memref, + %arg1 : vector<4xf32>, %arg2: index, %arg3 : index, %arg4 : index, + %arg5: index, %arg6 : index, %arg7 : index) { + %cst = constant 0.0 : f32 + %0 = memref.subview %arg0[0, %arg2, %arg3] [1, %arg4, %arg5] [1, 1, 1] + : memref to + memref + vector.transfer_write %arg1, %0[%arg6, %arg7] {in_bounds = [true]} + : vector<4xf32>, memref + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1)> +// CHECK: func @fold_vector_transfer_write_with_rank_reduced_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: vector<4xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP1]](%[[ARG6]])[%[[ARG2]]] +// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP1]](%[[ARG7]])[%[[ARG3]]] +// CHECK-DAG: vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[IDX0]], %[[IDX1]]] +// CHECK-SAME: permutation_map = #[[MAP2]] diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -353,3 +353,12 @@ : memref into memref return %0 : memref } + +// ----- + +func @static_stride_to_dynamic_stride(%arg0 : memref, %arg1 : index, + %arg2 : index) -> memref { + // expected-error @+1 {{expected result type to be 'memref<1x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} + %0 = memref.subview %arg0[0, 0, 0] [1, %arg1, %arg2] [1, 1, 1] : memref to memref + return %0 : memref +} \ No newline at end of file diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -960,17 +960,6 @@ // ----- -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = memref.alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result affine map)}} - %1 = memref.subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] - : memref<8x16x4xf32> to - memref - return -} - -// ----- - func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = memref.alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result element type to be 'f32'}} @@ -1014,22 +1003,13 @@ // ----- func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map)}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result sizes)}} %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref return } // ----- -// The affine map affine_map<(d0)[s0, s1, s2] -> (d0 * s1 + s0)> has an extra unused symbol. -func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result affine map) inferred type: (d0)[s0, s1] -> (d0 * s1 + s0)}} - %0 = memref.subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref (d0 * s1 + s0)>> - return -} - -// ----- - func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref.cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>