diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -33,7 +33,8 @@ /// can be dropped, i.e. the remaining operands can compute the loop /// bounds of the op. bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, - ArrayRef droppedOperands); + ArrayRef droppedOperands, + MLIRContext *context); } // namespace detail /// Checks whether `linalgOp` conforms to ContractionOpInterface. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -610,11 +610,11 @@ }], /*retTy=*/"AffineMap", /*methodName=*/"getLoopsToShapesMap", - /*args=*/(ins), + /*args=*/(ins "MLIRContext *":$context), /*methodBody=*/"", /*defaultImplementation=*/[{ auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps); + return concatAffineMaps(maps, context); }] >, InterfaceMethod< @@ -635,10 +635,10 @@ }], /*retTy=*/"AffineMap", /*methodName=*/"getShapesToLoopsMap", - /*args=*/(ins), + /*args=*/(ins "MLIRContext *":$context), /*methodBody=*/"", /*defaultImplementation=*/[{ - return inversePermutation(getLoopsToShapesMap()); + return inversePermutation(getLoopsToShapesMap(context)); }] >, InterfaceMethod< @@ -648,10 +648,11 @@ }], /*retTy=*/"bool", /*methodName=*/"canOpOperandsBeDropped", - /*args=*/(ins "ArrayRef":$droppedOperands), + /*args=*/(ins "MLIRContext *":$context, + "ArrayRef":$droppedOperands), /*methodBody=*/"", /*defaultImplementation=*/[{ - return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); + return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands, context); }] >, InterfaceMethod< @@ -680,11 +681,11 @@ }], /*retTy=*/"SmallVector", /*methodName=*/"getStaticLoopRanges", - /*args=*/(ins), + /*args=*/(ins "MLIRContext *":$context), /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector viewSizes = getStaticShape(); - AffineMap invertedMap = getShapesToLoopsMap(); + AffineMap invertedMap = getShapesToLoopsMap(context); assert(invertedMap && "expected a valid Linalg op to call the method"); return invertedMap.compose(viewSizes); }] @@ -736,7 +737,7 @@ /// Compute the static loop sizes necessary to vectorize the computation. /// This is done by applying `getShapesToLoopsMap` to /// `createFlatListOfOperandStaticDims`. - SmallVector computeStaticLoopSizes(); + SmallVector computeStaticLoopSizes(MLIRContext *context); /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -535,7 +535,7 @@ /// ```mlir /// (i, j, k) -> (i, k, k, j, i, j) /// ``` -AffineMap concatAffineMaps(ArrayRef maps); +AffineMap concatAffineMaps(ArrayRef maps, MLIRContext *context); /// Returns the map that results from projecting out the dimensions specified in /// `projectedDimensions`. The projected dimensions are set to 0. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -29,7 +29,8 @@ // Interface utility functions //===----------------------------------------------------------------------===// bool linalg::detail::canOpOperandsBeDroppedImpl( - linalg::LinalgOp linalgOp, ArrayRef droppedOperands) { + linalg::LinalgOp linalgOp, ArrayRef droppedOperands, + MLIRContext *context) { SmallVector indexingMaps; for (auto &opOperand : linalgOp->getOpOperands()) { if (llvm::is_contained(droppedOperands, &opOperand)) @@ -41,7 +42,8 @@ // if the op has no loops. return linalgOp.getNumLoops() == 0; } - return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); + return inversePermutation(concatAffineMaps(indexingMaps, context)) != + AffineMap(); } //===----------------------------------------------------------------------===// @@ -504,7 +506,7 @@ } SmallVector LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { - AffineMap map = getLoopsToShapesMap(); + AffineMap map = getLoopsToShapesMap(b.getContext()); unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); auto viewSizes = createFlatListOfOperandDims(b, loc); SmallVector res(numDims); @@ -520,8 +522,8 @@ return res; } -SmallVector LinalgOp::computeStaticLoopSizes() { - AffineMap map = getLoopsToShapesMap(); +SmallVector LinalgOp::computeStaticLoopSizes(MLIRContext *context) { + AffineMap map = getLoopsToShapesMap(context); unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); SmallVector allShapeSizes = createFlatListOfOperandStaticDims(); SmallVector res(numDims, 0); @@ -579,7 +581,7 @@ // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap) // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1) - AffineMap loopsToShapesMap = getLoopsToShapesMap(); + AffineMap loopsToShapesMap = getLoopsToShapesMap(b.getContext()); // Find the position in the above map that represents the shape of the // result:dim being inferred. @@ -591,7 +593,7 @@ resultShapesSubMapPos.first, resultShapesSubMapPos.second - resultShapesSubMapPos.first); AffineMap resultShapesFromInputShapesMap = - loopToResultsShapeMap.compose(getShapesToLoopsMap()); + loopToResultsShapeMap.compose(getShapesToLoopsMap(b.getContext())); // Check that the result dim map does not contain the positions corresponding // to the outputs. @@ -664,11 +666,12 @@ SmallVector redDims; linalgOp.getReductionDims(redDims); - if (!linalgOp.getShapesToLoopsMap()) + if (!linalgOp.getShapesToLoopsMap(op->getContext())) return op->emitOpError("expected the shape-to-loops map to be non-null"); // Check if given shapes match to inferred shapes. - SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); + SmallVector endLoopRangeValues = + linalgOp.getStaticLoopRanges(op->getContext()); SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); // Verify only static cases since we can't get exact dimension sizes and loop diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -907,7 +907,7 @@ Location loc = target.getLoc(); SmallVector allShapeSizes = target.createFlatListOfOperandDims(b, loc); - AffineMap map = target.getShapesToLoopsMap(); + AffineMap map = target.getShapesToLoopsMap(getContext()); if (!map) return tileSizes; IRRewriter rewriter(b); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -90,7 +90,8 @@ auto linalgLoc = linalgOp.getLoc(); SmallVector allShapeSizes = linalgOp.createFlatListOfOperandDims(rewriter, linalgLoc); - AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap(); + AffineMap shapeSizesToLoopsMap = + linalgOp.getShapesToLoopsMap(rewriter.getContext()); if (!shapeSizesToLoopsMap) { return rewriter.notifyMatchFailure( linalgOp, "failed to get loops map from shape sizes"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -128,7 +128,8 @@ } auto linalgOp = cast(genericOp.getOperation()); - SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + SmallVector loopBounds = + linalgOp.computeStaticLoopSizes(rewriter.getContext()); int64_t numElements = outputType.getNumElements(); // Use APInt/APFloat instead of Attribute here for constructing the output. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -108,7 +108,7 @@ Location loc = op.getLoc(); auto allShapesSizes = cast(op.getOperation()).createFlatListOfOperandDims(b, loc); - AffineMap map = op.getShapesToLoopsMap(); + AffineMap map = op.getShapesToLoopsMap(b.getContext()); IRRewriter rewriter(b); return makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapesSizes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -141,7 +141,7 @@ // Check that the new index maps are invertible. If not, something went // wrong, so abort. - if (!inversePermutation(concatAffineMaps(newIndexingMaps))) + if (!inversePermutation(concatAffineMaps(newIndexingMaps, context))) return nullptr; return ArrayAttr::get(context, llvm::to_vector<4>(llvm::map_range( @@ -184,7 +184,8 @@ // Check if any of the iteration dimensions are unit-trip count. They will // end up being unit-trip count if they are used to index into a unit-dim // tensor/memref. - AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps)); + AffineMap invertedMap = inversePermutation( + concatAffineMaps(indexingMaps, rewriter.getContext())); if (!invertedMap) return failure(); SmallVector dims = genericOp.getStaticShape(); @@ -404,8 +405,8 @@ // If the indexing maps of the result operation are not invertible (i.e. not // legal), abort. - if (!doCanonicalization || - !inversePermutation(concatAffineMaps(newIndexingMaps))) + if (!doCanonicalization || !inversePermutation(concatAffineMaps( + newIndexingMaps, rewriter.getContext()))) return failure(); // If any operand type change, insert a reshape to convert from the original diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -281,7 +281,8 @@ for (auto producerResult : llvm::enumerate(producer->getResults())) { auto outputOperand = producer.getDpsInitOperand(producerResult.index()); if (producer.payloadUsesValueFromOperand(outputOperand) || - !producer.canOpOperandsBeDropped(outputOperand) || + !producer.canOpOperandsBeDropped(rewriter.getContext(), + outputOperand) || llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { return user != consumer.getOperation(); })) { @@ -358,7 +359,7 @@ consumer.getIteratorTypes(), /*doc=*/nullptr, /*library_call=*/nullptr); - if (!fusedOp.getShapesToLoopsMap()) { + if (!fusedOp.getShapesToLoopsMap(rewriter.getContext())) { // Fused op has invalid indexing maps. Typically this means something is off // in the input, but going ahead here would result in verification errors. // So cleanup and abort. @@ -544,7 +545,8 @@ return failure(); AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); - SmallVector originalLoopRange = linalgOp.getStaticLoopRanges(); + SmallVector originalLoopRange = + linalgOp.getStaticLoopRanges(rewriter.getContext()); originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); reassociation.clear(); @@ -1655,7 +1657,8 @@ genericOp.getMatchingIndexingMap(outputOperand)); // Check if the operation shapes to loops map is computable. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { + if (!inversePermutation( + concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) { return rewriter.notifyMatchFailure( genericOp, "fused op loop bound computation failed"); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -70,13 +70,15 @@ // Gather information about duplicate input operands. llvm::SmallDenseMap origInsToNewInsPos = - deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands, + deduplicateInputOperands(rewriter.getContext(), genericOp, + droppedOpOperands, newInputOperands, newIndexingMaps); // Gather information about the dropped outputs. llvm::SmallDenseMap origOutsToNewOutsPos = - deduplicateOutputOperands(genericOp, droppedOpOperands, - newOutputOperands, newIndexingMaps); + deduplicateOutputOperands(rewriter.getContext(), genericOp, + droppedOpOperands, newOutputOperands, + newIndexingMaps); // Check if there is any change to operands. if (newInputOperands.size() + newOutputOperands.size() == @@ -125,7 +127,7 @@ // the canonicalized op. // - The preserved input operands list (by reference). llvm::SmallDenseMap - deduplicateInputOperands(GenericOp genericOp, + deduplicateInputOperands(MLIRContext *context, GenericOp genericOp, SmallVector &droppedOpOperands, SmallVector &newInputOperands, SmallVector &newIndexingMaps) const { @@ -139,7 +141,7 @@ // Add the current operands to the list of potentially droppable // operands. If it cannot be dropped, this needs to be popped back. droppedOpOperands.push_back(inputOpOperand); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) + if (genericOp.canOpOperandsBeDropped(context, droppedOpOperands)) continue; droppedOpOperands.pop_back(); } @@ -169,7 +171,7 @@ // the canonicalized op. // - The preserved output operands list (by reference). llvm::SmallDenseMap - deduplicateOutputOperands(GenericOp genericOp, + deduplicateOutputOperands(MLIRContext *context, GenericOp genericOp, SmallVector &droppedOpOperands, SmallVector &newOutputOperands, SmallVector &newIndexingMaps) const { @@ -205,7 +207,7 @@ // operand for checking. If it cannot be dropped, need to pop the // value back. droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + if (genericOp.canOpOperandsBeDropped(context, droppedOpOperands)) { continue; } droppedOpOperands.pop_back(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -42,7 +42,7 @@ op.getReductionDims(dims); assert(dims.size() == 1); unsigned reductionDim = dims[0]; - SmallVector loopRanges = op.getStaticLoopRanges(); + SmallVector loopRanges = op.getStaticLoopRanges(b.getContext()); int64_t reductionDimSize = loopRanges[reductionDim]; if (reductionDimSize == ShapedType::kDynamicSize || reductionDimSize % ratio != 0 || @@ -265,7 +265,7 @@ return b.notifyMatchFailure(op, "needs at least 1 reduction dimension"); unsigned reductionDimPos = dims[0]; - SmallVector loopRanges = op.getStaticLoopRanges(); + SmallVector loopRanges = op.getStaticLoopRanges(b.getContext()); int64_t reductionDimSize = loopRanges[reductionDimPos]; if (reductionDimSize == ShapedType::kDynamicSize || reductionDimSize % splitFactor != 0 || diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -136,7 +136,7 @@ // sizes are computed. SmallVector allShapes = op.createFlatListOfOperandDims(b, b.getLoc()); - AffineMap shapesToLoops = op.getShapesToLoopsMap(); + AffineMap shapesToLoops = op.getShapesToLoopsMap(b.getContext()); SmallVector loopRanges = makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops, allShapes); @@ -575,7 +575,7 @@ // 1. Build the tiled loop ranges. SmallVector allShapeSizes = op.createFlatListOfOperandDims(b, op.getLoc()); - AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(); + AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap(b.getContext()); if (!shapeSizesToLoopsMap) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -99,7 +99,7 @@ LinalgOp linalgOp = cast(op); SmallVector allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc); - AffineMap map = linalgOp.getShapesToLoopsMap(); + AffineMap map = linalgOp.getShapesToLoopsMap(b.getContext()); return llvm::to_vector( llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -281,7 +281,7 @@ return VectorizationResult{VectorizationStatus::Failure, nullptr}; auto loc = indexOp.getLoc(); // Compute the static loop sizes of the index op. - auto targetShape = linalgOp.computeStaticLoopSizes(); + auto targetShape = linalgOp.computeStaticLoopSizes(b.getContext()); // Compute a one-dimensional index vector for the index op dimension. SmallVector constantSeq = llvm::to_vector<16>(llvm::seq(0, targetShape[indexOp.getDim()])); @@ -347,7 +347,7 @@ auto indexVec = bvm.lookup(extractOp.getIndices()[0]); // Compute the static loop sizes of the extract op. - auto targetShape = linalgOp.computeStaticLoopSizes(); + auto targetShape = linalgOp.computeStaticLoopSizes(b.getContext()); SmallVector gatherIndices; gatherIndices.push_back(b.create(loc, 0)); @@ -523,7 +523,8 @@ // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - SmallVector commonVectorShape = linalgOp.computeStaticLoopSizes(); + SmallVector commonVectorShape = + linalgOp.computeStaticLoopSizes(b.getContext()); // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -707,7 +707,10 @@ return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context); } -AffineMap mlir::concatAffineMaps(ArrayRef maps) { +AffineMap mlir::concatAffineMaps(ArrayRef maps, + MLIRContext *context) { + if (maps.empty()) + return AffineMap::get(context); unsigned numResults = 0, numDims = 0, numSymbols = 0; for (auto m : maps) numResults += m.getNumResults(); @@ -720,8 +723,7 @@ numSymbols += m.getNumSymbols(); numDims = std::max(m.getNumDims(), numDims); } - return AffineMap::get(numDims, numSymbols, results, - maps.front().getContext()); + return AffineMap::get(numDims, numSymbols, results, context); } AffineMap mlir::getProjectedMap(AffineMap map,