diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -583,6 +583,7 @@ Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { assert(value.getType().isa() && "expected tensor type"); + auto tensorType = value.getType().cast(); // No further analysis is possible for a block argument. if (value.isa()) @@ -593,16 +594,34 @@ auto opResult = value.cast(); auto bufferizableOp = cast(op); AnalysisState state(options); + + // Case 1: If the OpResult has an equivalent OpOperand, both OpResult and + // OpOperand bufferize to the exact same buffer type. auto aliasingOperands = bufferizableOp.getAliasingOpOperand(opResult, state); if (!aliasingOperands.empty() && bufferizableOp.bufferRelation(opResult, state) == BufferRelation::Equivalent) { - // If the OpResult has an equivalent OpOperand, both OpResult and - // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliasingOperands.front()->get(); return getBufferType(equivalentOperand, options, fixedTypes); } + // Case 2: If the OpResult bufferizes to a new allocation (and never aliases + // any OpOperand), the result buffer type has a static identity layout. + if (aliasingOperands.empty() && + bufferizableOp.bufferizesToAllocation(opResult)) { + // Compute memory space of this allocation. + unsigned memorySpace; + if (auto maybeMemorySpace = getMemorySpaceAttr(opResult)) { + memorySpace = *maybeMemorySpace; + } else if (options.defaultMemorySpace.has_value()) { + memorySpace = *options.defaultMemorySpace; + } else { + return op->emitError("could not infer memory space"); + } + return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace); + } + + // Case 3: Generate a buffer type according to the bufferization options. // If we do not know the memory space and there is no default memory space, // report a failure. if (!options.defaultMemorySpace.has_value()) diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -433,10 +433,6 @@ bool dealloc = shouldDeallocateOpResult( fromElementsOp.getResult().cast(), options); - // TODO: Implement memory space for this op. - if (options.defaultMemorySpace != static_cast(0)) - return op->emitError("memory space not implemented yet"); - // Allocate a buffer for the result. Location loc = op->getLoc(); auto tensorType = fromElementsOp.getType().cast(); @@ -448,8 +444,11 @@ /*copy=*/false); if (failed(tensorAlloc)) return failure(); - auto memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + FailureOr maybeMemrefType = + bufferization::getBufferType(*tensorAlloc, options); + if (failed(maybeMemrefType)) + return failure(); + auto memrefType = maybeMemrefType->cast(); Value buffer = rewriter.create( op->getLoc(), memrefType, *tensorAlloc); @@ -502,11 +501,6 @@ bool dealloc = shouldDeallocateOpResult( generateOp.getResult().cast(), options); - // TODO: Implement memory space for this op. - if (options.defaultMemorySpace != static_cast(0)) - return op->emitError("memory space not implemented yet"); - - auto tensorType = generateOp.getType().cast(); // Allocate memory. Location loc = op->getLoc(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. @@ -516,8 +510,11 @@ /*copy=*/false); if (failed(tensorAlloc)) return failure(); - auto memrefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + FailureOr maybeMemrefType = + bufferization::getBufferType(*tensorAlloc, options); + if (failed(maybeMemrefType)) + return failure(); + auto memrefType = maybeMemrefType->cast(); Value buffer = rewriter.create( op->getLoc(), memrefType, *tensorAlloc); @@ -555,9 +552,7 @@ rewriter.replaceOpWithNewOp( elementYield, elementYield->getOperands()[0], buffer, parallelBody->getArguments()); - replaceOpWithBufferizedValues(rewriter, op, buffer); - return success(); } }; @@ -833,12 +828,31 @@ return {}; } + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + // Infer memory space from the source tensor. + auto padOp = cast(op); + auto maybeSrcBufferType = + bufferization::getBufferType(padOp.getSource(), options, fixedTypes); + if (failed(maybeSrcBufferType)) + return failure(); + MemRefLayoutAttrInterface layout; + return MemRefType::get(padOp.getResultType().getShape(), + padOp.getResultType().getElementType(), layout, + maybeSrcBufferType->getMemorySpace()); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto padOp = cast(op); Location loc = padOp.getLoc(); RankedTensorType resultType = padOp.getResultType(); RankedTensorType srcType = padOp.getSourceType(); + auto resultBufferType = + bufferization::getBufferType(padOp.getResult(), options); + if (failed(resultBufferType)) + return failure(); auto toValue = [&](OpFoldResult ofr) { if (ofr.is()) @@ -869,12 +883,13 @@ // Create tensor::GenerateOp. auto generateOp = rewriter.create(loc, resultType, dynamicSizes); + setMemorySpaceAttr(generateOp->getOpResult(0), + resultBufferType->getMemorySpaceAsInt()); // Move over "escape" attribute if present. if (padOp->hasAttr(BufferizationDialect::kEscapeAttrName)) generateOp->setAttr( BufferizationDialect::kEscapeAttrName, padOp->getAttr(BufferizationDialect::kEscapeAttrName)); - // TODO: Memory space rewriter.inlineRegionBefore(padOp.getRegion(), generateOp.getBody(), generateOp.getBody().begin()); diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -253,3 +253,65 @@ %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32> return %1 : tensor<10xf32> } + +// ----- + +// CHECK-LABEL: func @generate_memory_space +func.func @generate_memory_space(%sz: index, %idx: index) -> index { + // CHECK: memref.alloc{{.*}} : memref + // CHECK: scf.parallel + // CHECK: memref.store {{.*}} : memref + // CHECK: memref.load {{.*}} : memref + // CHECK: memref.dealloc {{.*}} : memref + %0 = tensor.generate %sz { + ^bb0(%i : index): + tensor.yield %sz : index + } { bufferization.memory_space = [3] }: tensor + %r = tensor.extract %0[%idx] : tensor + return %r : index +} + +// ----- + +// CHECK-LABEL: func @from_elements_memory_space +func.func @from_elements_memory_space(%val: index, %idx: index) -> index { + // CHECK: memref.alloc() {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.store {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.store {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.store {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.store {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.store {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.store {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.load {{.*}} : memref<2x3xindex, 3> + // CHECK: memref.dealloc {{.*}} : memref<2x3xindex, 3> + %0 = tensor.from_elements %val, %val, %val, %val, %val, %val + { bufferization.memory_space = [3] }: tensor<2x3xindex> + %r = tensor.extract %0[%idx, %idx] : tensor<2x3xindex> + return %r : index +} + +// ----- + +// CHECK-LABEL: func @pad_memory_space( +// CHECK-SAME: %[[t:.*]]: memref> +func.func @pad_memory_space(%t: tensor, %h1: index, %f: f32, %pos: index) -> f32 +{ + // CHECK: %[[alloc_tensor:.*]] = memref.alloc{{.*}} : memref + // CHECK: memref.copy %[[t]], %[[alloc_tensor]] + %0 = bufferization.alloc_tensor() copy(%t) + {bufferization.memory_space = [3]} : tensor + // CHECK: %[[padded_alloc:.*]] = memref.alloc() {{.*}} : memref<15xf32, 3> + // CHECK: scf.parallel + // CHECK: memref.store {{.*}} : memref<15xf32, 3> + // CHECK: %[[subview:.*]] = memref.subview {{.*}} : memref<15xf32, 3> to memref, 3> + // CHECK: memref.copy %[[alloc_tensor]], %[[subview]] + %1 = tensor.pad %0 low[2] high[%h1] { + ^bb0(%arg0: index): + tensor.yield %f : f32 + } : tensor to tensor<15xf32> + // CHECK: memref.load {{.*}} : memref<15xf32, 3> + %2 = tensor.extract %1[%pos] : tensor<15xf32> + // CHECK-DAG: memref.dealloc %[[alloc_tensor]] + // CHECK-DAG: memref.dealloc %[[padded_alloc]] + return %2 : f32 +}