diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -28,8 +28,8 @@ /// Adds patterns to emulate wide Arith and Function ops over integer /// types into supported ones. This is done by splitting original power-of-two /// i2N integer types into two iN halves. -void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter, - RewritePatternSet &patterns); +void populateArithWideIntEmulationPatterns( + WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -52,9 +52,9 @@ def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> { let summary = "Emulate 2*N-bit integer operations using N-bit operations"; let description = [{ - Emulate integer operations that use too wide integer types with equivalent - operations on supported narrow integer types. This is done by splitting - original integer values into two halves. + Emulate arith integer operations that use too wide integer types with + equivalent operations on supported narrow integer types. This is done by + splitting original integer values into two halves. This pass is intended preserve semantics but not necessarily provide the most efficient implementation. diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -20,6 +20,10 @@ class AffineDialect; class ModuleOp; +namespace arith { +class WideIntEmulationConverter; +} // namespace arith + namespace func { class FuncDialect; } // namespace func @@ -60,6 +64,17 @@ void populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns); +/// Appends patterns for emulating wide integer memref operations with ops over +/// narrower integer types. +void populateMemRefWideIntEmulationPatterns( + arith::WideIntEmulationConverter &typeConverter, + RewritePatternSet &patterns); + +/// Appends type converions for emulating wide integer memref operations with +/// ops over narrowe integer types. +void populateMemRefWideIntEmulationConversions( + arith::WideIntEmulationConverter &typeConverter); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -28,6 +28,22 @@ ]; } +def MemRefEmulateWideInt : Pass<"memref-emulate-wide-int"> { + let summary = "Emulate 2*N-bit integer operations using N-bit operations"; + let description = [{ + Emulate memref integer operations that use too wide integer types with + equivalent operations on supported narrow integer types. This is done by + splitting original integer values into two halves. + + Currently, only power-of-two integer bitwidths are supported. + }]; + let options = [ + Option<"widestIntSupported", "widest-int-supported", "unsigned", + /*default=*/"32", "Widest integer type supported by the target">, + ]; + let dependentDialects = ["vector::VectorDialect"]; +} + def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> { let summary = "Normalize memrefs"; let description = [{ diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -745,7 +745,7 @@ opLegalCallback); RewritePatternSet patterns(ctx); - arith::populateWideIntEmulationPatterns(typeConverter, patterns); + arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); @@ -817,7 +817,7 @@ }); } -void arith::populateWideIntEmulationPatterns( +void arith::populateArithWideIntEmulationPatterns( WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) { // Populate `func.*` conversion patterns. populateFunctionOpInterfaceTypeConversionPattern(patterns, diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms ComposeSubView.cpp ExpandOps.cpp + EmulateWideInt.cpp FoldMemRefAliasOps.cpp MultiBuffer.cpp NormalizeMemRefs.cpp @@ -17,6 +18,7 @@ MLIRAffineDialect MLIRAffineUtils MLIRArithDialect + MLIRArithTransforms MLIRFuncDialect MLIRInferTypeOpInterface MLIRLoopLikeInterface diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -0,0 +1,162 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +namespace mlir::memref { +#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace mlir::memref + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// ConvertMemRefAlloc +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAlloc final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + + rewriter.replaceOpWithNewOp( + op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefLoad +//===----------------------------------------------------------------------===// + +struct ConvertMemRefLoad final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newResTy = getTypeConverter()->convertType(op.getType()); + if (!newResTy) + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemRefType())); + + rewriter.replaceOpWithNewOp( + op, newResTy, adaptor.getMemref(), adaptor.getIndices()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefStore +//===----------------------------------------------------------------------===// + +struct ConvertMemRefStore final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getMemRefType()); + if (!newTy) + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemRefType())); + + rewriter.replaceOpWithNewOp( + op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct EmulateWideIntPass final + : memref::impl::MemRefEmulateWideIntBase { + using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase; + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + arith::WideIntEmulationConverter typeConverter(widestIntSupported); + memref::populateMemRefWideIntEmulationConversions(typeConverter); + ConversionTarget target(*ctx); + target.addDynamicallyLegalDialect< + arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + // Add common pattenrs to support contants, functions, etc. + arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); + + memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +void memref::populateMemRefWideIntEmulationPatterns( + arith::WideIntEmulationConverter &typeConverter, + RewritePatternSet &patterns) { + // Populate `memref.*` conversion patterns. + patterns.add( + typeConverter, patterns.getContext()); +} + +void memref::populateMemRefWideIntEmulationConversions( + arith::WideIntEmulationConverter &typeConverter) { + typeConverter.addConversion( + [&typeConverter](MemRefType ty) -> Optional { + auto intTy = ty.getElementType().dyn_cast(); + if (!intTy) + return ty; + + if (intTy.getIntOrFloatBitWidth() <= + typeConverter.getMaxTargetIntBitWidth()) + return ty; + + Type newElemTy = typeConverter.convertType(intTy); + if (!newElemTy) + return None; + + return ty.cloneWith(None, newElemTy); + }); +} diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @memref_i32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: return +func.func @memref_i32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i32 + %m = memref.alloc() : memref<4xi32, 1> + %v = memref.load %m[%c0] : memref<4xi32, 1> + memref.store %c1, %m[%c0] : memref<4xi32, 1> + return +} + +// Expect no conversions, f64 is not an integer type. +// CHECK-LABEL: func @memref_f32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: return +func.func @memref_f32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + %m = memref.alloc() : memref<4xf32, 1> + %v = memref.load %m[%c0] : memref<4xf32, 1> + memref.store %c1, %m[%c0] : memref<4xf32, 1> + return +} + +// CHECK-LABEL: func @alloc_load_store_i64 +// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32> +// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: memref.store [[C1]], [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1> +// CHECK-NEXT: return +func.func @alloc_load_store_i64() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i64 + %m = memref.alloc() : memref<4xi64, 1> + %v = memref.load %m[%c0] : memref<4xi64, 1> + memref.store %c1, %m[%c0] : memref<4xi64, 1> + return +} diff --git a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp --- a/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp +++ b/mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp @@ -74,7 +74,7 @@ }); RewritePatternSet patterns(ctx); - arith::populateWideIntEmulationPatterns(typeConverter, patterns); + arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); }