diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -322,9 +322,8 @@ if (numCommonLoops == 0) { Block *block = srcAccess.opInst->getBlock(); - while (!llvm::isa(block->getParentOp())) { + while (!block->getParentOp()->hasTrait()) block = block->getParentOp()->getBlock(); - } return block; } Value commonForIV = srcDomain.getValue(numCommonLoops - 1); diff --git a/mlir/test/Dialect/Affine/scalrep.mlir b/mlir/test/Dialect/Affine/scalrep.mlir --- a/mlir/test/Dialect/Affine/scalrep.mlir +++ b/mlir/test/Dialect/Affine/scalrep.mlir @@ -701,3 +701,46 @@ // CHECK: } else { // CHECK: scf.yield %[[pi]] : f64 // CHECK: } + +// Check if scalar replacement works correctly when affine memory ops are in the +// body of an scf.for. + +// CHECK-LABEL: func @affine_store_load_in_scope +func.func @affine_store_load_in_scope(%memref: memref<1x4094x510x1xf32>, %memref_2: memref<4x4x1x64xf32>, %memref_0: memref<1x2046x254x1x64xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %c768 = arith.constant 768 : index + scf.for %i = %c0 to %c768 step %c1 { + %9 = arith.remsi %i, %c64 : index + %10 = arith.divsi %i, %c64 : index + %11 = arith.remsi %10, %c2 : index + %12 = arith.divsi %10, %c2 : index + test.affine_scope { + %14 = arith.muli %12, %c2 : index + %15 = arith.addi %c2, %14 : index + %16 = arith.addi %15, %c0 : index + %18 = arith.muli %11, %c2 : index + %19 = arith.addi %c2, %18 : index + %20 = affine.load %memref[0, symbol(%16), symbol(%19), 0] : memref<1x4094x510x1xf32> + %21 = affine.load %memref_2[0, 0, 0, symbol(%9)] : memref<4x4x1x64xf32> + %24 = affine.load %memref_0[0, symbol(%12), symbol(%11), 0, symbol(%9)] : memref<1x2046x254x1x64xf32> + %25 = arith.mulf %20, %21 : f32 + %26 = arith.addf %24, %25 : f32 + // CHECK: %[[A:.*]] = arith.addf + affine.store %26, %memref_0[0, symbol(%12), symbol(%11), 0, symbol(%9)] : memref<1x2046x254x1x64xf32> + %27 = arith.addi %19, %c1 : index + %28 = affine.load %memref[0, symbol(%16), symbol(%27), 0] : memref<1x4094x510x1xf32> + %29 = affine.load %memref_2[0, 1, 0, symbol(%9)] : memref<4x4x1x64xf32> + %30 = affine.load %memref_0[0, symbol(%12), symbol(%11), 0, symbol(%9)] : memref<1x2046x254x1x64xf32> + %31 = arith.mulf %28, %29 : f32 + %32 = arith.addf %30, %31 : f32 + // The addf above will get the forwarded value from the store on + // %memref_0 above which is being loaded into %30.. + // CHECK: arith.addf %[[A]], + "terminate"() : () -> () + } + } + return +}