diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -61,13 +61,23 @@ "expected single use of linalg op"); } - if (!linalgOp.hasTensorSemantics()) + if (!linalgOp.hasTensorSemantics()) { return rewriter.notifyMatchFailure(sliceOp, "expected tensor of linalg op"); + } if (!sliceOp.hasUnitStride()) return rewriter.notifyMatchFailure(sliceOp, "expected unit stride"); + // Check all input indexing maps are a projected permutation. + for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) { + AffineMap inputMap = linalgOp.getTiedIndexingMap(operand); + if (!inputMap.isProjectedPermutation()) { + return rewriter.notifyMatchFailure(sliceOp, "expected a projected " + "permutation for input"); + } + } + auto resultNumber = source.cast().getResultNumber(); AffineMap outputMap = linalgOp.getTiedIndexingMap(linalgOp.getOutputOperand(resultNumber)); @@ -83,7 +93,7 @@ AffineMap inverseMap = inversePermutation(outputMap); - // bubble up extract slice for each operand. + // Bubble up extract slice for each operand. auto sliceLoc = sliceOp.getLoc(); auto sliceOffsets = getValueOrCreateConstantIndexOp( rewriter, sliceLoc, sliceOp.getMixedOffsets());