diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -466,11 +466,6 @@ /// Return null if the layout is not compatible with a strided layout. AffineMap getStridedLinearLayoutMap(MemRefType t); -/// Helper determining if a memref is static-shape and contiguous-row-major -/// layout, while still allowing for an arbitrary offset (any static or -/// dynamic value). -bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType); - } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -961,7 +961,25 @@ auto srcType = op.getSource().getType().cast(); auto targetType = op.getTarget().getType().cast(); - auto isContiguousMemrefType = [](BaseMemRefType type) { + auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) { + if (!type.hasStaticShape()) + return false; + + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(type, strides, offset))) + return false; + + int64_t runningStride = 1; + for (unsigned i = strides.size(); i > 0; --i) { + if (strides[i - 1] != runningStride) + return false; + runningStride *= type.getDimSize(i - 1); + } + return true; + }; + + auto isContiguousMemrefType = [&](BaseMemRefType type) { auto memrefType = type.dyn_cast(); // We can use memcpy for memrefs if they have an identity layout or are // contiguous with an arbitrary offset. Ignore empty memrefs, which is a diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1761,7 +1761,7 @@ /// Compute the layout map after expanding a given source MemRef type with the /// specified reassociation indices. -static FailureOr +static FailureOr computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation) { int64_t srcOffset; @@ -1798,8 +1798,7 @@ } auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides)); resultStrides.resize(resultShape.size(), 1); - return makeStridedLinearLayoutMap(resultStrides, srcOffset, - srcType.getContext()); + return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); } static FailureOr @@ -1814,14 +1813,12 @@ } // Source may not be contiguous. Compute the layout map. - FailureOr computedLayout = + FailureOr computedLayout = computeExpandedLayoutMap(srcType, resultShape, reassociation); if (failed(computedLayout)) return failure(); - auto computedType = - MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, - srcType.getMemorySpaceAsInt()); - return canonicalizeStridedLayout(computedType); + return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, + srcType.getMemorySpace()); } void ExpandShapeOp::build(OpBuilder &builder, OperationState &result, @@ -1855,10 +1852,9 @@ return emitOpError("invalid source layout map"); // Check actual result type. - auto canonicalizedResultType = canonicalizeStridedLayout(resultType); - if (*expectedResultType != canonicalizedResultType) + if (*expectedResultType != resultType) return emitOpError("expected expanded type to be ") - << *expectedResultType << " but found " << canonicalizedResultType; + << *expectedResultType << " but found " << resultType; return success(); } @@ -1877,7 +1873,7 @@ /// not possible to check this by inspecting a MemRefType in the general case. /// If non-contiguity cannot be checked statically, the collapse is assumed to /// be valid (and thus accepted by this function) unless `strict = true`. -static FailureOr +static FailureOr computeCollapsedLayoutMap(MemRefType srcType, ArrayRef reassociation, bool strict = false) { @@ -1940,13 +1936,12 @@ return failure(); } } - return makeStridedLinearLayoutMap(resultStrides, srcOffset, - srcType.getContext()); + return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides); } bool CollapseShapeOp::isGuaranteedCollapsible( MemRefType srcType, ArrayRef reassociation) { - // MemRefs with standard layout are always collapsible. + // MemRefs with identity layout are always collapsible. if (srcType.getLayout().isIdentity()) return true; @@ -1978,14 +1973,12 @@ // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. - FailureOr computedLayout = + FailureOr computedLayout = computeCollapsedLayoutMap(srcType, reassociation); assert(succeeded(computedLayout) && "invalid source layout map or collapsing non-contiguous dims"); - auto computedType = - MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, - srcType.getMemorySpaceAsInt()); - return canonicalizeStridedLayout(computedType); + return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout, + srcType.getMemorySpace()); } void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src, @@ -2021,21 +2014,19 @@ // Source may not be fully contiguous. Compute the layout map. // Note: Dimensions that are collapsed into a single dim are assumed to be // contiguous. - FailureOr computedLayout = + FailureOr computedLayout = computeCollapsedLayoutMap(srcType, getReassociationIndices()); if (failed(computedLayout)) return emitOpError( "invalid source layout map or collapsing non-contiguous dims"); - auto computedType = + expectedResultType = MemRefType::get(resultType.getShape(), srcType.getElementType(), - *computedLayout, srcType.getMemorySpaceAsInt()); - expectedResultType = canonicalizeStridedLayout(computedType); + *computedLayout, srcType.getMemorySpace()); } - auto canonicalizedResultType = canonicalizeStridedLayout(resultType); - if (expectedResultType != canonicalizedResultType) + if (expectedResultType != resultType) return emitOpError("expected collapsed type to be ") - << expectedResultType << " but found " << canonicalizedResultType; + << expectedResultType << " but found " << resultType; return success(); } @@ -2709,24 +2700,26 @@ AffineMap permutationMap) { auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); - // Compute permuted sizes. - SmallVector sizes(rank, 0); - for (const auto &en : llvm::enumerate(permutationMap.getResults())) - sizes[en.index()] = - originalSizes[en.value().cast().getPosition()]; - - // Compute permuted strides. int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(memRefType, strides, offset); - assert(succeeded(res) && strides.size() == static_cast(rank)); + SmallVector originalStrides; + auto res = getStridesAndOffset(memRefType, originalStrides, offset); + assert(succeeded(res) && + originalStrides.size() == static_cast(rank)); (void)res; - auto map = - makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); - map = permutationMap ? map.compose(permutationMap) : map; + + // Compute permuted sizes and strides. + SmallVector sizes(rank, 0); + SmallVector strides(rank, 1); + for (const auto &en : llvm::enumerate(permutationMap.getResults())) { + unsigned position = en.value().cast().getPosition(); + sizes[en.index()] = originalSizes[position]; + strides[en.index()] = originalStrides[position]; + } + return MemRefType::Builder(memRefType) .setShape(sizes) - .setLayout(AffineMapAttr::get(map)); + .setLayout( + StridedLayoutAttr::get(memRefType.getContext(), offset, strides)); } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -136,11 +136,10 @@ int64_t offset; if (failed(getStridesAndOffset(bufferType, strides, offset))) return failure(); - AffineMap resultLayout = - makeStridedLinearLayoutMap({}, offset, op->getContext()); - resultType = - MemRefType::get({}, tensorResultType.getElementType(), resultLayout, - bufferType.getMemorySpaceAsInt()); + resultType = MemRefType::get( + {}, tensorResultType.getElementType(), + StridedLayoutAttr::get(op->getContext(), offset, {}), + bufferType.getMemorySpace()); } replaceOpWithNewBufferizedOp( diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2250,7 +2250,8 @@ os << 'x'; } printType(memrefTy.getElementType()); - if (!memrefTy.getLayout().isIdentity()) { + MemRefLayoutAttrInterface layout = memrefTy.getLayout(); + if (!layout.isa() || !layout.isIdentity()) { os << ", "; printAttribute(memrefTy.getLayout(), AttrTypeElision::May); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -1027,40 +1027,3 @@ return AffineMap(); return makeStridedLinearLayoutMap(strides, offset, t.getContext()); } - -/// Return the AffineExpr representation of the offset, assuming `memRefType` -/// is a strided memref. -static AffineExpr getOffsetExpr(MemRefType memrefType) { - SmallVector strides; - AffineExpr offset; - if (failed(getStridesAndOffset(memrefType, strides, offset))) - assert(false && "expected strided memref"); - return offset; -} - -/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and -/// `offset` AffineExpr. -static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, - ArrayRef shape, - Type elementType, - AffineExpr offset) { - AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); - AffineExpr contiguousRowMajor = canonical + offset; - AffineMap contiguousRowMajorMap = - AffineMap::inferFromExprList({contiguousRowMajor})[0]; - return MemRefType::get(shape, elementType, contiguousRowMajorMap); -} - -/// Helper determining if a memref is static-shape and contiguous-row-major -/// layout, while still allowing for an arbitrary offset (any static or -/// dynamic value). -bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { - if (!memrefType.hasStaticShape()) - return false; - AffineExpr offset = getOffsetExpr(memrefType); - MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( - memrefType.getContext(), memrefType.getShape(), - memrefType.getElementType(), offset); - return canonicalizeStridedLayout(memrefType) == - canonicalizeStridedLayout(contiguousRowMajorMemRefType); -} diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -609,7 +609,7 @@ // CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> func.func @transpose(%arg0: memref>) { - %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref> to memref (d2 * s1 + s0 + d0 * s2 + d1)>> + %0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref> to memref> return } @@ -725,12 +725,12 @@ // ----- func.func @collapse_shape_dynamic_with_non_identity_layout( - %arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) -> - memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> { + %arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) -> + memref<4x?xf32, strided<[?, ?], offset: ?>> { %0 = memref.collapse_shape %arg0 [[0], [1, 2]]: - memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into - memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> - return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> + memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into + memref<4x?xf32, strided<[?, ?], offset: ?>> + return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>> } // CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout( // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -898,12 +898,12 @@ // ----- func.func @expand_shape_dynamic_with_non_identity_layout( - %arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) -> - memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> { + %arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) -> + memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { %0 = memref.expand_shape %arg0 [[0], [1, 2]]: - memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into - memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> - return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> + memref<1x?xf32, strided<[?, ?], offset: ?>> into + memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> + return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> } // CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout( // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> @@ -982,10 +982,10 @@ // ----- // CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout -func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> { +func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> { // CHECK-NOT: memref.collapse_shape - %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> - return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> + %1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>> + return %1 : memref<64xf32, strided<[1], offset: ?>> } // ----- @@ -1069,13 +1069,11 @@ // ----- // CHECK-LABEL: func @memref_copy_0d_offset -#map0 = affine_map<(d0) -> (d0 + 1)> -#map1 = affine_map<() -> (1)> func.func @memref_copy_0d_offset(%in: memref<2xi32>) { %buf = memref.alloc() : memref - %sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, #map0> - %scalar = memref.collapse_shape %sub [] : memref<1xi32, #map0> into memref - memref.copy %scalar, %buf : memref to memref + %sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> + %scalar = memref.collapse_shape %sub [] : memref<1xi32, strided<[1], offset: 1>> into memref> + memref.copy %scalar, %buf : memref> to memref // CHECK: llvm.intr.memcpy return } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir @@ -23,8 +23,8 @@ %f = linalg.fill ins(%f0 : f32) outs(%a : tensor) -> tensor // CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref to memref - // CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref to memref - // CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref to memref + // CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref to memref> + // CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref to memref> %r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor into tensor // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -6,8 +6,6 @@ // Test that we can lower all the way to LLVM without crashing, don't check results here. // DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1 -// CHECK: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)> - func.func @views(%arg0: index) { %c0 = arith.constant 0 : index %0 = arith.muli %arg0, %arg0 : index @@ -70,12 +68,12 @@ // ----- func.func @transpose(%arg0: memref>) { - %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref> to memref (d2 * s1 + s0 + d1 * s2 + d0)>> + %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref> to memref> return } // CHECK-LABEL: func @transpose // CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : -// CHECK-SAME: memref> to memref +// CHECK-SAME: memref> to memref> // ----- diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -424,7 +424,7 @@ func.func @expand_shape_invalid_result_layout( %arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) { - // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}} + // expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}} %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<30x20xf32, strided<[4000, 2], offset: 100>> into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>> diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -104,10 +104,10 @@ %arg1: tensor<3x4x5xf32>, %arg2: tensor<3x?x5xf32>, %arg3: memref<30x20xf32, strided<[4000, 2], offset: 100>>, - %arg4: memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>>, + %arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>, %arg5: memref, %arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>, - %arg7: memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) { + %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>) { // Reshapes that collapse and expand back a contiguous buffer. // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] // CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32> @@ -157,8 +157,8 @@ // CHECK: memref.expand_shape {{.*}} {{\[}}[0], [1, 2]] %r4 = memref.expand_shape %arg4 [[0], [1, 2]] : - memref<1x5xf32, affine_map<(d0, d1)[s0] -> (d0 * 5 + s0 + d1)>> into - memref<1x1x5xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 5 + s0 + d2 + d1 * 5)>> + memref<1x5xf32, strided<[5, 1], offset: ?>> into + memref<1x1x5xf32, strided<[5, 5, 1], offset: ?>> // Note: Only the collapsed two shapes are contiguous in the follow test case. // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]] @@ -168,8 +168,8 @@ // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] %r7 = memref.collapse_shape %arg7 [[0, 1]] : - memref<1x2049xi64, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into - memref<2049xi64, affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>> + memref<1x2049xi64, strided<[?, ?], offset: ?>> into + memref<2049xi64, strided<[?], offset: ?>> // Reshapes that expand and collapse back a contiguous buffer with some 1's. // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] @@ -241,15 +241,15 @@ memref> // CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1]] -// CHECK-SAME: memref> into memref +// CHECK-SAME: memref> into memref> %3 = memref.collapse_shape %arg3 [[0, 1]] : memref> into - memref + memref> // CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1]] -// CHECK-SAME: memref into memref +// CHECK-SAME: memref> into memref %r3 = memref.expand_shape %3 [[0, 1]] : - memref into memref + memref> into memref return } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -372,8 +372,6 @@ // ----- -// CHECK-DAG: #[[$MAP2b:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> - // CHECK-LABEL: func @tensor.expand_shape_of_slice( // CHECK-SAME: %[[t1:.*]]: tensor func.func @tensor.expand_shape_of_slice( @@ -383,7 +381,7 @@ %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] : tensor to tensor // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [ - // CHECK-SAME: [0, 1], [2, 3]] : memref> into memref + // CHECK-SAME: [0, 1], [2, 3]] : memref> into memref> %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] : tensor into tensor // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] @@ -393,8 +391,6 @@ // ----- -// CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice( // CHECK-SAME: %[[t1:.*]]: tensor func.func @tensor.expand_shape_of_scalar_slice( @@ -402,7 +398,7 @@ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref to memref> %0 = tensor.extract_slice %t1[%o1][1][1] : tensor to tensor - // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref into memref<1xf32, #[[$MAP10]]> + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref into memref<1xf32, strided<[1], offset: ?>> %1 = tensor.expand_shape %0 [] : tensor into tensor<1xf32> // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] // CHECK: return %[[r]] @@ -442,13 +438,11 @@ // ----- -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)> - // CHECK-LABEL: func @tensor.collapse_shape_of_slice( func.func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor { // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>> %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> - // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref + // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, strided<[1], offset: 1>> into memref> %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor return %1 : tensor } @@ -474,23 +468,19 @@ // ----- -// CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)> - // CHECK-LABEL: func @tensor.collapse_shape_of_slice3( // CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> func.func @tensor.collapse_shape_of_slice3(%t1: tensor<1x2xf32>) -> tensor<1xf32> { // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, strided<[2, 1]>> %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> // CHECK: memref.collapse_shape %{{.*}} [ - // CHECK-SAME: [0, 1]] : memref<1x1xf32, strided<[2, 1]>> into memref<1xf32, #[[$MAP6]]> + // CHECK-SAME: [0, 1]] : memref<1x1xf32, strided<[2, 1]>> into memref<1xf32, strided<[2]>> %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> return %1 : tensor<1xf32> } // ----- -// CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> - // CHECK-LABEL: func @tensor.collapse_shape_of_slice4( // CHECK-SAME: %[[t1:.*]]: tensor, // CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<8xf32> { @@ -498,7 +488,7 @@ // CHECK: memref.subview %{{.*}} : memref to memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor to tensor<4x2x1xf32> // CHECK: memref.collapse_shape %{{.*}} [ - // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> into memref<8xf32, #[[$MAP8]]> + // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, strided<[8, 4, 1], offset: ?>> into memref<8xf32, strided<[4], offset: ?>> %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32> return %ret: tensor<8xf32> } diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -124,8 +124,8 @@ { // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 128 : i64} : memref // CHECK: memref.copy %[[A]], %[[ALLOC]] : memref - // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32> - // CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32> + // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref to memref<4xf32, strided<[1]>> + // CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32, strided<[1]>> %r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor // CHECK: return %{{.*}} : memref