diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1625,7 +1625,10 @@ // continue fusing based on new operands. for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { auto loadMemRef = cast(loadOpInst).getMemRef(); - if (visitedMemrefs.count(loadMemRef) == 0) + // NOTE: Change 'loads' to a hash set in case efficiency is an + // issue. We still use a vector since it's expected to be small. + if (visitedMemrefs.count(loadMemRef) == 0 && + !llvm::is_contained(loads, loadOpInst)) loads.push_back(loadOpInst); } diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -2422,5 +2422,45 @@ // CHECK-NEXT: affine.store %{{.*}}, %[[A]] // CHECK-NEXT: affine.load %[[B]] // CHECK-NOT: affine.for %{{.*}} + // CHECK: return return } + +// ----- + +// MAXIMAL-LABEL: func @reshape_into_matmul +func @reshape_into_matmul(%lhs : memref<1024x1024xf32>, + %R: memref<16x64x1024xf32>, %out: memref<1024x1024xf32>) { + %rhs = alloc() : memref<1024x1024xf32> + + // Reshape from 3-d to 2-d. + affine.for %i0 = 0 to 16 { + affine.for %i1 = 0 to 64 { + affine.for %k = 0 to 1024 { + %v = affine.load %R[%i0, %i1, %k] : memref<16x64x1024xf32> + affine.store %v, %rhs[64*%i0 + %i1, %k] : memref<1024x1024xf32> + } + } + } + + // Matmul. + affine.for %i = 0 to 1024 { + affine.for %j = 0 to 1024 { + affine.for %k = 0 to 1024 { + %0 = affine.load %rhs[%k, %j] : memref<1024x1024xf32> + %1 = affine.load %lhs[%i, %k] : memref<1024x1024xf32> + %2 = mulf %1, %0 : f32 + %3 = affine.load %out[%i, %j] : memref<1024x1024xf32> + %4 = addf %3, %2 : f32 + affine.store %4, %out[%i, %j] : memref<1024x1024xf32> + } + } + } + return +} +// MAXIMAL-NEXT: alloc +// MAXIMAL-NEXT: affine.for +// MAXIMAL-NEXT: affine.for +// MAXIMAL-NEXT: affine.for +// MAXIMAL-NOT: affine.for +// MAXIMAL: return