diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -162,6 +162,9 @@ /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. +/// If the second part is statically known to be empty, do not create it +/// and return nullptr instead. Error state is signalled by returning +/// a pair of nullptrs. /// /// For example, the following op: /// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1043,6 +1043,7 @@ // Split each target operation. SmallVector first, second; + Operation *noSecondPart = nullptr; for (const auto &pair : llvm::zip(payload, splitPoints)) { Operation *target = std::get<0>(pair); auto linalgOp = dyn_cast(target); @@ -1067,6 +1068,32 @@ std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( rewriter, cast(linalgOp.getOperation()), getDimension(), std::get<1>(pair)); + + // Propagate errors. + if (!first.back() && !second.back()) { + auto diag = emitDefiniteFailure() << "internal failure in splitting"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + // Do not add null second parts. + if (!second.back()) { + noSecondPart = target; + second.pop_back(); + } + } + + if (second.size() != first.size() && !second.empty()) { + results.set(getFirst().cast(), {}); + results.set(getSecond().cast(), {}); + auto diag = + emitSilenceableError() + << "splitting does not produce the second part for a subset of targets"; + diag.attachNote() << "expected splitting to produce the second part of all " + "or none of the targets"; + diag.attachNote(noSecondPart->getLoc()) + << "first target with no second part"; + return diag; } results.set(getFirst().cast(), first); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -128,6 +128,10 @@ createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults, dimension, remainingSize, totalOffset, secondResults); + // Propagate any errors in part creation. + if (!firstPart || !secondPart) + return {TilingInterface(), TilingInterface()}; + // Replace the original op with the results of the two newly created ops. rewriter.replaceOp(op, secondResults); return {firstPart, secondPart}; diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -46,6 +46,16 @@ return %0 : tensor<100xf32> } +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1:2 = transform.structured.split %0 after 42 { dimension = 0 } +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + // CHECK-LABEL: @one_d_static_overflow // CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { @@ -268,3 +278,45 @@ return %0 : tensor<100xf32> } +// ----- + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + // expected-error @below {{splitting does not produce the second part for a subset of targets}} + // expected-note @below {{expected splitting to produce the second part of all or none of the targets}} + %1:2 = transform.structured.split %0 after 142 { dimension = 0 } +} + +func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 + +func.func @split_one_but_not_other( + %arg0: tensor<100xf32>, %arg1: tensor<100xf32>, + %arg2: tensor<200xf32>, %arg3: tensor<200xf32>) + -> (tensor<100xf32>, tensor<200xf32>) { + // expected-note @below {{first target with no second part}} + %0 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) { + ^bb0(%arg4: f32, %arg5: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<100xf32> + + %1 = linalg.generic { + indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], + iterator_types = ["parallel"] + } + ins(%arg2: tensor<200xf32>) outs(%arg3: tensor<200xf32>) { + ^bb0(%arg4: f32, %arg5: f32): + %i = linalg.index 0 : index + %call_res = func.call @elem(%arg4, %i, %i) : (f32, index, index) -> f32 + linalg.yield %call_res : f32 + } -> tensor<200xf32> + + return %0, %1 : tensor<100xf32>, tensor<200xf32> +} +