diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -119,8 +119,9 @@ if (!memRefType.getAffineMaps().empty()) numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); if (op.symbolOperands().size() != numSymbols) - return op.emitOpError( - "symbol operand count does not equal memref symbol count"); + return op.emitOpError("symbol operand count does not equal memref symbol " + "count: expected ") + << numSymbols << ", got " << op.symbolOperands().size(); return success(); } @@ -146,7 +147,7 @@ PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - if (llvm::none_of(alloc.getOperands(), [](Value operand) { + if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) { return matchPattern(operand, matchConstantIndex()); })) return failure(); @@ -167,7 +168,7 @@ newShapeConstants.push_back(dimSize); continue; } - auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); + auto *defOp = alloc.dynamicSizes()[dynamicDimPos].getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); @@ -187,7 +188,8 @@ // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create( - alloc.getLoc(), newMemRefType, newOperands, alloc.alignmentAttr()); + alloc.getLoc(), newMemRefType, newOperands, alloc.symbolOperands(), + alloc.alignmentAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -217,3 +217,41 @@ return } +// ----- + +// CHECK-LABEL: func @simplify_alloc( +// CHECK: %[[mem1:.+]] = memref.alloc() : memref<1xi32> +// CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32> to memref +// CHECK: return %[[mem2]] : memref +func @simplify_alloc() -> memref { + %c1 = constant 1 : index + %0 = memref.alloc(%c1) : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @simplify_alloc_ignore_symbols1( +// CHECK: %[[c1:.+]] = constant 1 : index +// CHECK: %[[mem1:.+]] = memref.alloc({{.*}})[%[[c1]], %[[c1]]] : memref +// CHECK: return %[[mem1]] : memref +#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +func @simplify_alloc_ignore_symbols1(%arg0 : index) -> memref { + %c1 = constant 1 : index + %0 = memref.alloc(%arg0)[%c1, %c1] : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @simplify_alloc_ignore_symbols2( +// CHECK: %[[c1:.+]] = constant 1 : index +// CHECK: %[[mem1:.+]] = memref.alloc()[%[[c1]], %[[c1]]] : memref<1xi32, #map> +// CHECK: %[[mem2:.+]] = memref.cast %[[mem1]] : memref<1xi32, #map> to memref +// CHECK: return %[[mem2]] : memref +#map0 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +func @simplify_alloc_ignore_symbols2() -> memref { + %c1 = constant 1 : index + %0 = memref.alloc(%c1)[%c1, %c1] : memref + return %0 : memref +} \ No newline at end of file