This is an archive of the discontinued LLVM Phabricator instance.

[mlir][arith] Fold redundant broadcasts
AbandonedPublic

Authored by awarzynski on Apr 12 2023, 12:22 PM.

Details

Summary

This revision folds addi(vector.broadcast(x), vector.broadcast(y)) as
vector.broadcast(addi(x, y)).

Diff Detail

Event Timeline

awarzynski created this revision.Apr 12 2023, 12:22 PM
Herald added a project: Restricted Project. · View Herald TranscriptApr 12 2023, 12:22 PM
awarzynski requested review of this revision.Apr 12 2023, 12:22 PM

The idea for this canonicalization came up in https://github.com/openxla/iree/issues/12991. I wasn't entirely sure whether it belongs in the Airth or the Vector dialect, probably a bit of both.

I'm also not sure about the CMake dependencies (it's a bit tricky without a good architectual overview), so I'd appreciate suggestions how to improve this. For a bit of context, here's the build error that I get without the CMake changes:

Undefined symbols for architecture arm64:
  "mlir::detail::TypeIDResolver<mlir::vector::BroadcastOp, void>::id", referenced from:
      llvm::DefaultDoCastIfPossible<mlir::vector::BroadcastOp, mlir::Operation*, llvm::CastInfo<mlir::vector::BroadcastOp, mlir::Operation*, void>>::doCastIfPossible(mlir::Operation*) in libMLIRArithDialect.a(ArithOps.cpp.o)
  "mlir::vector::BroadcastOp::getODSResults(unsigned int)", referenced from:
      (anonymous namespace)::AddIVectorBroadcast::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const in libMLIRArithDialect.a(ArithOps.cpp.o)
  "mlir::vector::BroadcastOp::getODSOperands(unsigned int)", referenced from:
      (anonymous namespace)::AddIVectorBroadcast::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&) const in libMLIRArithDialect.a(ArithOps.cpp.o)
  "mlir::vector::BroadcastOp::build(mlir::OpBuilder&, mlir::OperationState&, mlir::TypeRange, mlir::ValueRange, llvm::ArrayRef<mlir::NamedAttribute>)", referenced from:
      mlir::vector::BroadcastOp mlir::OpBuilder::create<mlir::vector::BroadcastOp, llvm::SmallVector<mlir::Type, 4u>&, llvm::SmallVector<mlir::Value, 4u>&, llvm::SmallVector<mlir::NamedAttribute, 4u>&>(mlir::Location, llvm::SmallVector<mlir::Type, 4u>&, llvm::SmallVector<mlir::Value, 4u>&, llvm::SmallVector<mlir::NamedAttribute, 4u>&) in libMLIRArithDialect.a(ArithOps.cpp.o)

Thanks for taking a look!

kuhar added a comment.Apr 12 2023, 6:57 PM

I'm not sure if this belongs to general arith canonicalizations; the vector dialect depends on artih, so I'd expect this transform to be archored at vector.broadcast.
Could we have a more general pattern that fold any elementwise op that supports both scalars and vectors when the vector operand(s) are either broadcasted or splat constants?

I'm also not sure about the CMake dependencies (it's a bit tricky without a good architectual overview), so I'd appreciate suggestions how to improve this.

I typically check with bazel in cases like this -- it's more picky than CMake when it comes to implicit/circular dependencies, which makes it tell when some are missing.

awarzynski abandoned this revision.Jun 13 2023, 6:26 AM

@kuhar , thanks for taking a look and apologies for the delay getting back to this. I will abandon this in favor of a more general approach: https://reviews.llvm.org/D152812. Cheers for the feedback!