diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2410,17 +2410,22 @@ SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.nameAttr())); if (!global) return {}; - if (auto cstAttr = - global.getConstantInitValue().dyn_cast_or_null()) { - // We can fold only if we know the indices. - if (!getAffineMap().isConstant()) - return {}; - auto indices = llvm::to_vector<4>( - llvm::map_range(getAffineMap().getConstantResults(), - [](int64_t v) -> uint64_t { return v; })); - return cstAttr.getValues()[indices]; - } - return {}; + + // Check if the global memref is a constant. + auto cstAttr = + global.getConstantInitValue().dyn_cast_or_null(); + if (!cstAttr) + return {}; + // If it's a splat constant, we can fold irrespective of indices. + if (auto splatAttr = cstAttr.dyn_cast()) + return splatAttr.getSplatValue(); + // Otherwise, we can fold only if we know the indices. + if (!getAffineMap().isConstant()) + return {}; + auto indices = llvm::to_vector<4>( + llvm::map_range(getAffineMap().getConstantResults(), + [](int64_t v) -> uint64_t { return v; })); + return cstAttr.getValues()[indices]; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1101,6 +1101,7 @@ module { memref.global "private" constant @__constant_1x5x1xf32 : memref<1x5x1xf32> = dense<[[[6.250000e-02], [2.500000e-01], [3.750000e-01], [2.500000e-01], [6.250000e-02]]]> + memref.global "private" constant @__constant_32x64xf32 : memref<32x64xf32> = dense<0.000000e+00> // CHECK-LABEL: func @fold_const_init_global_memref func @fold_const_init_global_memref() -> (f32, f32) { %m = memref.get_global @__constant_1x5x1xf32 : memref<1x5x1xf32> @@ -1109,8 +1110,21 @@ return %v0, %v1 : f32, f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 6.250000e-02 : f32 // CHECK-DAG: %[[C1:.*]] = arith.constant 2.500000e-01 : f32 - // CHECK-NEXT: return - // CHECK-SAME: %[[C0]] - // CHECK-SAME: %[[C1]] + // CHECK-NEXT: return %[[C0]], %[[C1]] + } + + // CHECK-LABEL: func @fold_const_splat_global + func @fold_const_splat_global() -> memref<32x64xf32> { + // CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + %m = memref.get_global @__constant_32x64xf32 : memref<32x64xf32> + %s = memref.alloc() : memref<32x64xf32> + affine.for %i = 0 to 32 { + affine.for %j = 0 to 64 { + %v = affine.load %m[%i, %j] : memref<32x64xf32> + affine.store %v, %s[%i, %j] : memref<32x64xf32> + // CHECK: affine.store %[[CST]], %{{.*}} + } + } + return %s: memref<32x64xf32> } }