diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -825,7 +825,7 @@ The `memref.global` operation declares or defines a named global memref variable. The backing memory for the variable is allocated statically and is described by the type of the variable (which should be a statically shaped - memref type). The operation is a declaration if no `inital_value` is + memref type). The operation is a declaration if no `initial_value` is specified, else it is a definition. The `initial_value` can either be a unit attribute to represent a definition of an uninitialized global variable, or an elements attribute to represent the definition of a global variable with @@ -878,6 +878,9 @@ bool isUninitialized() { return !isExternal() && initial_value().getValue().isa(); } + /// Returns the constant initial value if the memref.global is a constant, + /// or null otherwise. + ElementsAttr getConstantInitValue(); }]; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -8,17 +8,13 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -2400,7 +2396,30 @@ /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); - return OpFoldResult(); + + // Fold load from a global constant memref. + auto getGlobalOp = memref().getDefiningOp(); + if (!getGlobalOp) + return {}; + // Get to the memref.global defining the symbol. + auto *symbolTableOp = getGlobalOp->getParentWithTrait(); + if (!symbolTableOp) + return {}; + auto global = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr())); + if (!global) + return {}; + if (auto cstAttr = + global.getConstantInitValue().dyn_cast_or_null()) { + // We can fold only if we know the indices. + if (!getAffineMap().isConstant()) + return {}; + auto indices = llvm::to_vector<4>( + llvm::map_range(getAffineMap().getConstantResults(), + [](int64_t v) -> uint64_t { return v; })); + return cstAttr.getValues()[indices]; + } + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1275,6 +1275,13 @@ return success(); } +ElementsAttr GlobalOp::getConstantInitValue() { + auto initVal = initial_value(); + if (constant() && initVal.hasValue()) + return initVal.getValue().cast(); + return {}; +} + //===----------------------------------------------------------------------===// // GetGlobalOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1080,3 +1080,21 @@ return %0, %1: index, index } + +// ----- + +module { + memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]> + // CHECK-LABEL: func @fold_const_init_global_memref + func @fold_const_init_global_memref() -> (f32, f32) { + %m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32> + %v0 = affine.load %m[0, 0, 0] : memref<1x5x1xf32> + %v1 = affine.load %m[0, 1, 0] : memref<1x5x1xf32> + return %v0, %v1 : f32, f32 + // CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32 + // CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32 + // CHECK-NEXT: return + // CHECK-SAME: %[[C0]] + // CHECK-SAME: %[[C1]] + } +}