diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -263,16 +263,23 @@ // type at the caller site. Optional symbolUses = funcOp.getSymbolUses(moduleOp); for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - Operation *callOp = symbolUse.getUser(); - OpBuilder builder(callOp); - StringRef callee = cast(callOp).getCallee(); + Operation *userOp = symbolUse.getUser(); + OpBuilder builder(userOp); + // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes + // that the non-CallOp has no memrefs to be replaced. + // TODO: Handle cases where a non-CallOp symbol use of a function deals with + // memrefs. + auto callOp = dyn_cast(userOp); + if (!callOp) + continue; + StringRef callee = callOp.getCallee(); Operation *newCallOp = builder.create( - callOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), - callOp->getOperands()); + userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), + userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; - for (unsigned resIndex : llvm::seq(0, callOp->getNumResults())) { - OpResult oldResult = callOp->getResult(resIndex); + for (unsigned resIndex : llvm::seq(0, userOp->getNumResults())) { + OpResult oldResult = userOp->getResult(resIndex); OpResult newResult = newCallOp->getResult(resIndex); // This condition ensures that if the result is not of type memref or if // the resulting memref was already having a trivial map layout then we @@ -302,8 +309,8 @@ if (replacingMemRefUsesFailed) continue; // Replace all uses for other non-memref result types. - callOp->replaceAllUsesWith(newCallOp); - callOp->erase(); + userOp->replaceAllUsesWith(newCallOp); + userOp->erase(); if (returnTypeChanged) { // Since the return type changed it might lead to a change in function's // signature. diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -89,3 +89,7 @@ // CHECK: dealloc %[[v1]] : memref<1x16x14x14xf32> return } + +// Test with an arbitrary op that references the function symbol. + +"test.op_funcref"() {func = @test_norm_mix} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/RegionKindInterface.h" @@ -29,7 +30,6 @@ #include "TestOpEnums.h.inc" - #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -629,6 +629,17 @@ let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } +// Test for memrefs normalization of an op with a reference to a function +// symbol. +def OpFuncRef : TEST_Op<"op_funcref"> { + let summary = "Test op with a reference to a function symbol"; + let description = [{ + The "test.op_funcref" is a test op with a reference to a function symbol. + }]; + let builders = [OpBuilder<[{OpBuilder &builder, OperationState &state, + FuncOp function}]>]; +} + // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice.