diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -40,17 +40,39 @@ using folded_affine_min = folded::ValueBuilder; using folded_linalg_range = folded::ValueBuilder; +using folded_std_dim = folded::ValueBuilder; +using folded_std_subview = folded::ValueBuilder; +using folded_std_view = folded::ValueBuilder; #define DEBUG_TYPE "linalg-promotion" -static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) { +/// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin +/// is a constant then return a new value set to the smallest such constant. +/// Otherwise return size. +static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc, + Value size) { + auto affineMinOp = dyn_cast_or_null(size.getDefiningOp()); + if (!affineMinOp) + return size; + int64_t minConst = std::numeric_limits::max(); + for (auto e : affineMinOp.getAffineMap().getResults()) + if (auto cst = e.dyn_cast()) + minConst = std::min(minConst, cst.getValue()); + return (minConst != std::numeric_limits::max()) + ? b.create(loc, minConst) + : size; +} + +static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers, + OperationFolder *folder) { auto *ctx = size.getContext(); auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size.getDefiningOp())) return std_alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); - Value mul = std_muli(std_constant_index(width), size); + Value mul = + folded_std_muli(folder, folded_std_constant_index(folder, width), size); return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); } @@ -87,18 +109,22 @@ for (auto en : llvm::enumerate(subView.getRanges())) { auto rank = en.index(); auto rangeValue = en.value(); - Value d = rangeValue.size; - allocSize = folded_std_muli(folder, allocSize, d).getValue(); - fullRanges.push_back(d); - partialRanges.push_back( - folded_linalg_range(folder, zero, std_dim(subView, rank), one)); + // Try to extract a tight constant + Value size = extractSmallestConstantBoundingSize(rangeValue.size); + allocSize = folded_std_muli(folder, allocSize, size).getValue(); + fullRanges.push_back(size); + partialRanges.push_back(folded_std_dim(folder, subView, rank)); } SmallVector dynSizes(fullRanges.size(), -1); auto buffer = - allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); - auto fullLocalView = std_view( - MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); - auto partialLocalView = linalg_slice(fullLocalView, partialRanges); + allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers, folder); + auto fullLocalView = folded_std_view( + folder, MemRefType::get(dynSizes, viewType.getElementType()), buffer, + fullRanges); + SmallVector zeros(fullRanges.size(), zero); + SmallVector ones(fullRanges.size(), one); + auto partialLocalView = + folded_std_subview(folder, fullLocalView, zeros, partialRanges, ones); return PromotionInfo{buffer, fullLocalView, partialLocalView}; } diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -7,7 +7,6 @@ #map3 = affine_map<(d0) -> (d0 + 3)> // CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[strided2DnoOffset:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)> // CHECK-DAG: #[[strided2D_dynamic:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> func @matmul_f32(%A: memref, %M: index, %N: index, %K: index) { @@ -46,28 +45,28 @@ // CHECK: %[[tmpA:.*]] = alloc() : memref<32xi8> // CHECK: %[[fullA:.*]] = std.view %[[tmpA]][][{{.*}}] : memref<32xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialA:.*]] = subview %[[fullA]]{{.*}} : memref to memref /// // CHECK: %[[tmpB:.*]] = alloc() : memref<48xi8> // CHECK: %[[fullB:.*]] = std.view %[[tmpB]][][{{.*}}] : memref<48xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialB:.*]] = subview %[[fullB]]{{.*}} : memref to memref /// // CHECK: %[[tmpC:.*]] = alloc() : memref<24xi8> // CHECK: %[[fullC:.*]] = std.view %[[tmpC]][][{{.*}}] : memref<24xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialC:.*]] = subview %[[fullC]]{{.*}} : memref to memref // CHECK: linalg.fill(%[[fullA]], {{.*}}) : memref, f32 // CHECK: linalg.fill(%[[fullB]], {{.*}}) : memref, f32 // CHECK: linalg.fill(%[[fullC]], {{.*}}) : memref, f32 -// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref, memref -// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref -// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref +// CHECK: linalg.copy(%[[vA]], %[[partialA]]) : memref, memref +// CHECK: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref +// CHECK: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref // // CHECK: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref, memref, memref // -// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref +// CHECK: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref // // CHECK: dealloc %[[tmpA]] : memref<32xi8> // CHECK: dealloc %[[tmpB]] : memref<48xi8> @@ -111,28 +110,28 @@ // CHECK: %[[tmpA_f64:.*]] = alloc() : memref<64xi8> // CHECK: %[[fullA_f64:.*]] = std.view %[[tmpA_f64]][][{{.*}}] : memref<64xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialA_f64:.*]] = linalg.slice %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialA_f64:.*]] = subview %[[fullA_f64]][%{{.*}}, %{{.*}}] : memref to memref /// // CHECK: %[[tmpB_f64:.*]] = alloc() : memref<96xi8> // CHECK: %[[fullB_f64:.*]] = std.view %[[tmpB_f64]][][{{.*}}] : memref<96xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialB_f64:.*]] = linalg.slice %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialB_f64:.*]] = subview %[[fullB_f64]][%{{.*}}, %{{.*}}] : memref to memref /// // CHECK: %[[tmpC_f64:.*]] = alloc() : memref<48xi8> // CHECK: %[[fullC_f64:.*]] = std.view %[[tmpC_f64]][][{{.*}}] : memref<48xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialC_f64:.*]] = linalg.slice %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialC_f64:.*]] = subview %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref to memref // CHECK: linalg.fill(%[[fullA_f64]], {{.*}}) : memref, f64 // CHECK: linalg.fill(%[[fullB_f64]], {{.*}}) : memref, f64 // CHECK: linalg.fill(%[[fullC_f64]], {{.*}}) : memref, f64 -// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref, memref -// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref -// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref +// CHECK: linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref, memref +// CHECK: linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref, memref +// CHECK: linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref, memref // // CHECK: linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref, memref, memref // -// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref, memref +// CHECK: linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref, memref // // CHECK: dealloc %[[tmpA_f64]] : memref<64xi8> // CHECK: dealloc %[[tmpB_f64]] : memref<96xi8> @@ -176,28 +175,28 @@ // CHECK: %[[tmpA_i32:.*]] = alloc() : memref<32xi8> // CHECK: %[[fullA_i32:.*]] = std.view %[[tmpA_i32]][][{{.*}}] : memref<32xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialA_i32:.*]] = linalg.slice %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialA_i32:.*]] = subview %[[fullA_i32]][%{{.*}}, %{{.*}}] : memref to memref /// // CHECK: %[[tmpB_i32:.*]] = alloc() : memref<48xi8> // CHECK: %[[fullB_i32:.*]] = std.view %[[tmpB_i32]][][{{.*}}] : memref<48xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialB_i32:.*]] = linalg.slice %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialB_i32:.*]] = subview %[[fullB_i32]][%{{.*}}, %{{.*}}] : memref to memref /// // CHECK: %[[tmpC_i32:.*]] = alloc() : memref<24xi8> // CHECK: %[[fullC_i32:.*]] = std.view %[[tmpC_i32]][][{{.*}}] : memref<24xi8> to memref // DYNAMIC: std.view %{{.*}}[][{{.*}}] : memref to memref -// CHECK: %[[partialC_i32:.*]] = linalg.slice %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK: %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref to memref // CHECK: linalg.fill(%[[fullA_i32]], {{.*}}) : memref, i32 // CHECK: linalg.fill(%[[fullB_i32]], {{.*}}) : memref, i32 // CHECK: linalg.fill(%[[fullC_i32]], {{.*}}) : memref, i32 -// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref, memref -// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref, memref -// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref, memref +// CHECK: linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref, memref +// CHECK: linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref, memref +// CHECK: linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref, memref // // CHECK: linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref, memref, memref // -// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref, memref +// CHECK: linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref, memref // // CHECK: dealloc %[[tmpA_i32]] : memref<32xi8> // CHECK: dealloc %[[tmpB_i32]] : memref<48xi8>