diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -36,6 +36,7 @@ void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp); void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp, DenseSet &normalizableFuncs); + Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp); }; } // end anonymous namespace @@ -384,6 +385,59 @@ funcOp.front().eraseArgument(argIndex + 1); } + // Walk over normalizable operations to normalize memrefs of the operation + // results. When `op` has memrefs with affine map in the operation results, + // new operation containin normalized memrefs is created. Then, the memrefs + // are replaced. `CallOp` is skipped here because it is handled in + // `updateFunctionSignature()`. + funcOp.walk([&](Operation *op) { + if (op->hasTrait() && + op->getNumResults() > 0 && !isa(op) && !funcOp.isExternal()) { + // Create newOp containing normalized memref in the operation result. + Operation *newOp = createOpResultsNormalized(funcOp, op); + // When all of the operation results have no memrefs or memrefs without + // affine map, `newOp` is the same with `op` and following process is + // skipped. + if (op != newOp) { + bool replacingMemRefUsesFailed = false; + for (unsigned resIndex : llvm::seq(0, op->getNumResults())) { + // Replace all uses of the old memrefs. + Value oldMemRef = op->getResult(resIndex); + Value newMemRef = newOp->getResult(resIndex); + MemRefType oldMemRefType = oldMemRef.getType().dyn_cast(); + // Check whether the operation result is MemRef type. + if (!oldMemRefType) + continue; + MemRefType newMemRefType = newMemRef.getType().cast(); + if (oldMemRefType == newMemRefType) + continue; + // TODO: Assume single layout map. Multiple maps not supported. + AffineMap layoutMap = oldMemRefType.getAffineMaps().front(); + if (failed(replaceAllMemRefUsesWith(oldMemRef, + /*newMemRef=*/newMemRef, + /*extraIndices=*/{}, + /*indexRemap=*/layoutMap, + /*extraOperands=*/{}, + /*symbolOperands=*/{}, + /*domInstFilter=*/nullptr, + /*postDomInstFilter=*/nullptr, + /*allowDereferencingOps=*/true, + /*replaceInDeallocOp=*/true))) { + newOp->erase(); + replacingMemRefUsesFailed = true; + continue; + } + } + if (!replacingMemRefUsesFailed) { + // Replace other ops with new op and delete the old op when the + // replacement succeeded. + op->replaceAllUsesWith(newOp); + op->erase(); + } + } + } + }); + // In a normal function, memrefs in the return type signature gets normalized // as a result of normalization of functions arguments, AllocOps or CallOps' // result types. Since an external function doesn't have a body, memrefs in @@ -417,3 +471,49 @@ } updateFunctionSignature(funcOp, moduleOp); } + +/// Create an operation containing normalized memrefs in the operation results. +/// When the results of `oldOp` have memrefs with affine map, the memrefs are +/// normalized, and new operation containing them in the operation results is +/// returned. If all of the results of `oldOp` have no memrefs or memrefs +/// without affine map, `oldOp` is returned without modification. +Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp, + Operation *oldOp) { + // Prepare OperationState to create newOp containing normalized memref in + // the operation results. + OperationState result(oldOp->getLoc(), oldOp->getName()); + result.addOperands(oldOp->getOperands()); + result.addAttributes(oldOp->getAttrs()); + // Add normalized MemRefType to the OperationState. + SmallVector resultTypes; + OpBuilder b(funcOp); + bool resultTypeNormalized = false; + for (unsigned resIndex : llvm::seq(0, oldOp->getNumResults())) { + auto resultType = oldOp->getResult(resIndex).getType(); + MemRefType memrefType = resultType.dyn_cast(); + // Check whether the operation result is MemRef type. + if (!memrefType) { + resultTypes.push_back(resultType); + continue; + } + // Fetch a new memref type after normalizing the old memref. + MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + /*numSymbolicOperands=*/0); + if (newMemRefType == memrefType) { + // Either memrefType already had an identity map or the map couldn't + // be transformed to an identity map. + resultTypes.push_back(memrefType); + continue; + } + resultTypes.push_back(newMemRefType); + resultTypeNormalized = true; + } + result.addTypes(resultTypes); + // When all of the results of `oldOp` have no memrefs or memrefs without + // affine map, `oldOp` is returned without modification. + if (resultTypeNormalized) { + OpBuilder bb(oldOp); + return bb.createOperation(result); + } else + return oldOp; +} diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir --- a/mlir/test/Transforms/normalize-memrefs-ops.mlir +++ b/mlir/test/Transforms/normalize-memrefs-ops.mlir @@ -14,13 +14,13 @@ // Test with op_norm and maps in arguments and in the operations in the function. // CHECK-LABEL: test_norm -// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>) +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32>) func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { %0 = alloc() : memref<1x16x14x14xf32, #map0> "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () dealloc %0 : memref<1x16x14x14xf32, #map0> - // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x64xf32> // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> return @@ -29,13 +29,13 @@ // Same test with op_nonnorm, with maps in the arguments and the operations in the function. // CHECK-LABEL: test_nonnorm -// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map>) +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32, #map>) func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () { %0 = alloc() : memref<1x16x14x14xf32, #map0> "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> () dealloc %0 : memref<1x16x14x14xf32, #map0> - // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map> + // CHECK: %[[v0:.*]] = alloc() : memref<1x16x14x14xf32, #map> // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map>, memref<1x16x14x14xf32, #map>) -> () // CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map> return @@ -44,13 +44,13 @@ // Test with op_norm, with maps in the operations in the function. // CHECK-LABEL: test_norm_mix -// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32> +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x64xf32> func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () { %0 = alloc() : memref<1x16x14x14xf32, #map0> "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> () dealloc %0 : memref<1x16x14x14xf32, #map0> - // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32> + // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x64xf32> // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> () // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32> return @@ -61,12 +61,12 @@ #map_tile = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 32, d2 mod 32, d3 mod 32)> // CHECK-LABEL: test_load_store -// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32> +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x14x14xf32> func @test_load_store(%arg0 : memref<1x16x14x14xf32>) -> () { %0 = alloc() : memref<1x16x14x14xf32, #map_tile> - // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x32xf32> + // CHECK: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x32xf32> %1 = alloc() : memref<1x16x14x14xf32> - // CHECK: %[[v1:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32> + // CHECK: %[[v1:.*]] = alloc() : memref<1x16x14x14xf32> "test.op_norm"(%0, %1) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) -> () // CHECK: "test.op_norm"(%[[v0]], %[[v1]]) : (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) -> () %cst = constant 3.0 : f32 @@ -90,6 +90,25 @@ return } +// Test with op_norm_ret, with maps in the results of normalizable operation. + +// CHECK-LABEL: test_norm_ret +// CHECK-SAME: (%[[ARG0:.*]]: memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) { +func @test_norm_ret(%arg0: memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) { + %0 = alloc() : memref<1x16x14x14xf32, #map_tile> + // CHECK-NEXT: %[[v0:.*]] = alloc() : memref<1x16x1x1x32x32xf32> + %1, %2 = "test.op_norm_ret"(%arg0) : (memref<1x16x14x14xf32, #map_tile>) -> (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32>) + // CHECK-NEXT: %[[v1:.*]], %[[v2:.*]] = "test.op_norm_ret" + // CHECK-SAME: (memref<1x16x1x1x32x32xf32>) -> (memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32>) + "test.op_norm"(%1, %0) : (memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32, #map_tile>) -> () + // CHECK-NEXT: "test.op_norm" + // CHECK-SAME: : (memref<1x16x1x1x32x32xf32>, memref<1x16x1x1x32x32xf32>) -> () + dealloc %0 : memref<1x16x14x14xf32, #map_tile> + // CHECK-NEXT: dealloc %[[v0]] : memref<1x16x1x1x32x32xf32> + return %1, %2 : memref<1x16x14x14xf32, #map_tile>, memref<1x16x14x14xf32> + // CHECK-NEXT: return %[[v1]], %[[v2]] : memref<1x16x1x1x32x32xf32>, memref<1x16x14x14xf32> +} + // Test with an arbitrary op that references the function symbol. "test.op_funcref"() {func = @test_norm_mix} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -659,6 +659,11 @@ def OpNonNorm : TEST_Op<"op_nonnorm"> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } +// Test for memrefs normalization of an op that has normalizable memref results. +def OpNormRet : TEST_Op<"op_norm_ret", [MemRefsNormalizable]> { + let arguments = (ins AnyMemRef:$X); + let results = (outs AnyMemRef:$Y, AnyMemRef:$Z); +} // Test for memrefs normalization of an op with a reference to a function // symbol.