diff --git a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ValueBoundsOpInterfaceImpl.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/ValueBoundsOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -366,6 +367,21 @@ } // namespace } // namespace scf + +namespace linalg { +namespace { + +/// Helper structure that iterates over all LinalgOps in `OpTys` and registers +/// the `BufferizableOpInterface` with each of them. +template +struct LinalgValueBoundsOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>(*ctx), ...); + } +}; + +} // namespace +} // namespace linalg } // namespace mlir void mlir::linalg::registerValueBoundsOpInterfaceExternalModels( @@ -411,4 +427,12 @@ registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { scf::ForOp::attachInterface(*ctx); }); + + registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) { + // Register all Linalg structured ops. + LinalgValueBoundsOpInterfaceHelper< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >::registerOpInterface(ctx); + }); } diff --git a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir --- a/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-op-interface-impl.mlir @@ -337,3 +337,16 @@ } return } + +// ----- + +// CHECK-LABEL: func @linalg_fill( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c0]] +// CHECK: return %[[dim]] +func.func @linalg_fill(%t: tensor, %f: f32) -> index { + %0 = linalg.fill ins(%f : f32) outs(%t : tensor) -> tensor + %1 = "test.reify_bound"(%0) {dim = 0} : (tensor) -> (index) + return %1 : index +} diff --git a/mlir/test/Dialect/Linalg/value-bounds-reification.mlir b/mlir/test/Dialect/Linalg/value-bounds-reification.mlir --- a/mlir/test/Dialect/Linalg/value-bounds-reification.mlir +++ b/mlir/test/Dialect/Linalg/value-bounds-reification.mlir @@ -20,3 +20,82 @@ return %4, %5, %6 : index, index, index } + +// ----- + +// CHECK-LABEL: func @reify_slice_bound( +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: "test.some_use"(%[[c5]]) +func.func @reify_slice_bound(%t: tensor, %idx: index, %ub: index, %f: f32) { + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + scf.for %iv = %c0 to %ub step %c4 { + %sz = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 4)>(%iv)[%ub] + %slice = tensor.extract_slice %t[%idx, %iv] [1, %sz] [1, 1] : tensor to tensor<1x?xi32> + %filled = linalg.fill ins(%f : f32) outs(%slice : tensor<1x?xi32>) -> tensor<1x?xi32> + %bound = "test.reify_bound"(%filled) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) + "test.some_use"(%bound) : (index) -> () + } + return +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 - s1 + 1)> +// CHECK-LABEL: func @scf_for( +// CHECK-SAME: %[[lb:.*]]: index, %[[ub:.*]]: index, %[[step:.*]]: index +// CHECK: %[[bound:.*]] = affine.apply #[[$map]]()[%[[ub]], %[[lb]]] +// CHECK: "test.some_use"(%[[bound]]) +func.func @scf_for(%lb: index, %ub: index, %step: index) { + scf.for %iv = %lb to %ub step %step { + %0 = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv)[%ub] + %bound = "test.reify_bound"(%0) {type = "UB"} : (index) -> (index) + "test.some_use"(%bound) : (index) -> () + } + return +} + +// ----- + +// CHECK-LABEL: func @reify_slice_bound2( +func.func @reify_slice_bound2(%lb0: index, %ub0: index, %step0: index, + %ub2: index, %t1: tensor<1x?xi8>, + %t2: tensor, %t3: tensor<1x?xi32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.for %iv0 = %lb0 to %ub0 step %step0 { + // CHECK: %[[c129:.*]] = arith.constant 129 : index + // CHECK: "test.some_use"(%[[c129]]) + %ub1 = affine.min affine_map<(d0)[s0] -> (-d0 + s0, 128)>(%iv0)[%ub0] + %ub1_ub = "test.reify_bound"(%ub1) {type = "UB"} : (index) -> (index) + "test.some_use"(%ub1_ub) : (index) -> () + + // CHECK: %[[c129:.*]] = arith.constant 129 : index + // CHECK: "test.some_use"(%[[c129]]) + %lb1 = affine.apply affine_map<()[s0] -> ((s0 floordiv 32) * 32)>()[%ub1] + %lb1_ub = "test.reify_bound"(%lb1) {type = "UB"} : (index) -> (index) + "test.some_use"(%lb1_ub) : (index) -> () + + scf.for %iv1 = %lb1 to %ub1 step %c32 { + // CHECK: %[[c32:.*]] = arith.constant 32 : index + // CHECK: "test.some_use"(%[[c32]]) + %sz = affine.apply affine_map<(d0)[s0] -> (-d0 + s0)>(%iv1)[%ub1] + %sz_ub = "test.reify_bound"(%sz) {type = "UB"} : (index) -> (index) + "test.some_use"(%sz_ub) : (index) -> () + + scf.for %iv2 = %c0 to %ub2 step %c1 { + %slice1 = tensor.extract_slice %t1[0, %iv2] [1, 1] [1, 1] : tensor<1x?xi8> to tensor<1x1xi8> + %slice2 = tensor.extract_slice %t2[%iv2, 0] [1, %sz] [1, 1] : tensor to tensor<1x?xi8> + %slice3 = tensor.extract_slice %t3[0, 0] [1, %sz] [1, 1] : tensor<1x?xi32> to tensor<1x?xi32> + %matmul = linalg.matmul ins(%slice1, %slice2 : tensor<1x1xi8>, tensor<1x?xi8>) outs(%slice3 : tensor<1x?xi32>) -> tensor<1x?xi32> + + // CHECK: %[[c32:.*]] = arith.constant 32 : index + // CHECK: "test.some_use"(%[[c32]]) + %matmul_ub = "test.reify_bound"(%matmul) {dim = 1, type = "UB"} : (tensor<1x?xi32>) -> (index) + "test.some_use"(%matmul_ub) : (index) -> () + } + } + } + return +}