diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -39,7 +39,7 @@ RewritePatternSet &patterns); /// Creates an instance of tensor constant bufferization pass. -std::unique_ptr createTensorConstantBufferizePass(); +std::unique_ptr createTensorConstantBufferizePass(unsigned alignment = 0); /// Creates an instance of the StdExpand pass that legalizes Std /// dialect ops to be convertible to LLVM. For example, diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -62,6 +62,10 @@ }]; let constructor = "mlir::createTensorConstantBufferizePass()"; let dependentDialects = ["memref::MemRefDialect"]; + let options = [ + Option<"alignment", "alignment", "unsigned", /*default=*/"0", + "Create global memrefs with a specified alignment">, + ]; } #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h --- a/mlir/include/mlir/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Transforms/BufferUtils.h @@ -125,11 +125,13 @@ // names. Duplicates are avoided. class GlobalCreator { public: - explicit GlobalCreator(ModuleOp module) : moduleOp(module) {} + GlobalCreator(ModuleOp module, unsigned alignment = 0) + : moduleOp(module), alignment(alignment) {} memref::GlobalOp getGlobalFor(ConstantOp constantOp); private: ModuleOp moduleOp; + unsigned alignment; // This could use memref::GlobalOp key but we avoid introducing a new // dependence to the memref dialect for this. DenseMap globals; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -43,13 +43,18 @@ interleave(type.getShape(), os, "x"); os << "x" << type.getElementType(); + // Add an optional alignment to the global memref. + IntegerAttr memrefAlignment = + alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) + : IntegerAttr(); + auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/typeConverter.convertType(type).cast(), /*initial_value=*/constantOp.getValue().cast(), /*constant=*/true, - /*alignment=*/IntegerAttr()); + /*alignment=*/memrefAlignment); symbolTable.insert(global); // The symbol table inserts at the end of the module, but globals are a bit // nicer if they are at the beginning. @@ -90,11 +95,17 @@ } namespace { -struct TensorConstantBufferizePass +class TensorConstantBufferizePass : public TensorConstantBufferizeBase { +public: + explicit TensorConstantBufferizePass(unsigned alignment) { + if (alignment) + this->alignment = alignment; + } + void runOnOperation() override { auto module = getOperation(); - GlobalCreator globals(module); + GlobalCreator globals(module, alignment); auto *context = &getContext(); BufferizeTypeConverter typeConverter; @@ -111,6 +122,7 @@ }; } // namespace -std::unique_ptr mlir::createTensorConstantBufferizePass() { - return std::make_unique(); +std::unique_ptr +mlir::createTensorConstantBufferizePass(unsigned alignment) { + return std::make_unique(alignment); } diff --git a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir --- a/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir +++ b/mlir/test/Dialect/Standard/tensor-constant-bufferize.mlir @@ -1,9 +1,17 @@ // RUN: mlir-opt %s -tensor-constant-bufferize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -tensor-constant-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s // CHECK-LABEL: module { + // We check the debug name too since we put some effort into making that readable. // The name isn't load-bearing though. + // CHECK: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00> +// CHECK-NOT: alignment + +// ALIGNED: memref.global "private" constant @__constant_3x4xf32 : memref<3x4xf32> = dense<7.000000e+00> +// ALIGNED-SAME: {alignment = 64 : i64} + // CHECK: @basic func @basic() -> tensor<3x4xf32> { // CHECK: %[[MEMREF:.*]] = memref.get_global @__constant_3x4xf32 : memref<3x4xf32>