diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -263,16 +263,24 @@ // type at the caller site. Optional symbolUses = funcOp.getSymbolUses(moduleOp); for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - Operation *callOp = symbolUse.getUser(); - OpBuilder builder(callOp); - StringRef callee = cast(callOp).getCallee(); + Operation *userOp = symbolUse.getUser(); + OpBuilder builder(userOp); + // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes + // that the `userOp` has no memrefs to be replaced. This works when + // the `userOp` is `test.op_entrypoint()` in `normalize-memrefs-ops.mlir`. + // If `userOp` can't be casted but contains memrefs, this code need to be + // update. + auto callOp = dyn_cast(userOp); + if (!callOp) + continue; + StringRef callee = callOp.getCallee(); Operation *newCallOp = builder.create( - callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), - callOp->getOperands()); + userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), + userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; - for (unsigned resIndex : llvm::seq(0, callOp->getNumResults())) { - OpResult oldResult = callOp->getResult(resIndex); + for (unsigned resIndex : llvm::seq(0, userOp->getNumResults())) { + OpResult oldResult = userOp->getResult(resIndex); OpResult newResult = newCallOp->getResult(resIndex); // This condition ensures that if the result is not of type memref or if // the resulting memref was already having a trivial map layout then we @@ -302,8 +310,8 @@ if (replacingMemRefUsesFailed) continue; // Replace all uses for other non-memref result types. - callOp->replaceAllUsesWith(newCallOp); - callOp->erase(); + userOp->replaceAllUsesWith(newCallOp); + userOp->erase(); if (returnTypeChanged) { // Since the return type changed it might lead to a change in function's // signature. diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -55,3 +55,6 @@ // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> return } + +// Test with entrypoint +"test.op_entrypoint"() {func = @test_norm_mix} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" @@ -29,7 +30,6 @@ #include "TestOpEnums.h.inc" - #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -628,6 +628,18 @@ let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } +// Test for memrefs normalization including entrypoint +def OpEntryPoint : TEST_Op<"op_entrypoint"> { + let summary = "Indicate TEST entry point"; + let description = [{ + The "test.op_entrypoint" function indicates the main entry point. + }]; + + let builders = [OpBuilder<[{OpBuilder &builder, OperationState &state, + FuncOp function}]>]; +} + + // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice.