diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -119,8 +119,9 @@ if (!memRefType.getAffineMaps().empty()) numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); if (op.symbolOperands().size() != numSymbols) - return op.emitOpError( - "symbol operand count does not equal memref symbol count"); + return op.emitOpError("symbol operand count does not equal memref symbol " + "count: expected ") + << numSymbols << ", got " << op.symbolOperands().size(); return success(); } @@ -146,7 +147,7 @@ PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - if (llvm::none_of(alloc.getOperands(), [](Value operand) { + if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) return failure(); @@ -157,7 +158,7 @@ // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); - SmallVector newOperands; + SmallVector dynamicSizes; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { @@ -167,14 +168,15 @@ newShapeConstants.push_back(dimSize); continue; } - auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); + auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos]; + auto *defOp = dynamicSize.getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); } else { - // Dynamic shape dimension not folded; copy operand from old memref. + // Dynamic shape dimension not folded; copy dynamicSize from old memref. newShapeConstants.push_back(-1); - newOperands.push_back(alloc.getOperand(dynamicDimPos)); + dynamicSizes.push_back(dynamicSize); } dynamicDimPos++; } @@ -182,12 +184,13 @@ // Create new memref type (which will have fewer dynamic dimensions). MemRefType newMemRefType = MemRefType::Builder(memrefType).setShape(newShapeConstants); - assert(static_cast(newOperands.size()) == + assert(static_cast(dynamicSizes.size()) == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create( - alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); + alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(), + alloc.alignmentAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -391,3 +391,56 @@ // CHECK-SAME: tensor<2x?xi32> into tensor // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]] // CHECK: return %[[CAST]] + +// ----- + +// CHECK-LABEL: func @alloc_const_fold +func @alloc_const_fold() -> memref { + // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32> + %c4 = constant 4 : index + %a = memref.alloc(%c4) : memref + + // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref + // CHECK-NEXT: return %1 : memref + return %a : memref +} + +// ----- + +// CHECK-LABEL: func @alloc_alignment_const_fold +func @alloc_alignment_const_fold() -> memref { + // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32> + %c4 = constant 4 : index + %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref + + // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref + // CHECK-NEXT: return %1 : memref + return %a : memref +} + +// ----- + +// CHECK-LABEL: func @alloc_const_fold_with_symbols1( +// CHECK: %[[c1:.+]] = constant 1 : index +// CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref +// CHECK: return %[[mem1]] : memref +#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +func @alloc_const_fold_with_symbols1(%arg0 : index) -> memref { + %c1 = constant 1 : index + %0 = memref.alloc(%arg0)[%c1, %c1] : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @alloc_const_fold_with_symbols2( +// CHECK: %[[c1:.+]] = constant 1 : index +// CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, #map> +// CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, #map> to memref +// CHECK: return %[[mem2]] : memref +#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +func @alloc_const_fold_with_symbols2() -> memref { + %c1 = constant 1 : index + %0 = memref.alloc(%c1)[%c1, %c1] : memref + return %0 : memref +} diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -357,28 +357,6 @@ return } -// CHECK-LABEL: func @alloc_const_fold -func @alloc_const_fold() -> memref { - // CHECK-NEXT: %0 = memref.alloc() : memref<4xf32> - %c4 = constant 4 : index - %a = memref.alloc(%c4) : memref - - // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref - // CHECK-NEXT: return %1 : memref - return %a : memref -} - -// CHECK-LABEL: func @alloc_alignment_const_fold -func @alloc_alignment_const_fold() -> memref { - // CHECK-NEXT: %0 = memref.alloc() {alignment = 4096 : i64} : memref<4xf32> - %c4 = constant 4 : index - %a = memref.alloc(%c4) {alignment = 4096 : i64} : memref - - // CHECK-NEXT: %1 = memref.cast %0 : memref<4xf32> to memref - // CHECK-NEXT: return %1 : memref - return %a : memref -} - // CHECK-LABEL: func @dead_alloc_fold func @dead_alloc_fold() { // CHECK-NEXT: return