diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1986,7 +1986,7 @@ /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. - std::array getArrayAttrRanks() { + std::array getArrayAttrMaxRanks() { unsigned resultRank = getResult().getType().cast().getRank(); return {1, resultRank, resultRank}; } @@ -2983,7 +2983,7 @@ /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. - std::array getArrayAttrRanks() { + std::array getArrayAttrMaxRanks() { unsigned rank = getSourceType().getRank(); return {rank, rank, rank}; } @@ -3097,7 +3097,7 @@ /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. - std::array getArrayAttrRanks() { + std::array getArrayAttrMaxRanks() { unsigned rank = getSourceType().getRank(); return {rank, rank, rank}; } @@ -3184,7 +3184,7 @@ /// Return the expected rank of each of the`static_offsets`, `static_sizes` /// and `static_strides` attributes. - std::array getArrayAttrRanks() { + std::array getArrayAttrMaxRanks() { unsigned rank = getType().getRank(); return {rank, rank, rank}; } diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -35,7 +35,7 @@ Common interface for ops that allow specifying mixed dynamic and static offsets, sizes and strides variadic operands. Ops that implement this interface need to expose the following methods: - 1. `getArrayAttrRanks` to specify the length of static integer + 1. `getArrayAttrMaxRanks` to specify the length of static integer attributes. 2. `offsets`, `sizes` and `strides` variadic operands. 3. `static_offsets`, resp. `static_sizes` and `static_strides` integer @@ -45,9 +45,9 @@ The invariants of this interface are: 1. `static_offsets`, `static_sizes` and `static_strides` have length - exactly `getArrayAttrRanks()`[0] (resp. [1], [2]). + at most `getArrayAttrMaxRanks()`[0] (resp. [1], [2]). 2. `offsets`, `sizes` and `strides` have each length at most - `getArrayAttrRanks()`[0] (resp. [1], [2]). + length `static_offsets` (resp. `static_sizes`, `static_strides`). 3. if an entry of `static_offsets` (resp. `static_sizes`, `static_strides`) is equal to a special sentinel value, namely `ShapedType::kDynamicStrideOrOffset` (resp. `ShapedType::kDynamicSize`, @@ -81,7 +81,7 @@ and `static_strides` attributes. }], /*retTy=*/"std::array", - /*methodName=*/"getArrayAttrRanks", + /*methodName=*/"getArrayAttrMaxRanks", /*args=*/(ins) >, InterfaceMethod< @@ -166,9 +166,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector res; - std::array ranks = $_op.getArrayAttrRanks(); unsigned numDynamic = 0; - unsigned count = ranks[getOffsetOperandGroupPosition()]; + unsigned count = $_op.static_offsets().size(); for (unsigned idx = 0; idx < count; ++idx) { if (isDynamicOffset(idx)) res.push_back($_op.offsets()[numDynamic++]); @@ -188,9 +187,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector res; - std::array ranks = $_op.getArrayAttrRanks(); unsigned numDynamic = 0; - unsigned count = ranks[getSizeOperandGroupPosition()]; + unsigned count = $_op.static_sizes().size(); for (unsigned idx = 0; idx < count; ++idx) { if (isDynamicSize(idx)) res.push_back($_op.sizes()[numDynamic++]); @@ -210,9 +208,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector res; - std::array ranks = $_op.getArrayAttrRanks(); unsigned numDynamic = 0; - unsigned count = ranks[getStrideOperandGroupPosition()]; + unsigned count = $_op.static_strides().size(); for (unsigned idx = 0; idx < count; ++idx) { if (isDynamicStride(idx)) res.push_back($_op.strides()[numDynamic++]); diff --git a/mlir/integration_test/Dialect/Standard/CPU/test_subview.mlir b/mlir/integration_test/Dialect/Standard/CPU/test_subview.mlir new file mode 100644 --- /dev/null +++ b/mlir/integration_test/Dialect/Standard/CPU/test_subview.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext | FileCheck %s + +global_memref "private" constant @__constant_5x3xf32 : memref<5x3xf32> = +dense<[[0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 10.0, 11.0], + [12.0, 13.0, 14.0]]> + +func @main() { + %0 = get_global_memref @__constant_5x3xf32 : memref<5x3xf32> + + /// Subview with only leading operands. + %1 = subview %0[2][3][1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]> + %unranked = memref_cast %1 : memref<3x3xf32, offset: 6, strides: [3, 1]> to memref<*xf32> + call @print_memref_f32(%unranked) : (memref<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 2 offset = 6 sizes = [3, 3] strides = [3, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [6, 7, 8], + // CHECK-NEXT: [9, 10, 11], + // CHECK-NEXT: [12, 13, 14] + // CHECK-SAME: ] + + /// Regular subview. + %2 = subview %0[0, 2][5, 1][1, 1]: memref<5x3xf32> to memref<5x1xf32, offset: 2, strides: [3, 1]> + %unranked2 = memref_cast %2 : memref<5x1xf32, offset: 2, strides: [3, 1]> to memref<*xf32> + call @print_memref_f32(%unranked2) : (memref<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 2 offset = 2 sizes = [5, 1] strides = [3, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [2], + // CHECK-NEXT: [5], + // CHECK-NEXT: [8], + // CHECK-NEXT: [11], + // CHECK-NEXT: [14] + // CHECK-SAME: ] + + /// Rank-reducing subview. + %3 = subview %0[0, 2][5, 1][1, 1]: memref<5x3xf32> to memref<5xf32, offset: 2, strides: [3]> + %unranked3 = memref_cast %3 : memref<5xf32, offset: 2, strides: [3]> to memref<*xf32> + call @print_memref_f32(%unranked3) : (memref<*xf32>) -> () + + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 1 offset = 2 sizes = [5] strides = [3] data = + // CHECK-NEXT: [2, 5, 8, 11, 14] + + /// Rank-reducing subview with only leading operands. + %4 = subview %0[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> + %unranked4 = memref_cast %4 : memref<3xf32, offset: 3, strides: [1]> to memref<*xf32> + call @print_memref_f32(%unranked4) : (memref<*xf32>) -> () + // CHECK: Unranked Memref base@ = {{0x[-9a-f]*}} + // CHECK-SAME: rank = 1 offset = 3 sizes = [3] strides = [1] data = + // CHECK-NEXT: [3, 4, 5] + + return +} + +func private @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -3336,8 +3336,14 @@ targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); - for (unsigned i = 0; i < inferredShapeRank; ++i) { + // `inferredShapeRank` may be larger than the number of offset operands + // because of trailing semantics. In this case, the offset is guaranteed + // to be interpreted as 0 and we can just skip the extra dimensions. + for (unsigned i = 0, e = std::min(inferredShapeRank, + subViewOp.getMixedOffsets().size()); + i < e; ++i) { Value offset = + // TODO: need OpFoldResult ODS adaptor to clean this up. subViewOp.isDynamicOffset(i) ? operands[subViewOp.getIndexOfDynamicOffset(i)] : rewriter.create( @@ -3350,31 +3356,47 @@ } // Update sizes and strides. + SmallVector mixedSizes = subViewOp.getMixedSizes(); + SmallVector mixedStrides = subViewOp.getMixedStrides(); + assert(mixedSizes.size() == mixedStrides.size() && + "expected sizes and strides of equal length"); for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { if (!mask[i]) continue; - Value size = - subViewOp.isDynamicSize(i) - ? operands[subViewOp.getIndexOfDynamicSize(i)] - : rewriter.create( - loc, llvmIndexType, - rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); - targetMemRef.setSize(rewriter, loc, j, size); - Value stride; - if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { + // `i` may overflow subViewOp.getMixedSizes because of trailing semantics. + // In this case, the size is guaranteed to be interpreted as Dim and the + // stride as 1. + Value size, stride; + if (static_cast(i) >= mixedSizes.size()) { + size = rewriter.create( + loc, llvmIndexType, + rewriter.create(loc, subViewOp.source(), i)); stride = rewriter.create( - loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); + loc, llvmIndexType, rewriter.getI64IntegerAttr(1)); } else { - stride = - subViewOp.isDynamicStride(i) - ? operands[subViewOp.getIndexOfDynamicStride(i)] + // TODO: need OpFoldResult ODS adaptor to clean this up. + size = + subViewOp.isDynamicSize(i) + ? operands[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, - rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); - stride = rewriter.create(loc, stride, strideValues[i]); + rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); + if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { + stride = rewriter.create( + loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); + } else { + stride = subViewOp.isDynamicStride(i) + ? operands[subViewOp.getIndexOfDynamicStride(i)] + : rewriter.create( + loc, llvmIndexType, + rewriter.getI64IntegerAttr( + subViewOp.getStaticStride(i))); + stride = rewriter.create(loc, stride, strideValues[i]); + } } + targetMemRef.setSize(rewriter, loc, j, size); targetMemRef.setStride(rewriter, loc, j, stride); j--; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -712,6 +712,10 @@ ShapedType::isDynamic))) return failure(); + if (op.static_sizes().size() != static_cast(resultType.getRank())) + return op->emitError("expected ") + << resultType.getRank() << " sizes values"; + Type expectedType = InitTensorOp::inferResultType(staticSizes, resultType.getElementType()); if (resultType != expectedType) { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2209,6 +2209,8 @@ return parser.addTypeToList(dstType, result.types); } +// TODO: ponder whether we want to allow missing trailing sizes/strides that are +// completed automatically, like we have for subview and subtensor. static LogicalResult verify(MemRefReinterpretCastOp op) { // The source and result memrefs should be in the same memory space. auto srcType = op.source().getType().cast(); @@ -2833,16 +2835,28 @@ /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. Type SubViewOp::inferResultType(MemRefType sourceMemRefType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides) { + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + // A subview may specify only a leading subset of offset/sizes/strides in + // which case we complete with offset=0, sizes from memref type and strides=1. unsigned rank = sourceMemRefType.getRank(); - (void)rank; - assert(staticOffsets.size() == rank && - "unexpected staticOffsets size mismatch"); - assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); - assert(staticStrides.size() == rank && - "unexpected staticStrides size mismatch"); + assert(leadingStaticOffsets.size() <= rank && + "unexpected leadingStaticOffsets overflow"); + assert(leadingStaticSizes.size() <= rank && + "unexpected leadingStaticSizes overflow"); + assert(leadingStaticStrides.size() <= rank && + "unexpected leadingStaticStrides overflow"); + auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets); + auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); + auto staticStrides = llvm::to_vector<4>(leadingStaticStrides); + unsigned numTrailingOffsets = rank - staticOffsets.size(); + unsigned numTrailingSizes = rank - staticSizes.size(); + unsigned numTrailingStrides = rank - staticStrides.size(); + staticOffsets.append(numTrailingOffsets, 0); + llvm::append_range(staticSizes, + sourceMemRefType.getShape().take_back(numTrailingSizes)); + staticStrides.append(numTrailingStrides, 1); // Extract source offset and strides. int64_t sourceOffset; @@ -3197,7 +3211,7 @@ /// with `b` at location `loc`. SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc) { - std::array ranks = op.getArrayAttrRanks(); + std::array ranks = op.getArrayAttrMaxRanks(); assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); SmallVector res; @@ -3484,16 +3498,18 @@ /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType, - ArrayRef staticOffsets, - ArrayRef staticSizes, - ArrayRef staticStrides) { + ArrayRef leadingStaticOffsets, + ArrayRef leadingStaticSizes, + ArrayRef leadingStaticStrides) { + // A subtensor may specify only a leading subset of offset/sizes/strides in + // which case we complete with offset=0, sizes from memref type and strides=1. unsigned rank = sourceRankedTensorType.getRank(); - (void)rank; - assert(staticOffsets.size() == rank && - "unexpected staticOffsets size mismatch"); - assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); - assert(staticStrides.size() == rank && - "unexpected staticStrides size mismatch"); + assert(leadingStaticSizes.size() <= rank && + "unexpected leadingStaticSizes overflow"); + auto staticSizes = llvm::to_vector<4>(leadingStaticSizes); + unsigned numTrailingSizes = rank - staticSizes.size(); + llvm::append_range(staticSizes, sourceRankedTensorType.getShape().take_back( + numTrailingSizes)); return RankedTensorType::get(staticSizes, sourceRankedTensorType.getElementType()); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -18,12 +18,12 @@ #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" LogicalResult mlir::verifyListOfOperandsOrIntegers( - Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, + Operation *op, StringRef name, unsigned maxNumElements, ArrayAttr attr, ValueRange values, llvm::function_ref isDynamic) { - /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != expectedNumElements) - return op->emitError("expected ") - << expectedNumElements << " " << name << " values"; + /// Check static and dynamic offsets/sizes/strides does not overflow type. + if (attr.size() > maxNumElements) + return op->emitError("expected <= ") + << maxNumElements << " " << name << " values"; unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { return isDynamic(attr.cast().getInt()); @@ -35,17 +35,35 @@ } LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { - std::array ranks = op.getArrayAttrRanks(); + std::array maxRanks = op.getArrayAttrMaxRanks(); + // Offsets can come in 2 flavors: + // 1. Either single entry (when maxRanks == 1). + // 2. Or as an array whose rank must match that of the mixed sizes. + // So that the result type is well-formed. + if (!(op.getMixedOffsets().size() == 1 && maxRanks[0] == 1) && + op.getMixedOffsets().size() != op.getMixedSizes().size()) + return op->emitError( + "expected mixed offsets rank to match mixed sizes rank (") + << op.getMixedOffsets().size() << " vs " << op.getMixedSizes().size() + << ") so the rank of the result type is well-formed."; + // Ranks of mixed sizes and strides must always match so the result type is + // well-formed. + if (op.getMixedSizes().size() != op.getMixedStrides().size()) + return op->emitError( + "expected mixed sizes rank to match mixed strides rank (") + << op.getMixedSizes().size() << " vs " << op.getMixedStrides().size() + << ") so the rank of the result type is well-formed."; + if (failed(verifyListOfOperandsOrIntegers( - op, "offset", ranks[0], op.static_offsets(), op.offsets(), + op, "offset", maxRanks[0], op.static_offsets(), op.offsets(), ShapedType::isDynamicStrideOrOffset))) return failure(); - if (failed(verifyListOfOperandsOrIntegers(op, "size", ranks[1], + if (failed(verifyListOfOperandsOrIntegers(op, "size", maxRanks[1], op.static_sizes(), op.sizes(), ShapedType::isDynamic))) return failure(); if (failed(verifyListOfOperandsOrIntegers( - op, "stride", ranks[2], op.static_strides(), op.strides(), + op, "stride", maxRanks[2], op.static_strides(), op.strides(), ShapedType::isDynamicStrideOrOffset))) return failure(); return success(); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -919,11 +919,11 @@ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -938,11 +938,11 @@ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -980,11 +980,11 @@ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -999,11 +999,11 @@ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : memref<64x4xf32, offset: 0, strides: [4, 1], 3> @@ -1052,12 +1052,12 @@ // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG8]], %[[STRIDE1]] : i64 + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG7]], %[[STRIDE0]] : i64 + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -1073,12 +1073,12 @@ // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG8]], %[[STRIDE1]] : i32 + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG7]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -1126,11 +1126,11 @@ // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG8]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG8]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG8]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG7]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG7]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -1145,11 +1145,11 @@ // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG8]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG8]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG8]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG7]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG7]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -1174,12 +1174,12 @@ // CHECK32: %[[CST8:.*]] = llvm.mlir.constant(8 : index) // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST3:.*]] = llvm.mlir.constant(3 : i64) - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = subview %0[0, 8][62, 3][1, 1] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -1219,12 +1219,12 @@ // CHECK32: %[[OFFA2:.*]] = llvm.add %[[OFFA1]], %[[OFFM2]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFFA2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) : i32 + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) : i32 - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = subview %0[%arg1, 8][62, %arg2][%arg0, 1] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -1232,6 +1232,86 @@ return } +// CHECK-LABEL: func @subview_leading_operands( +func @subview_leading_operands(%0 : memref<5x3xf32>, %1: memref<5x?xf32>) { + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Alloc ptr + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Aligned ptr + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Offset + // CHECK: %[[C6:.*]] = llvm.mlir.constant(6 : index) : i64 + // CHECK: llvm.insertvalue %[[C6:.*]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Sizes and strides @rank 1: both static. + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: llvm.insertvalue %[[C3]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Sizes and strides @rank 0: both static extracted from type. + // CHECK: %[[C3_2:.*]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[C3_3:.*]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: llvm.insertvalue %[[C3_2]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[C3_3]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %2 = subview %0[2][3][1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]> + + return +} + +// CHECK-LABEL: func @subview_leading_operands_dynamic( +func @subview_leading_operands_dynamic(%0 : memref<5x?xf32>) { + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Alloc ptr + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Aligned ptr + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Extract strides + // CHECK: %[[ST0:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ST1:.*]] = llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Compute and insert offset from 2 + dynamic value. + // CHECK: %[[OFF:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64 + // CHECK: %[[MUL:.*]] = llvm.mul %[[C2]], %[[ST0]] : i64 + // CHECK: %[[NEW_OFF:.*]] = llvm.add %[[OFF]], %[[MUL]] : i64 + // CHECK: llvm.insertvalue %[[NEW_OFF]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Sizes and strides @rank 1: static stride 1, dynamic size unchanged from source memref. + // CHECK: %[[SZ1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: llvm.insertvalue %[[SZ1]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[C1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Sizes and strides @rank 0: both static. + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 + // CHECK: %[[C1_2:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[MUL:.*]] = llvm.mul %[[C1_2]], %[[ST0]] : i64 + // CHECK: llvm.insertvalue %[[C3]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.insertvalue %[[MUL]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + %1 = subview %0[2][3][1]: memref<5x?xf32> to memref<3x?xf32, offset: ?, strides: [?, 1]> + return +} + +// CHECK-LABEL: func @subview_rank_reducing_leading_operands( +func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) { + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Alloc ptr + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Aligned ptr + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Extract strides + // CHECK: %[[ST0:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[ST1:.*]] = llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // Offset + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: llvm.insertvalue %[[C3:.*]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // Sizes and strides @rank 0: both static. + // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : index) : i64 + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 + // CHECK: llvm.insertvalue %[[C3]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.insertvalue %[[C1]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %1 = subview %0[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> + + return +} + + // ----- // CHECK-LABEL: func @atomic_rmw @@ -1342,7 +1422,7 @@ return %dim : index } // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(i64, ptr)> -// CHECK-NEXT: llvm.insertvalue +// CHECK-NEXT: llvm.insertvalue // CHECK-NEXT: %[[UNRANKED_DESC:.*]] = llvm.insertvalue // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -38,7 +38,7 @@ // ----- func @memref_reinterpret_cast_too_many_offsets(%in: memref) { - // expected-error @+1 {{expected 1 offset values}} + // expected-error @+1 {{expected <= 1 offset values}} %out = memref_reinterpret_cast %in to offset: [0, 0], sizes: [10, 10], strides: [10, 1] : memref to memref<10x10xf32, offset: 0, strides: [10, 1]> diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -22,6 +22,8 @@ // CHECK-DAG: #[[$SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0 * 36 + d1 * 36 + d2 * 4 + d3 * 4 + d4)> // CHECK-DAG: #[[$SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6)> // CHECK-DAG: #[[$SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> +// CHECK-DAG: #[[$SUBVIEW_MAP9:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 3 + d1 + 6)> +// CHECK-DAG: #[[$SUBVIEW_MAP10:map[0-9]+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func @func_with_ops // CHECK-SAME: %[[ARG:.*]]: f32 @@ -791,6 +793,16 @@ %23 = alloc() : memref %78 = subview %23[] [] [] : memref to memref + + /// Subview with only leading operands. + %24 = alloc() : memref<5x3xf32> + // CHECK: subview %{{.*}}[2] [3] [1] : memref<5x3xf32> to memref<3x3xf32, #[[$SUBVIEW_MAP9]]> + %25 = subview %24[2][3][1]: memref<5x3xf32> to memref<3x3xf32, offset: 6, strides: [3, 1]> + + /// Rank-reducing subview with only leading operands. + // CHECK: subview %{{.*}}[1] [1] [1] : memref<5x3xf32> to memref<3xf32, #[[$SUBVIEW_MAP10]]> + %26 = subview %24[1][1][1]: memref<5x3xf32> to memref<3xf32, offset: 3, strides: [1]> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -905,6 +905,36 @@ // ----- +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank (2 vs 3) so the rank of the result type is well-formed}} + %1 = subview %0[0, 0][2, 2, 2][1, 1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected mixed sizes rank to match mixed strides rank (3 vs 2) so the rank of the result type is well-formed}} + %1 = subview %0[0, 0, 0][2, 2, 2][1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> + return +} + +// ----- + +func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { + %0 = alloc() : memref<8x16x4xf32> + // expected-error@+1 {{expected mixed sizes rank to match mixed strides rank (3 vs 2) so the rank of the result type is well-formed}} + %1 = memref_reinterpret_cast %0 to offset: [0], sizes: [2, 2, 2], strides:[1, 1] + : memref<8x16x4xf32> to memref<8x16x4xf32> + return +} + +// ----- + func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> // expected-error@+1 {{different memory spaces}} @@ -929,8 +959,8 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected 3 offset values}} - %1 = subview %0[%arg0, %arg1][%arg2][1, 1, 1] + // expected-error@+1 {{expected <= 3 offset values}} + %1 = subview %0[%arg0, %arg1, 0, 0][%arg2, 0, 0, 0][1, 1, 1, 1] : memref<8x16x4xf32> to memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> return