diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -245,9 +245,9 @@ /// Normalizes `memrefType` so that the affine layout map of the memref is /// transformed to an identity map with a new shape being computed for the -/// normalized memref type and returns it. The old memref type is simplify +/// normalized memref type and returns it. The old memref type is simply /// returned if the normalization failed. -MemRefType normalizeMemRefType(MemRefType memrefType, +MemRefType normalizeMemRefType(OpBuilder& builder, MemRefType memrefType, unsigned numSymbolicOperands); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1716,7 +1716,7 @@ // Fetch a new memref type after normalizing the old memref to have an // identity map layout. MemRefType newMemRefType = - normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size()); + normalizeMemRefType(b, memrefType, allocOp->getSymbolOperands().size()); if (newMemRefType == memrefType) // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -1728,7 +1728,9 @@ AffineMap layoutMap = memrefType.getLayout().getAffineMap(); memref::AllocOp newAlloc; // Check if `layoutMap` is a tiled layout. Only single layout map is - // supported for normalizing dynamic memrefs. + // supported for normalizing dynamic memrefs. While normalizeMemRefType + // can handle normalization of memrefs with dynamic shapes, the computation + // of new allocation sizes is not yet supported. SmallVector> tileSizePos; (void)getTileSizePos(layoutMap, tileSizePos); if (newMemRefType.getNumDynamicDims() > 0 && !tileSizePos.empty()) { @@ -1740,6 +1742,10 @@ newAlloc = b.create(allocOp->getLoc(), newMemRefType, newDynamicSizes, allocOp->getAlignmentAttr()); + } else if (newMemRefType.getNumDynamicDims() > 0 && tileSizePos.empty()) { + // As discussed above, handling the new size calculation for dynamically + // shaped memrefs during normalzation is currently unsupported. + return failure(); } else { newAlloc = b.create(allocOp->getLoc(), newMemRefType, allocOp->getAlignmentAttr()); @@ -1767,7 +1773,7 @@ return success(); } -MemRefType mlir::normalizeMemRefType(MemRefType memrefType, +MemRefType mlir::normalizeMemRefType(OpBuilder& builder, MemRefType memrefType, unsigned numSymbolicOperands) { unsigned rank = memrefType.getRank(); if (rank == 0) @@ -1783,14 +1789,6 @@ // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. - // Normalize only static memrefs and dynamic memrefs with a tiled-layout map - // for now. - // TODO: Normalize the other types of dynamic memrefs. - SmallVector> tileSizePos; - (void)getTileSizePos(layoutMap, tileSizePos); - if (memrefType.getNumDynamicDims() > 0 && tileSizePos.empty()) - return memrefType; - // We have a single map that is not an identity map. Create a new memref // with the right shape and an identity layout map. ArrayRef shape = memrefType.getShape(); @@ -1798,12 +1796,22 @@ FlatAffineValueConstraints fac(rank, numSymbolicOperands); SmallVector memrefTypeDynDims; for (unsigned d = 0; d < rank; ++d) { - // Use constraint system only in static dimensions. + // The lower bound on a memref shape is always 0. + fac.addBound(IntegerPolyhedron::LB, d, 0); if (shape[d] > 0) { - fac.addBound(IntegerPolyhedron::LB, d, 0); + // If the size of this dimension is statically known, + // add the constant upper bound. fac.addBound(IntegerPolyhedron::UB, d, shape[d] - 1); } else { + // Otherwise, construct an affine map to represent the dynamic + // bound of this dimension. The affine map constructed represents + // just accessing the d'th dimension of the input shape. memrefTypeDynDims.emplace_back(d); + AffineExpr ub = builder.getAffineDimExpr(d); + auto map = AffineMap::get(fac.getNumDimVars(), fac.getNumSymbolVars(), ub); + if (failed(fac.addBound(IntegerPolyhedron::UB, d, map))) { + return memrefType; + } } } // We compose this map with the original index (logical) space to derive diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -366,8 +366,9 @@ } // Fetch a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = normalizeMemRefType(memrefType, + MemRefType newMemRefType = normalizeMemRefType(b, memrefType, /*numSymbolicOperands=*/0); + if (newMemRefType == memrefType || funcOp.isExternal()) { // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -474,7 +475,7 @@ } // Computing a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = normalizeMemRefType(memrefType, + MemRefType newMemRefType = normalizeMemRefType(b, memrefType, /*numSymbolicOperands=*/0); resultTypes.push_back(newMemRefType); } @@ -513,7 +514,7 @@ continue; } // Fetch a new memref type after normalizing the old memref. - MemRefType newMemRefType = normalizeMemRefType(memrefType, + MemRefType newMemRefType = normalizeMemRefType(b, memrefType, /*numSymbolicOperands=*/0); if (newMemRefType == memrefType) { // Either memrefType already had an identity map or the map couldn't diff --git a/mlir/test/Transforms/normalize-memrefs-ops-dynamic.mlir b/mlir/test/Transforms/normalize-memrefs-ops-dynamic.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops-dynamic.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops-dynamic.mlir @@ -105,7 +105,7 @@ // CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2)> // CHECK-LABEL: func @test_norm_dynamic_not_tiled0 -// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP6]]>) { +// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x?xf32>) { func.func @test_norm_dynamic_not_tiled0(%arg0 : memref<1x?x?x14xf32, #map_not_tiled0>) -> () { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -118,10 +118,10 @@ // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-NOT: separator of consecutive DAGs - // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP6]]> - // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP6]]> + // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x?xf32> + // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x?xf32> // CHECK: [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP6]]> - // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP6]]>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> () + // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x?xf32>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> () // CHECK: memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP6]]> // CHECK: return } @@ -133,10 +133,10 @@ #map_not_tiled1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)> -// CHECK-DAG: #[[$MAP6:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)> +// CHECK-DAG: #[[$MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 - d2, d2 mod 32, d3 mod 64)> // CHECK-LABEL: func @test_norm_dynamic_not_tiled1 -// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP6]]>) { +// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x?x?x64xf32>) { func.func @test_norm_dynamic_not_tiled1(%arg0 : memref<1x?x?x14xf32, #map_not_tiled1>) -> () { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -149,10 +149,10 @@ // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-NOT: separator of consecutive DAGs - // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP6]]> - // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP6]]> + // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x?x?x64xf32> + // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x?x?x64xf32> // CHECK: [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP6]]> - // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP6]]>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> () + // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x?x?x64xf32>, memref<1x?x?x14xf32, #[[$MAP6]]>) -> () // CHECK: memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP6]]> // CHECK: return } @@ -164,10 +164,10 @@ #map_not_tiled2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)> -// CHECK-DAG: #[[$MAP7:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)> +// CHECK-DAG: #[[$MAP8:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 - d1, d3 floordiv 64, d2 mod 32, d3 mod 32)> // CHECK-LABEL: func @test_norm_dynamic_not_tiled2 -// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP7]]>) { +// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x1x?x32xf32>) { func.func @test_norm_dynamic_not_tiled2(%arg0 : memref<1x?x?x14xf32, #map_not_tiled2>) -> () { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -180,10 +180,10 @@ // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-NOT: separator of consecutive DAGs - // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP7]]> - // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP7]]> + // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x1x?x32xf32> + // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x1x?x32xf32> // CHECK: [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP7]]> - // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP7]]>, memref<1x?x?x14xf32, #[[$MAP7]]>) -> () + // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x1x?x32xf32>, memref<1x?x?x14xf32, #[[$MAP7]]>) -> () // CHECK: memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP7]]> // CHECK: return } @@ -195,10 +195,10 @@ #map_not_tiled3 = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)> -// CHECK-DAG: #[[$MAP8:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)> +// CHECK-DAG: #[[$MAP9:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 floordiv 32, d2, d3, d1 mod 32, d1 mod 32)> // CHECK-LABEL: func @test_norm_dynamic_not_tiled3 -// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP8]]>) { +// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x14x?x?xf32>) { func.func @test_norm_dynamic_not_tiled3(%arg0 : memref<1x?x?x14xf32, #map_not_tiled3>) -> () { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -211,10 +211,10 @@ // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-NOT: separator of consecutive DAGs - // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP8]]> - // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP8]]> + // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14x?x?xf32> + // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14x?x?xf32> // CHECK: [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP8]]> - // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP8]]>, memref<1x?x?x14xf32, #[[$MAP8]]>) -> () + // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14x?x?xf32>, memref<1x?x?x14xf32, #[[$MAP8]]>) -> () // CHECK: memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP8]]> // CHECK: return } @@ -226,10 +226,10 @@ #map_not_tiled4 = affine_map<(d0, d1, d2, d3) -> (d0 floordiv 32, d1 floordiv 32, d0, d3, d0 mod 32, d1 mod 32)> -// CHECK-DAG: #[[$MAP9:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 floordiv 32, d1 floordiv 32, d0, d3, d0 mod 32, d1 mod 32)> +// CHECK-DAG: #[[$MAP10:.+]] = affine_map<(d0, d1, d2, d3) -> (d0 floordiv 32, d1 floordiv 32, d0, d3, d0 mod 32, d1 mod 32)> // CHECK-LABEL: func @test_norm_dynamic_not_tiled4 -// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x?x14xf32, #[[$MAP9]]>) { +// CHECK-SAME: ([[ARG_0_:%.+]]: memref<1x?x1x14x32x?xf32>) { func.func @test_norm_dynamic_not_tiled4(%arg0 : memref<1x?x?x14xf32, #map_not_tiled4>) -> () { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index @@ -242,10 +242,41 @@ // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-NOT: separator of consecutive DAGs - // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x?x14xf32, #[[$MAP9]]> - // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x?x14xf32, #[[$MAP9]]> + // CHECK-DAG: [[DIM_0_:%.+]] = memref.dim [[ARG_0_]], [[CST_1_]] : memref<1x?x1x14x32x?xf32> + // CHECK-DAG: [[DIM_1_:%.+]] = memref.dim [[ARG_0_]], [[CST_2_]] : memref<1x?x1x14x32x?xf32> // CHECK: [[RES_:%.+]] = memref.alloc([[DIM_0_]], [[DIM_1_]]) : memref<1x?x?x14xf32, #[[$MAP9]]> - // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x?x14xf32, #[[$MAP9]]>, memref<1x?x?x14xf32, #[[$MAP9]]>) -> () + // CHECK: "test.op_norm"([[ARG_0_]], [[RES_]]) : (memref<1x?x1x14x32x?xf32>, memref<1x?x?x14xf32, #[[$MAP9]]>) -> () // CHECK: memref.dealloc [[RES_]] : memref<1x?x?x14xf32, #[[$MAP9]]> // CHECK: return } + +// ----- + +// Test that memrefs with affine maps that aren't tiled can be still be normalized when passed as function arguments. + +#map_dyn_arg1 = affine_map<(d0, d1, d2) -> (d1 + 3, d2 + 5)> +#map_dyn_arg2 = affine_map<(d0, d1, d2) -> (0)> +#map_dyn_arg3 = affine_map<(d0, d1, d2) -> (d0 * 2, d1 + d2)> + +// CHECK-LABEL: func @test_func_arg_dynamic_memrefs_normalized +// CHECK-SAME: ([[ARG_0_:%.+]]: memref, [[ARG_1_:%.+]]: memref<1xf64>, [[ARG_2_:%.+]]: memref, [[D_0_:%.+]]: index, [[D_1_:%.+]]: index, [[D_2_:%.+]]: index) { +func.func @test_func_arg_dynamic_memrefs_normalized(%arg0 : memref, %arg1: memref, %arg2 : memref, %d0 : index, %d1 : index, %d2 : index) -> () { + %0 = arith.constant 0.0 : f64 + affine.for %i = 0 to %d0 { + affine.for %j = 0 to %d1 { + affine.for %k = 0 to %d2 { + affine.store %0, %arg0[%i, %j, %k] : memref + affine.store %0, %arg1[%i, %j, %k] : memref + affine.store %0, %arg2[%i, %j, %k] : memref + } + } + } + return + // CHECK: [[CST_1_:%.+]] = arith.constant 0.000000e+00 : f64 + // CHECK: affine.for [[I_:%.+]] = 0 to [[D_0_]] { + // CHECK: affine.for [[J_:%.+]] = 0 to [[D_1_]] { + // CHECK: affine.for [[K_:%.+]] = 0 to [[D_2_]] { + // CHECK: affine.store [[CST_1_]], [[ARG_0_]][[[J_]] + 3, [[K_]] + 5] : memref + // CHECK: affine.store [[CST_1_]], [[ARG_1_]][0] : memref<1xf64> + // CHECK: affine.store [[CST_1_]], [[ARG_2_]][[[I_]] * 2, [[J_]] + [[K_]]] : memref +}