diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -25,6 +25,7 @@ #include "mlir/Support/LLVM.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -219,7 +220,11 @@ SmallVector partialSizes; fullSizes.reserve(rank); partialSizes.reserve(rank); + llvm::SmallBitVector droppedDims = subView.getDroppedDims(); + int64_t resultDimIdx = 0; for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) { + if (droppedDims[en.index()]) + continue; auto rangeValue = en.value(); // Try to extract a tight constant. LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); @@ -232,7 +237,7 @@ LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); fullSizes.push_back(size); partialSizes.push_back( - b.createOrFold(loc, subView, en.index())); + b.createOrFold(loc, subView, resultDimIdx++)); } SmallVector dynSizes(fullSizes.size(), -1); // If a callback is not specified, then use the default implementation for diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s -// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" | FileCheck %s --check-prefix=DYNAMIC -// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" | FileCheck %s --check-prefix=ALLOCA +// RUN: mlir-opt %s -linalg-promote-subviews -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" -split-input-file | FileCheck %s --check-prefix=DYNAMIC +// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" -split-input-file | FileCheck %s --check-prefix=ALLOCA #map1 = affine_map<(d0) -> (d0 + 2)> #map2 = affine_map<(d0) -> (d0 + 4)> @@ -145,3 +145,46 @@ // CHECK: memref.dealloc %[[tmpA_f64]] : memref<64xi8> // CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8> // CHECK: memref.dealloc %[[tmpC_f64]] : memref<48xi8> + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)> +#map2 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> +#map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +#map6 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map7 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map8 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: promote_rank_reducing_subviews([[arg0:%.+]]: memref<{{.*}}>, [[arg1:%.+]]: memref<{{.*}}>, [[arg2:%.+]]: memref<{{.*}}>, [[lb1:%.+]]: index, [[lb2:%.+]]: index, [[lb3:%.+]]: index, [[lb4:%.+]]: index, [[lb5:%.+]]: index, [[lb6:%.+]]: index, [[ub1:%.+]]: index, [[ub2:%.+]]: index +func.func @promote_rank_reducing_subviews(%arg0: memref, %arg1: memref<128x3x3x64xf32, #map0>, %arg2: memref, + %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %ub1: index, %ub2: index) { + %13 = memref.subview %arg0[%arg3, 0, %arg4, %arg8] [1, 1, %ub1, 32] [1, 1, 1, 1] : memref to memref + %14 = memref.subview %arg1[0, %arg6, %arg7, %arg8] [128, 1, 1, 32] [1, 1, 1, 1] : memref<128x3x3x64xf32, #map0> to memref<128x32xf32, #map5> + %9 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [1, 1, %ub2, 128] [1, 1, 1, 1] : memref to memref + + // CHECK: [[a_alloc:%.+]] = memref.alloc + // CHECK: [[a_view:%.+]] = memref.view [[a_alloc]]{{.*}} + // CHECK: [[a_pro_subview:%.+]] = memref.subview [[a_view]][0, 0] [[[ub1]], {{%.+}}] [1, 1] + + // CHECK: memref.alloc + // CHECK: [[b_view:%.+]] = memref.view + // CHECK: [[b_pro_subview:%.+]] = memref.subview [[b_view]] + + // CHECK: memref.alloc + // CHECK: [[c_view:%.+]] = memref.view + // CHECK: [[c_pro_subview:%.+]] = memref.subview [[c_view]] + + // CHECK-COUNT-3: memref.copy + // CHECK: linalg.generic + // CHECK-SAME: ins([[a_pro_subview]], [[b_pro_subview]] + // CHECK-SAME: outs([[c_pro_subview]] + + linalg.generic {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : memref, memref<128x32xf32, #map5>) outs(%9 : memref) { + ^bb0(%arg9: f32, %arg10: f32, %arg11: f32): + %15 = arith.mulf %arg9, %arg10 : f32 + %16 = arith.addf %arg11, %15 : f32 + linalg.yield %16 : f32 + } + + return +}