diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -100,6 +100,20 @@ static StringRef getAlignmentAttrStrName() { return "alignment"; } MemRefType getType() { return getResult().getType().cast(); } + + SmallVector getMixedSizes() { + SmallVector result; + unsigned ctr = 0; + OpBuilder b(getContext()); + for (int64_t i = 0, e = getType().getRank(); i < e; ++i) { + if (getType().isDynamicDim(i)) { + result.push_back(getDynamicSizes()[ctr++]); + } else { + result.push_back(b.getIndexAttr(getType().getShape()[i])); + } + } + return result; + } }]; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_IR_VALUEBOUNDSOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -46,6 +46,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" @@ -139,6 +140,7 @@ linalg::registerTilingInterfaceExternalModels(registry); memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerValueBoundsOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MemRefDialect.cpp MemRefOps.cpp + ValueBoundsOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect @@ -21,5 +22,6 @@ MLIRIR MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRValueBoundsOpInterface MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -0,0 +1,129 @@ +//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" + +using namespace mlir; + +namespace mlir { +namespace memref { +namespace { + +template +struct AllocOpInterface + : public ValueBoundsOpInterface::ExternalModel, + OpTy> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto allocOp = cast(op); + assert(value == allocOp.getResult() && "invalid value"); + + cstr.bound(value)[dim] == allocOp.getMixedSizes()[dim]; + } +}; + +struct CastOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto castOp = cast(op); + assert(value == castOp.getResult() && "invalid value"); + + if (castOp.getResult().getType().isa() && + castOp.getSource().getType().isa()) { + cstr.bound(value)[dim] == cstr.getExpr(castOp.getSource(), dim); + } + } +}; + +struct DimOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto dimOp = cast(op); + assert(value == dimOp.getResult() && "invalid value"); + + auto constIndex = dimOp.getConstantIndex(); + if (!constIndex.has_value()) + return; + cstr.bound(value) == cstr.getExpr(dimOp.getSource(), *constIndex); + } +}; + +struct GetGlobalOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto getGlobalOp = cast(op); + assert(value == getGlobalOp.getResult() && "invalid value"); + + auto type = getGlobalOp.getType(); + assert(!type.isDynamicDim(dim) && "expected static dim"); + cstr.bound(value)[dim] == type.getDimSize(dim); + } +}; + +struct RankOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto rankOp = cast(op); + assert(value == rankOp.getResult() && "invalid value"); + + auto memrefType = rankOp.getMemref().getType().dyn_cast(); + if (!memrefType) + return; + cstr.bound(value) == memrefType.getRank(); + } +}; + +struct SubViewOpInterface + : public ValueBoundsOpInterface::ExternalModel { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto subViewOp = cast(op); + assert(value == subViewOp.getResult() && "invalid value"); + + llvm::SmallBitVector dropped = subViewOp.getDroppedDims(); + int64_t ctr = -1; + for (int64_t i = 0, e = subViewOp.getMixedSizes().size(); i < e; ++i) { + // Skip over rank-reduced dimensions. + if (!dropped.test(i)) + ++ctr; + if (ctr == dim) { + cstr.bound(value)[dim] == subViewOp.getMixedSizes()[i]; + return; + } + } + llvm_unreachable("could not find non-rank-reduced dim"); + } +}; + +} // namespace +} // namespace memref +} // namespace mlir + +void mlir::memref::registerValueBoundsOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + memref::AllocOp::attachInterface>( + *ctx); + memref::AllocaOp::attachInterface< + memref::AllocOpInterface>(*ctx); + memref::CastOp::attachInterface(*ctx); + memref::DimOp::attachInterface(*ctx); + memref::GetGlobalOp::attachInterface(*ctx); + memref::RankOp::attachInterface(*ctx); + memref::SubViewOp::attachInterface(*ctx); + }); +} diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ +// RUN: -split-input-file | FileCheck %s + +// CHECK-LABEL: func @memref_alloc( +// CHECK-SAME: %[[sz:.*]]: index +// CHECK: %[[c6:.*]] = arith.constant 6 : index +// CHECK: return %[[c6]], %[[sz]] +func.func @memref_alloc(%sz: index) -> (index, index) { + %0 = memref.alloc(%sz) : memref<6x?xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @memref_alloca( +// CHECK-SAME: %[[sz:.*]]: index +// CHECK: %[[c6:.*]] = arith.constant 6 : index +// CHECK: return %[[c6]], %[[sz]] +func.func @memref_alloca(%sz: index) -> (index, index) { + %0 = memref.alloca(%sz) : memref<6x?xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (memref<6x?xf32>) -> (index) + %2 = "test.reify_bound"(%0) {dim = 1} : (memref<6x?xf32>) -> (index) + return %1, %2 : index, index +} + +// ----- + +// CHECK-LABEL: func @memref_cast( +// CHECK: %[[c10:.*]] = arith.constant 10 : index +// CHECK: return %[[c10]] +func.func @memref_cast(%m: memref<10xf32>) -> index { + %0 = memref.cast %m : memref<10xf32> to memref + %1 = "test.reify_bound"(%0) {dim = 0} : (memref) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @memref_dim( +// CHECK-SAME: %[[m:.*]]: memref +// CHECK: %[[dim:.*]] = memref.dim %[[m]] +// CHECK: %[[dim:.*]] = memref.dim %[[m]] +// CHECK: return %[[dim]] +func.func @memref_dim(%m: memref) -> index { + %c0 = arith.constant 0 : index + %0 = memref.dim %m, %c0 : memref + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @memref_get_global( +// CHECK: %[[c4:.*]] = arith.constant 4 : index +// CHECK: return %[[c4]] +memref.global "private" @gv0 : memref<4xf32> = dense<[0.0, 1.0, 2.0, 3.0]> +func.func @memref_get_global() -> index { + %0 = memref.get_global @gv0 : memref<4xf32> + %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4xf32>) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @memref_rank( +// CHECK-SAME: %[[t:.*]]: memref<5xf32> +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: return %[[c1]] +func.func @memref_rank(%m: memref<5xf32>) -> index { + %0 = memref.rank %m : memref<5xf32> + %1 = "test.reify_bound"(%0) : (index) -> (index) + return %1 : index +} + +// ----- + +// CHECK-LABEL: func @memref_subview( +// CHECK-SAME: %[[m:.*]]: memref, %[[sz:.*]]: index +// CHECK: return %[[sz]] +func.func @memref_subview(%m: memref, %sz: index) -> index { + %0 = memref.subview %m[2][%sz][1] : memref to memref> + %1 = "test.reify_bound"(%0) {dim = 0} : (memref>) -> (index) + return %1 : index +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10257,6 +10257,7 @@ ), hdrs = [ "include/mlir/Dialect/MemRef/IR/MemRef.h", + "include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h", "include/mlir/Dialect/MemRef/Utils/MemRefUtils.h", ], includes = ["include"], @@ -10271,6 +10272,7 @@ ":MemRefBaseIncGen", ":MemRefOpsIncGen", ":ShapedOpInterfaces", + ":ValueBoundsOpInterface", ":ViewLikeInterface", "//llvm:Support", "//llvm:TargetParser",