diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2823,19 +2823,30 @@ })); } +enum SubViewVerificationResult { + Success, + RankTooLarge, + SizeMismatch, + StrideMismatch, + ElemTypeMismatch, + MemSpaceMismatch, + AffineMapMismatch +}; + /// Checks if `original` Type type can be rank reduced to `reduced` type. /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. -static bool isRankReducedType(Type originalType, Type reducedType) { +static SubViewVerificationResult isRankReducedType(Type originalType, + Type reducedType) { if (originalType == reducedType) - return true; + return SubViewVerificationResult::Success; if (!originalType.isa() && !originalType.isa()) - return true; + return SubViewVerificationResult::Success; if (originalType.isa() && !reducedType.isa()) - return true; + return SubViewVerificationResult::Success; if (originalType.isa() && !reducedType.isa()) - return true; + return SubViewVerificationResult::Success; ShapedType originalShapedType = originalType.cast(); ShapedType reducedShapedType = reducedType.cast(); @@ -2846,7 +2857,7 @@ unsigned originalRank = originalShape.size(), reducedRank = reducedShape.size(); if (reducedRank > originalRank) - return false; + return SubViewVerificationResult::RankTooLarge; unsigned reducedIdx = 0; SmallVector keepMask(originalRank); @@ -2858,41 +2869,78 @@ reducedIdx++; // 1 is the only non-matching allowed. else if (originalShape[originalIdx] != 1) - return false; + return SubViewVerificationResult::SizeMismatch; } // Must match the reduced rank. if (reducedIdx != reducedRank) - return false; + return SubViewVerificationResult::SizeMismatch; // We are done for the tensor case. if (originalType.isa()) - return true; + return SubViewVerificationResult::Success; // Strided layout logic is relevant for MemRefType only. MemRefType original = originalType.cast(); MemRefType reduced = reducedType.cast(); MLIRContext *c = original.getContext(); - int64_t originalOffset, symCounter = 0, dimCounter = 0; - SmallVector originalStrides; + int64_t originalOffset, reducedOffset; + SmallVector originalStrides, reducedStrides, keepStrides; getStridesAndOffset(original, originalStrides, originalOffset); - auto getSymbolOrConstant = [&](int64_t offset) { - return offset == ShapedType::kDynamicStrideOrOffset - ? getAffineSymbolExpr(symCounter++, c) - : getAffineConstantExpr(offset, c); - }; - - AffineExpr expr = getSymbolOrConstant(originalOffset); - for (unsigned i = 0, e = originalStrides.size(); i < e; i++) { - if (keepMask[i]) - expr = expr + getSymbolOrConstant(originalStrides[i]) * - getAffineDimExpr(dimCounter++, c); + getStridesAndOffset(reduced, reducedStrides, reducedOffset); + + // Filter strides based on the mask and check that they are the same + // as reduced ones. + reducedIdx = 0; + for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { + if (keepMask[originalIdx]) { + if (originalStrides[originalIdx] != reducedStrides[reducedIdx++]) + return SubViewVerificationResult::StrideMismatch; + keepStrides.push_back(originalStrides[originalIdx]); + } } - auto reducedMap = AffineMap::get(dimCounter, symCounter, expr, c); - return original.getElementType() == reduced.getElementType() && - original.getMemorySpace() == reduced.getMemorySpace() && - (reduced.getAffineMaps().empty() || - reducedMap == reduced.getAffineMaps().front()); + if (original.getElementType() != reduced.getElementType()) + return SubViewVerificationResult::ElemTypeMismatch; + + if (original.getMemorySpace() != reduced.getMemorySpace()) + return SubViewVerificationResult::MemSpaceMismatch; + + auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c); + if (!reduced.getAffineMaps().empty() && + reducedMap != reduced.getAffineMaps().front()) + return SubViewVerificationResult::AffineMapMismatch; + + return SubViewVerificationResult::Success; +} + +template +static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, + OpTy op, Type expectedType) { + auto memrefType = expectedType.cast(); + switch (result) { + case SubViewVerificationResult::Success: + return success(); + case SubViewVerificationResult::RankTooLarge: + return op.emitError("expected result rank to be smaller or equal to ") + << "the source rank."; + case SubViewVerificationResult::SizeMismatch: + return op.emitError("expected result type to be ") + << expectedType + << " or a rank-reduced version. (mismatch of result sizes)"; + case SubViewVerificationResult::StrideMismatch: + return op.emitError("expected result type to be ") + << expectedType + << " or a rank-reduced version. (mismatch of result strides)"; + case SubViewVerificationResult::ElemTypeMismatch: + return op.emitError("expected result element type to be ") + << memrefType.getElementType(); + case SubViewVerificationResult::MemSpaceMismatch: + return op.emitError("expected result and source memory spaces to match."); + case SubViewVerificationResult::AffineMapMismatch: + return op.emitError("expected result type to be ") + << expectedType + << " or a rank-reduced version. (mismatch of result affine map)"; + } } template @@ -2937,11 +2985,9 @@ baseType, extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); - if (!isRankReducedType(expectedType, subViewType)) - return op.emitError("expected result type to be ") - << expectedType << " or a rank-reduced version."; - return success(); + auto result = isRankReducedType(expectedType, subViewType); + return produceSubViewErrorMsg(result, op, expectedType); } raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { @@ -3352,11 +3398,8 @@ op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); - if (!isRankReducedType(expectedType, op.getType())) - return op.emitError("expected result type to be ") - << expectedType << " or a rank-reduced version."; - - return success(); + auto result = isRankReducedType(expectedType, op.getType()); + return produceSubViewErrorMsg(result, op, expectedType); } void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results, diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -21,6 +21,7 @@ // CHECK-DAG: #[[$SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)> // CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)> // CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)> +// CHECK-DAG: #[[$SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> // CHECK-LABEL: func @func_with_ops // CHECK-SAME: %[[ARG:.*]]: f32 @@ -811,11 +812,11 @@ %15 = alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * d0 + s2 * d1 + s3 * d2 + s4 * d3 + s5 * d4 + s6 * d5)>> // CHECK: subview %15[0, 0, 0, 0, 0, 0] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] : - // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref - %16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref + // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref + %16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref // CHECK: subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] : - // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref - %17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref + // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref + %17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref %18 = alloc() : memref<1x8xf32> // CHECK: subview %18[0, 0] [1, 8] [1, 1] : memref<1x8xf32> to memref<8xf32> diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1011,7 +1011,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}} %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] : memref<8x16x4xf32> to memref @@ -1020,9 +1020,31 @@ // ----- +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result element type to be 'f32'}} + %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to + memref<8x16x4xi32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result rank to be smaller or equal to the source rank.}} + %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to + memref<8x16x4x3xi32> + return +} + +// ----- + func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>'}} + // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> return @@ -1030,6 +1052,14 @@ // ----- +func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}} + %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref + return +} + +// ----- + func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> @@ -1259,7 +1289,7 @@ // ----- func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { - // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>'}} + // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}} %0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor @@ -1269,7 +1299,7 @@ // ----- func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) { - // expected-error @+1 {{expected result type to be 'tensor'}} + // expected-error @+1 {{expected result type to be 'tensor' or a rank-reduced version. (mismatch of result sizes)}} %0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1] : tensor<8x16x4xf32> to tensor<4x4x4xf32>