diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -224,9 +224,9 @@ /// Returns a map of codomain to domain dimensions such that the first codomain /// dimension for a particular domain dimension is selected. -/// Returns an empty map if the input map is empty or if `map` is not invertible -/// (i.e. `map` does not contain a subset that is a permutation of full domain -/// rank). +/// Returns an empty map if the input map is empty. +/// Returns null map (not empty map) if `map` is not invertible (i.e. `map` does +/// not contain a subset that is a permutation of full domain rank). /// /// Prerequisites: /// 1. `map` has no symbols. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -416,6 +416,8 @@ AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); AffineMap invProducerResultIndexMap = inversePermutation(producerOp.getOutputIndexingMap(0)); + if (!invProducerResultIndexMap) + return {}; // Compute the fused op operandslist by replacing the operand corresponding to // the result of the producer, with the operands of the producer. @@ -559,6 +561,9 @@ if (!fusedOp) continue; rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); + if (llvm::all_of(definingOp.getResults(), + [](Value val) -> bool { return val.use_empty(); })) + rewriter.eraseOp(definingOp); return success(); } return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -649,9 +649,11 @@ auto nLoops = nPar + nRed + nWin; auto mapsRange = linalgOp.indexing_maps().template getAsRange(); - auto maps = llvm::to_vector<8>( - llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); + auto maps = + functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); + if (!invertedMap) + return {}; if (invertedMap.isEmpty()) { LinalgScopedEmitter::emitScalarImplementation( {}, linalgOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -291,11 +291,13 @@ auto linOp = cast(op); auto permutationMap = inversePermutation( AffineMap::getPermutationMap(permutation, rewriter.getContext())); + assert(permutationMap && "expected permutation to be invertible"); SmallVector newIndexingMap; auto indexingMaps = linOp.indexing_maps().getValue(); for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { - AffineMap m = indexingMaps[i].cast().getValue().compose( - permutationMap); + AffineMap m = indexingMaps[i].cast().getValue(); + if (!permutationMap.isEmpty()) + m = m.compose(permutationMap); newIndexingMap.push_back(m); } auto itTypes = linOp.iterator_types().getValue(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -349,6 +349,8 @@ if (!permutation.empty()) invPermutationMap = inversePermutation( AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); + if (!invPermutationMap) + return llvm::None; OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -361,7 +363,8 @@ auto maps = llvm::to_vector<8>( llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps)); - assert(viewSizesToLoopsMap && "expected invertible map"); + if (!viewSizesToLoopsMap) + return llvm::None; SmallVector loopRanges; LoopIndexToRangeIndexMap loopIndexToRangeIndex; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -83,7 +83,6 @@ // CHECK-DAG: [[MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: [[MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: [[MAP2:#[a-zA-Z0-9_]*]] = affine_map<(d0) -> (d0)> #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0) -> (d0)> @@ -119,7 +118,7 @@ %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64 + // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 // CHECK: addf // CHECK: mulf %1 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %0, %arg2 {