diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2841,6 +2841,20 @@ "ArrayRef attrs = {}">, // Build a SubViewOp with all dynamic entries. OpBuilder< + "OpBuilder &b, OperationState &result, Value source, " + "ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}">, + // Build a SubViewOp with mixed static and dynamic entries + // and custom result type. + OpBuilder< + "OpBuilder &b, OperationState &result, MemRefType resultType, " + "Value source, ArrayRef staticOffsets, " + "ArrayRef staticSizes, ArrayRef staticStrides, " + "ValueRange offsets, ValueRange sizes, " + "ValueRange strides, ArrayRef attrs = {}">, + // Build a SubViewOp with all dynamic entries and custom result type. + OpBuilder< + "OpBuilder &b, OperationState &result, MemRefType resultType, " "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, " "ArrayRef attrs = {}"> ]; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2728,15 +2728,47 @@ staticStridesVector, offsets, sizes, strides, attrs); } +/// Build a SubViewOp as above but with custom result type. +void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, + MemRefType resultType, Value source, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides, ValueRange offsets, + ValueRange sizes, ValueRange strides, + ArrayRef attrs) { + build(b, result, resultType, source, offsets, sizes, strides, + b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); + result.addAttributes(attrs); +} + +/// Build a SubViewOp as above but with custom result type. +void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, + MemRefType resultType, Value source, + ValueRange offsets, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + auto sourceMemRefType = source.getType().cast(); + unsigned rank = sourceMemRefType.getRank(); + SmallVector staticOffsetsVector; + staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector; + staticSizesVector.assign(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector; + staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, resultType, source, staticOffsetsVector, staticSizesVector, + staticStridesVector, offsets, sizes, strides, attrs); +} + /// Verify that a particular offset/size/stride static attribute is well-formed. static LogicalResult verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName, ArrayAttr attr, llvm::function_ref isDynamic, ValueRange values) { /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != op.getRank()) - return op.emitError("expected ") - << op.getRank() << " " << name << " values"; + size_t inputRank = op.source().getType().cast().getRank(); + if (attr.size() != inputRank) + return op.emitError("expected ") << inputRank << " " << name << " values"; unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { return isDynamic(attr.cast().getInt()); @@ -2755,6 +2787,62 @@ })); } +/// Checks if `original` MemRef type can be rank reduced to `reduced` type. +/// This function is slight variant of `is subsequence` algorithm where +/// not matching dimension must be 1. +static bool isRankReducedType(Type originalType, Type reducedType) { + if (originalType == reducedType) + return true; + + MemRefType original = originalType.cast(); + MemRefType reduced = reducedType.cast(); + ArrayRef originalShape = original.getShape(); + ArrayRef reducedShape = reduced.getShape(); + unsigned originalRank = originalShape.size(), + reducedRank = reducedShape.size(); + if (reducedRank > originalRank) + return false; + + unsigned reducedIdx = 0; + SmallVector keepMask(originalRank); + for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { + // -2 is never used as a dim size so it will never match. + int reducedVal = reducedIdx < reducedRank ? reducedShape[reducedIdx] : -2; + // Skip matching dims greedily. + if ((keepMask[originalIdx] = originalShape[originalIdx] == reducedVal)) + reducedIdx++; + // 1 is the only non-matching allowed. + else if (originalShape[originalIdx] != 1) + return false; + } + // Must match the reduced rank. + if (reducedIdx != reducedRank) + return false; + + MLIRContext *c = original.getContext(); + int64_t originalOffset, symCounter = 0, dimCounter = 0; + SmallVector originalStrides; + getStridesAndOffset(original, originalStrides, originalOffset); + auto getSymbolOrConstant = [&](int64_t offset) { + return offset == ShapedType::kDynamicStrideOrOffset + ? getAffineSymbolExpr(symCounter++, c) + : getAffineConstantExpr(offset, c); + }; + + AffineExpr expr = getSymbolOrConstant(originalOffset); + for (unsigned i = 0, e = originalStrides.size(); i < e; i++) { + if (keepMask[i]) + expr = expr + getSymbolOrConstant(originalStrides[i]) * + getAffineDimExpr(dimCounter++, c); + } + + auto reducedMap = AffineMap::get(dimCounter, symCounter, expr, c); + return original.getElementType() == reduced.getElementType() && + original.getMemorySpace() == reduced.getMemorySpace() && + (reduced.getAffineMaps().empty() || + reducedMap == reduced.getAffineMaps().front()); +} + /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { auto baseType = op.getBaseMemRefType().cast(); @@ -2790,8 +2878,9 @@ op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); - if (op.getType() != expectedType) - return op.emitError("expected result type to be ") << expectedType; + if (!isRankReducedType(expectedType, subViewType)) + return op.emitError("expected result type to be ") + << expectedType << " or a rank-reduced version."; return success(); } diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2107,9 +2107,6 @@ // TODO: expand support to these 2 cases. if (!xferOp.permutation_map().isMinorIdentity()) return failure(); - // TODO: relax this precondition. This will require rank-reducing subviews. - if (xferOp.getMemRefType().getRank() != xferOp.getTransferRank()) - return failure(); // Must have some masked dimension to be a candidate for splitting. if (!xferOp.hasMaskedDim()) return failure(); diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -19,6 +19,8 @@ // CHECK-DAG: #[[$SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> // CHECK-DAG: #[[$SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[$SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)> +// CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)> +// CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)> // CHECK-LABEL: func @func_with_ops // CHECK-SAME: %[[ARG:.*]]: f32 @@ -797,6 +799,33 @@ %11 = subview %9[%arg1, %arg2][4, 4][2, 2] : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]> + %12 = alloc() : memref<1x9x1x4x1xf32, affine_map<(d0, d1, d2, d3, d4) -> (36 * d0 + 36 * d1 + 4 * d2 + 4 * d3 + d4)>> + // CHECK: subview %12[%arg1, %arg1, %arg1, %arg1, %arg1] + // CHECK-SAME: [1, 9, 1, 4, 1] [%arg2, %arg2, %arg2, %arg2, %arg2] : + // CHECK-SAME: memref<1x9x1x4x1xf32, #[[$SUBVIEW_MAP6]]> to memref<9x4xf32, #[[$SUBVIEW_MAP2]]> + %13 = subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, offset: 0, strides: [36, 36, 4, 4, 1]> to memref<9x4xf32, offset: ?, strides: [?, ?]> + // CHECK: subview %12[%arg1, %arg1, %arg1, %arg1, %arg1] + // CHECK-SAME: [1, 9, 1, 4, 1] [%arg2, %arg2, %arg2, %arg2, %arg2] : + // CHECK-SAME: memref<1x9x1x4x1xf32, #[[$SUBVIEW_MAP6]]> to memref<1x9x4xf32, #[[$BASE_MAP3]]> + %14 = subview %12[%arg1, %arg1, %arg1, %arg1, %arg1][1, 9, 1, 4, 1][%arg2, %arg2, %arg2, %arg2, %arg2] : memref<1x9x1x4x1xf32, offset: 0, strides: [36, 36, 4, 4, 1]> to memref<1x9x4xf32, offset: ?, strides: [?, ?, ?]> + + %15 = alloc(%arg1, %arg2)[%c0, %c1, %arg1, %arg0, %arg0, %arg2, %arg2] : memref<1x?x5x1x?x1xf32, affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * d0 + s2 * d1 + s3 * d2 + s4 * d3 + s5 * d4 + s6 * d5)>> + // CHECK: subview %15[0, 0, 0, 0, 0, 0] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] : + // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref + %16 = subview %15[0, 0, 0, 0, 0, 0][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref + // CHECK: subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1] [1, %arg1, 5, 1, %arg2, 1] [1, 1, 1, 1, 1, 1] : + // CHECK-SAME: memref<1x?x5x1x?x1xf32, #[[$SUBVIEW_MAP7]]> to memref + %17 = subview %15[%arg1, %arg1, %arg1, %arg1, %arg1, %arg1][1, %arg1, 5, 1, %arg2, 1][1, 1, 1, 1, 1, 1] : memref<1x?x5x1x?x1xf32, offset: ?, strides: [?, ?, ?, ?, ?, ?]> to memref + + %18 = alloc() : memref<1x8xf32> + // CHECK: subview %18[0, 0] [1, 8] [1, 1] : memref<1x8xf32> to memref<8xf32> + %19 = subview %18[0, 0][1, 8][1, 1] : memref<1x8xf32> to memref<8xf32> + + %20 = alloc() : memref<8x16x4xf32> + // CHECK: subview %20[0, 0, 0] [1, 16, 4] [1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> + %21 = subview %20[0, 0, 0][1, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> + + %22 = subview %20[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]> return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1020,6 +1020,16 @@ // ----- +func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>'}} + %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] + : memref<8x16x4xf32> to memref<16x4xf32> + return +} + +// ----- + func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>