diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2354,6 +2354,38 @@ /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); + + // Fold load from a global constant memref if load indices are all constant. + if (getAffineMap().isConstant()) { + if (auto getGlobalOp = memref().getDefiningOp()) { + // Get to the memref.global defining the symbol. + if (auto *symbolTableOp = + getGlobalOp->getParentWithTrait()) { + if (auto global = + dyn_cast_or_null(SymbolTable::lookupSymbolIn( + symbolTableOp, getGlobalOp.nameAttr()))) { + + // Return the constant for a memref.global if constant or null + // otherwise. + auto getGlobalConstantValue = + [&](memref::GlobalOp globalOp) -> DenseElementsAttr { + auto initVal = globalOp.initial_value(); + if (!globalOp.constant() || !initVal.hasValue()) + return {}; + return initVal.getValue().dyn_cast(); + }; + + // Check for constant global memref and access value at known indices. + if (auto cstAttr = getGlobalConstantValue(global)) { + auto indices = llvm::to_vector<4>( + llvm::map_range(getAffineMap().getConstantResults(), + [](int64_t v) -> uint64_t { return v; })); + return cstAttr.getValues()[indices]; + } + } + } + } + } return OpFoldResult(); } diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -974,3 +974,21 @@ return %0, %1: index, index } + +// ----- + +module { + memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]> + // CHECK-LABEL: func @fold_const_init_global_memref + func @fold_const_init_global_memref() -> (f32, f32) { + %m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32> + %v0 = affine.load %m[0, 0, 0] : memref<1x5x1xf32> + %v1 = affine.load %m[0, 1, 0] : memref<1x5x1xf32> + // CHECK-DAG: %[[C0:.*]] = arith.constant 2.500000e-01 : f32 + // CHECK-DAG: %[[C1:.*]] = arith.constant 6.250000e-02 : f32 + // CHECK-NEXT: return + // CHECK-DAG-SAME: %[[C0]] + // CHECK-DAG-SAME: %[[C1]] + return %v0, %v1 : f32, f32 + } +}