diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -384,6 +384,79 @@ funcOp.front().eraseArgument(argIndex + 1); } + // Walk over normalizable operations to normalize memrefs of the operation + // results. + funcOp.walk([&](Operation *op) { + if (op->hasTrait() && + op->getNumResults() >= 1 && !dyn_cast(op)) { + // Insert newOp containing normalized memref in the operation result. + OperationState result(op->getLoc(), op->getName()); + result.addOperands(op->getOperands()); + result.addAttributes(op->getAttrs()); + // Add types to `result` + SmallVector resultTypes; + bool resultTypeNormalized = false; + for (unsigned resIndex : llvm::seq(0, op->getNumResults())) { + auto resultType = op->getResult(resIndex).getType(); + MemRefType memrefType = resultType.dyn_cast(); + // Check whether the operation result is MemRef type. + if (!memrefType) { + resultTypes.push_back(resultType); + continue; + } + // Fetch a new memref type after normalizing the old memref. + MemRefType newMemRefType = + normalizeMemRefType(memrefType, b, + /*numSymbolicOperands=*/0); + if (newMemRefType == memrefType || funcOp.isExternal()) { + // Either memrefType already had an identity map or the map couldn't + // be transformed to an identity map. + resultTypes.push_back(memrefType); + continue; + } + resultTypes.push_back(newMemRefType); + resultTypeNormalized = true; + } + if (resultTypeNormalized) { + result.addTypes(resultTypes); + OpBuilder bb(op); + auto newOp = bb.createOperation(result); + bool replacingMemRefUsesFailed = false; + + for (unsigned resIndex : llvm::seq(0, op->getNumResults())) { + // Replace all uses of the old memrefs. + Value oldMemRef = op->getResult(resIndex); + Value newMemRef = newOp->getResult(resIndex); + MemRefType oldMemRefType = oldMemRef.getType().dyn_cast(); + MemRefType newMemRefType = newMemRef.getType().dyn_cast(); + if (oldMemRefType == newMemRefType) + continue; + // TODO: Multiple maps not supported + AffineMap layoutMap = oldMemRefType.getAffineMaps().front(); + if (failed(replaceAllMemRefUsesWith(oldMemRef, + /*newMemRef=*/newMemRef, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true, + /*replaceInDeallocOp=*/true))) { + newOp->erase(); + replacingMemRefUsesFailed = true; + continue; + } + } + if (!replacingMemRefUsesFailed) { + // Replace old operation with new one and delete the old one. + op->replaceAllUsesWith(newOp); + op->erase(); + } + } + } + }); + // In a normal function, memrefs in the return type signature gets normalized // as a result of normalization of functions arguments, AllocOps or CallOps' // result types. Since an external function doesn't have a body, memrefs in diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -90,6 +90,27 @@ return } +// Test with op_norm_ret, with maps in the results of normalizable operation. + +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)> + +// CHECK-LABEL: test_norm_ret +// CHECK-SAME: , %[[ARG1:[a-z0-9]*]]: memref<1x16x1x1x32x32xf32>) -> memref<1x16x1x1x32x32xf32> { +func @test_norm_ret(%arg0: memref<1x32768xf32>, %arg1: memref<1x16x14x14xf32, #map1>) -> memref<1x16x14x14xf32, #map1> { + %0 = alloc() : memref<1x16x14x14xf32, #map1> + // CHECK-NEXT: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32> + %1 = "test.op_norm_ret"(%arg0, %arg1) : (memref<1x32768xf32>, memref<1x16x14x14xf32, #map1>) -> (memref<1x16x14x14xf32, #map1>) + // CHECK-NEXT: %[[v1:[a-z0-9]*]] = "test.op_norm_ret" + // CHECK-SAME: , memref<1x16x1x1x32x32xf32>) -> memref<1x16x1x1x32x32xf32> + "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map1>, memref<1x16x14x14xf32, #map1>) -> () + // CHECK-NEXT: "test.op_norm" + // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map1> + // CHECK-NEXT: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32> + return %1 : memref<1x16x14x14xf32, #map1> + // CHECK-NEXT: return %[[v1]] : memref<1x16x1x1x32x32xf32> +} + // Test with an arbitrary op that references the function symbol. "test.op_funcref"() {func = @test_norm_mix} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -627,6 +627,11 @@ def OpNonNorm : TEST_Op<"op_nonnorm"> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } +// Test for memrefs normalization of an op with normalizable memrefs. +def OpNormRet : TEST_Op<"op_norm_ret", [MemRefsNormalizable]> { + let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); + let results = (outs AnyMemRef:$Z); +} // Test for memrefs normalization of an op with a reference to a function // symbol.