diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -240,12 +240,11 @@ /// escape (while leaving the IR in a valid state). LogicalResult normalizeMemRef(memref::AllocOp *op); -/// Uses the old memref type map layout and computes the new memref type to have -/// a new shape and a layout map, where the old layout map has been normalized -/// to an identity layout map. It returns the old memref in case no -/// normalization was needed or a failure occurs while transforming the old map -/// layout to an identity layout map. -MemRefType normalizeMemRefType(MemRefType memrefType, OpBuilder builder, +/// Normalizes `memrefType` so that the affine layout map of the memref is +/// transformed to an identity map with a new shape being computed for the +/// normalized memref type and returns it. The old memref type is simplify +/// returned if the normalization failed. +MemRefType normalizeMemRefType(MemRefType memrefType, unsigned numSymbolicOperands); /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1691,7 +1691,7 @@ // Fetch a new memref type after normalizing the old memref to have an // identity map layout. MemRefType newMemRefType = - normalizeMemRefType(memrefType, b, allocOp->getSymbolOperands().size()); + normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size()); if (newMemRefType == memrefType) // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -1742,7 +1742,7 @@ return success(); } -MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, +MemRefType mlir::normalizeMemRefType(MemRefType memrefType, unsigned numSymbolicOperands) { unsigned rank = memrefType.getRank(); if (rank == 0) @@ -1790,10 +1790,11 @@ // Project out the old data dimensions. fac.projectOut(newRank, fac.getNumVars() - newRank - fac.getNumLocalVars()); SmallVector newShape(newRank); + MLIRContext *context = memrefType.getContext(); for (unsigned d = 0; d < newRank; ++d) { // Check if each dimension of normalized memrefType is dynamic. - bool isDynDim = isNormalizedMemRefDynamicDim( - d, layoutMap, memrefTypeDynDims, b.getContext()); + bool isDynDim = + isNormalizedMemRefDynamicDim(d, layoutMap, memrefTypeDynDims, context); if (isDynDim) { newShape[d] = -1; } else { @@ -1814,8 +1815,8 @@ MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) - .setLayout(AffineMapAttr::get(b.getMultiDimIdentityMap(newRank))); - + .setLayout(AffineMapAttr::get( + AffineMap::getMultiDimIdentityMap(newRank, context))); return newMemRefType; } @@ -1844,12 +1845,12 @@ FailureOr> mlir::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, - ArrayRef dimSizes) { - unsigned numDims = dimSizes.size(); + ArrayRef basis) { + unsigned numDims = basis.size(); SmallVector divisors; for (unsigned i = 1; i < numDims; i++) { - ArrayRef slice = dimSizes.drop_front(i); + ArrayRef slice = basis.drop_front(i); FailureOr prod = getIndexProduct(b, loc, slice); if (failed(prod)) return failure(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -364,7 +364,7 @@ } // Fetch a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + MemRefType newMemRefType = normalizeMemRefType(memrefType, /*numSymbolicOperands=*/0); if (newMemRefType == memrefType || funcOp.isExternal()) { // Either memrefType already had an identity map or the map couldn't be @@ -472,7 +472,7 @@ } // Computing a new memref type after normalizing the old memref to have an // identity map layout. - MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + MemRefType newMemRefType = normalizeMemRefType(memrefType, /*numSymbolicOperands=*/0); resultTypes.push_back(newMemRefType); } @@ -511,7 +511,7 @@ continue; } // Fetch a new memref type after normalizing the old memref. - MemRefType newMemRefType = normalizeMemRefType(memrefType, b, + MemRefType newMemRefType = normalizeMemRefType(memrefType, /*numSymbolicOperands=*/0); if (newMemRefType == memrefType) { // Either memrefType already had an identity map or the map couldn't