diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -108,6 +108,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -117,6 +118,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" @@ -1491,9 +1493,7 @@ << "cannot bufferize bodiless function that returns a tensor"; } else { ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) - return funcOp->emitError() << "cannot bufferize a FuncOp with tensors " - "and without a unique ReturnOp"; + assert(returnOp && "expected func with single return op"); // For each FuncOp result, keep track of which inplace argument it reuses. for (OpOperand &returnOperand : returnOp->getOpOperands()) { @@ -2474,9 +2474,7 @@ // Support only single return-terminated block in the function. ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) - return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; + assert(returnOp && "expected func with single return op"); // 1. For each FuncOp result, keep track of which inplace argument it reuses. SmallVector returnValues; @@ -2574,7 +2572,15 @@ DenseMap> calledBy; // For each FuncOp, the number of CallOpInterface it contains. DenseMap numberCallOpsContainedInFuncOp; - WalkResult res = moduleOp.walk([&](FuncOp funcOp) { + WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult { + if (!funcOp.body().empty()) { + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() + << "cannot bufferize a FuncOp with tensors and " + "without a unique ReturnOp"; + } + numberCallOpsContainedInFuncOp[funcOp] = 0; return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { // Only support CallOp for now. @@ -2622,8 +2628,17 @@ }; } // end namespace +static void applyEnablingTransformations(ModuleOp moduleOp) { + RewritePatternSet patterns(moduleOp.getContext()); + patterns.add(moduleOp.getContext()); + (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)); +} + void LinalgComprehensiveModuleBufferize::runOnOperation() { ModuleOp moduleOp = getOperation(); + applyEnablingTransformations(moduleOp); + + moduleOp.dump(); SmallVector orderedFuncOps; DenseMap> callerMap; diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -485,9 +485,11 @@ // %r0 must be out of place because one use of %t in the subsequent production // of %r1 is read. // CHECK: scf.for + // CHECK-NEXT: call // CHECK-NEXT: scf.yield // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} %r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor) { + call @some_use(%t) : (tensor) -> () scf.yield %t : tensor } @@ -504,11 +506,13 @@ // %r2 must be out of place because one use of %t in the subsequent production // of %r3 is read. // CHECK: linalg.tiled_loop + // CHECK-NEXT: call // CHECK-NEXT: linalg.yield // CHECK-NEXT: {__inplace_results_attr__ = ["false"]} %r2 = linalg.tiled_loop (%i) = (%lb) to (%ub) step (%step) ins() outs(%t = %B: tensor) { + call @some_use(%t) : (tensor) -> () linalg.yield %t : tensor } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -17,18 +17,20 @@ // ----- // expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}} -func @switch(%flag : i32, %caseOperand : i32, %t1 : tensor, %t2 : tensor) - -> (tensor) +func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor, %t2 : tensor) + -> (tensor, tensor) { - switch %flag : i32, [ - default: ^bb1(%caseOperand : i32), - 42: ^bb2(%caseOperand : i32) - ] - - ^bb1(%bb1arg : i32): - return %t1 : tensor - ^bb2(%bb2arg : i32): - return %t2 : tensor + cond_br %cond1, ^bb1, ^bb2 + + ^bb1: + %T:2 = scf.if %cond2 -> (tensor, tensor) { + scf.yield %t1, %t2 : tensor, tensor + } else { + scf.yield %t2, %t1 : tensor, tensor + } + return %T#0, %T#1 : tensor, tensor + ^bb2: + return %t2, %t1 : tensor, tensor } // ----- diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir @@ -6,15 +6,73 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\ // RUN: FileCheck %s -func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor) -> tensor { - %v0 = constant 0.0 : f32 +#map0 = affine_map<(d0, d1)[s0] -> ((d1 - d0) ceildiv s0)> +#map1 = affine_map<(d0, d1)[s0] -> ((d0 - d1) ceildiv s0)> + +func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: tensor {linalg.inplaceable = true}) -> tensor { + %c64 = constant 64 : index + %cst = constant 0.000000e+00 : f32 + %c2 = constant 2 : index + %c0 = constant 0 : index + %0 = linalg.fill(%cst, %arg2) : f32, tensor -> tensor + %1 = affine.apply #map0(%c0, %c64)[%c2] + %2 = linalg.init_tensor [%1, 2] : tensor + %3 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %2) -> (tensor) { + %8 = affine.apply #map1(%arg3, %c0)[%c2] + %9 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32> + %10 = tensor.cast %9 : tensor<2xf32> to tensor + %11 = linalg.pad_tensor %10 low[%c0] high[%c0] { + ^bb0(%arg5: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<2xf32> + %12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor + scf.yield %12 : tensor + } + + // %B = tensor.cast %3 : tensor to tensor<*xf32> + // call @print_memref_f32(%B) : (tensor<*xf32>) -> () + + %4 = affine.apply #map0(%c0, %c64)[%c2] + %5 = linalg.init_tensor [%4, 2] : tensor + %6 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %5) -> (tensor) { + %8 = affine.apply #map1(%arg3, %c0)[%c2] + %9 = tensor.extract_slice %arg0[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32> + %10 = tensor.cast %9 : tensor<2xf32> to tensor + %11 = linalg.pad_tensor %10 low[%c0] high[%c0] { + ^bb0(%arg5: index): // no predecessors + linalg.yield %cst : f32 + } : tensor to tensor<2xf32> + %12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor + scf.yield %12 : tensor + } + + // %A = tensor.cast %6 : tensor to tensor<*xf32> + // call @print_memref_f32(%A) : (tensor<*xf32>) -> () + + // %C = tensor.cast %0 : tensor to tensor<*xf32> + // call @print_memref_f32(%C) : (tensor<*xf32>) -> () - %d = linalg.fill(%v0, %c) : f32, tensor -> tensor + %7 = scf.for %arg3 = %c0 to %c64 step %c2 iter_args(%arg4 = %0) -> (tensor) { + %8 = tensor.extract_slice %arg0[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32> + %9 = tensor.cast %8 : tensor<2xf32> to tensor + %10 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32> + %11 = tensor.cast %10 : tensor<2xf32> to tensor + %12 = affine.apply #map1(%arg3, %c0)[%c2] + %13 = tensor.extract_slice %6[%12, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + %14 = affine.apply #map1(%arg3, %c0)[%c2] + %15 = tensor.extract_slice %3[%14, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + %16 = linalg.dot ins(%13, %15 : tensor<2xf32>, tensor<2xf32>) outs(%arg4 : tensor) -> tensor - %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>) - outs(%d: tensor) -> tensor + // %AA = tensor.cast %13 : tensor<2xf32> to tensor<*xf32> + // call @print_memref_f32(%AA) : (tensor<*xf32>) -> () + // %BB = tensor.cast %15 : tensor<2xf32> to tensor<*xf32> + // call @print_memref_f32(%BB) : (tensor<*xf32>) -> () + // %CC = tensor.cast %16 : tensor to tensor<*xf32> + // call @print_memref_f32(%CC) : (tensor<*xf32>) -> () - return %e : tensor + scf.yield %16 : tensor + } + return %7 : tensor } func @main() {