diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -224,9 +224,9 @@ /// Returns a map of codomain to domain dimensions such that the first codomain /// dimension for a particular domain dimension is selected. -/// Returns an empty map if the input map is empty or if `map` is not invertible -/// (i.e. `map` does not contain a subset that is a permutation of full domain -/// rank). +/// Returns an empty map if the input map is empty. +/// Returns null map (not empty map) if `map` is not invertible (i.e. `map` does +/// not contain a subset that is a permutation of full domain rank). /// /// Prerequisites: /// 1. `map` has no symbols. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -416,6 +416,8 @@ AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); AffineMap invProducerResultIndexMap = inversePermutation(producerOp.getOutputIndexingMap(0)); + if (!invProducerResultIndexMap) + return {}; // Compute the fused op operandslist by replacing the operand corresponding to // the result of the producer, with the operands of the producer. @@ -559,6 +561,9 @@ if (!fusedOp) continue; rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); + if (llvm::all_of(definingOp.getResults(), + [](Value val) -> bool { return val.use_empty(); })) + rewriter.eraseOp(definingOp); return success(); } return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -652,8 +652,10 @@ linalgOp.indexing_maps().template getAsRange(); auto maps = functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); - auto invertedMap = inversePermutation(concatAffineMaps(maps)); - if (!invertedMap) { + AffineMap invertedMap = inversePermutation(concatAffineMaps(maps)); + if (!invertedMap) + return {}; + if (invertedMap.isEmpty()) { LinalgScopedEmitter::emitScalarImplementation( {}, linalgOp); return LinalgLoops(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -291,11 +291,13 @@ auto linOp = cast(op); auto permutationMap = inversePermutation( AffineMap::getPermutationMap(permutation, rewriter.getContext())); + assert(permutationMap && "expected permutation to be invertible"); SmallVector newIndexingMap; auto indexingMaps = linOp.indexing_maps().getValue(); for (unsigned i = 0, e = linOp.getNumInputsAndOutputs(); i != e; ++i) { - AffineMap m = indexingMaps[i].cast().getValue().compose( - permutationMap); + AffineMap m = indexingMaps[i].cast().getValue(); + if (!permutationMap.isEmpty()) + m = m.compose(permutationMap); newIndexingMap.push_back(m); } auto itTypes = linOp.iterator_types().getValue(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -350,6 +350,8 @@ if (!permutation.empty()) invPermutationMap = inversePermutation( AffineMap::getPermutationMap(permutation, ScopedContext::getContext())); + if (!invPermutationMap) + return llvm::None; OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -362,7 +364,8 @@ auto maps = functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange); auto viewSizesToLoopsMap = inversePermutation(concatAffineMaps(maps)); - assert(viewSizesToLoopsMap && "expected invertible map"); + if (!viewSizesToLoopsMap) + return llvm::None; SmallVector loopRanges; LoopIndexToRangeIndexMap loopIndexToRangeIndex; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -83,7 +83,6 @@ // CHECK-DAG: [[MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: [[MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0)> -// CHECK-DAG: [[MAP2:#[a-zA-Z0-9_]*]] = affine_map<(d0) -> (d0)> #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0) -> (d0)> @@ -119,7 +118,7 @@ %1 = addf %arg3, %arg4 : f32 linalg.yield %1 : f32 }: tensor, tensor -> tensor - // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64 + // CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64 // CHECK: addf // CHECK: mulf %1 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = []} %0, %arg2 { diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -913,3 +913,46 @@ // CHECKPARALLEL: %[[CONST:.*]] = constant 1.000000e+00 : f32 // CHECKPARALLEL: loop.parallel (%[[i:.*]]) // CHECKPARALLEL: store %[[CONST]], %[[ARG0]] + +#scalar_access = [ + affine_map<() -> ()>, + affine_map<() -> ()>, + affine_map<() -> ()> +] +#scalar_trait = { + args_in = 2, + args_out = 1, + iterator_types = [], + indexing_maps = #scalar_access, + library_call = "some_external_fn" +} +func @scalar_code(%arg0: memref, %arg1 : memref, %arg2 : memref) +{ + linalg.generic #scalar_trait %arg0, %arg1, %arg2 { + ^bb(%a : f32, %b : f32, %c : f32) : + %0 = addf %a, %b : f32 + linalg.yield %0 : f32 + } : memref, memref, memref + return +} +// CHECKLOOP-LABEL: @scalar_code +// CHECKLOOP-SAME: %[[ARG0]]: memref +// CHECKLOOP-SAME: %[[ARG1]]: memref +// CHECKLOOP-SAME: %[[ARG2]]: memref +// CHECKLOOP-NOT: loop.for +// CHECKLOOP-DAG: load %[[ARG0]][] +// CHECKLOOP-DAG: load %[[ARG1]][] +// CHECKLOOP-DAG: load %[[ARG2]][] +// CHECKLOOP: addf +// CHECKLOOP: store %{{.*}}, %[[ARG2]][] + +// CHECKPARALLEL-LABEL: @scalar_code +// CHECKPARALLEL-SAME: %[[ARG0]]: memref +// CHECKPARALLEL-SAME: %[[ARG1]]: memref +// CHECKPARALLEL-SAME: %[[ARG2]]: memref +// CHECKPARALLEL-NOT: loop.for +// CHECKPARALLEL-DAG: load %[[ARG0]][] +// CHECKPARALLEL-DAG: load %[[ARG1]][] +// CHECKPARALLEL-DAG: load %[[ARG2]][] +// CHECKPARALLEL: addf +// CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]